""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import threading from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional import paddle import paddle.nn.functional as F from paddle import nn from fastdeploy.config import FDConfig from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( LogitsProcessorBase, ) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.ops import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, min_p_sampling, top_k_top_p_sampling, ) from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput class SamplerProcessor: """ SamplingProcessor for guided decoding. """ def __init__(self): self.async_step = None self.token_bitmask = None self.logits_processor: Dict[int, Optional[Any]] = dict() self.executor = ThreadPoolExecutor() self.logits_lock = threading.Lock() def add_logits_processor( self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = [], ): """add logits processor to SamplerProcessor""" with self.logits_lock: if future is None: if ids in self.logits_processor: del self.logits_processor[ids] return if isinstance(future, LogitsProcessorBase): self.logits_processor[ids] = future for token in prefill_tokens: self.logits_processor[ids].accept_token(token) elif future.done(): self.logits_processor[ids] = future.result() for token in prefill_tokens: self.logits_processor[ids].accept_token(token) else: self.logits_processor[ids] = [future, prefill_tokens] def update_vocab_mask(self, skip_idx_list: List[int] = []): """update vocab mask. (cpu-heavy operation)""" if len(self.logits_processor) == 0: return with self.logits_lock: for idx, processor in self.logits_processor.items(): if processor is None: del self.logits_processor[idx] continue if not isinstance(processor, LogitsProcessorBase): future, prefill_tokens = self.logits_processor[idx] self.logits_processor[idx] = future.result() for token in prefill_tokens: self.logits_processor[idx].accept_token(token) available_processors = None for processor in self.logits_processor.values(): if processor.is_terminated(): continue available_processors = processor if available_processors is None: return # allocate token bitmask self.token_bitmask = available_processors.allocate_token_bitmask() with self.logits_lock: # fill token bitmask for idx, processor in self.logits_processor.items(): if processor.is_terminated() or idx in skip_idx_list: continue processor.fill_token_bitmask(self.token_bitmask, idx) def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []): """apply token mask to logits""" if len(self.logits_processor) == 0 or self.token_bitmask is None: return logits # self.async_step.result() available_processors = None with self.logits_lock: for processor in self.logits_processor.values(): if processor.is_terminated(): continue available_processors = processor if available_processors is None: return logits indices = list(self.logits_processor.keys()) mask_idx = [i for i in indices if i not in skip_idx_list] return available_processors.apply_token_mask(logits, self.token_bitmask, indices=mask_idx) def _accept_token(self, idx: int, token: int): """accept token""" if idx not in self.logits_processor: raise ValueError(f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}") if self.logits_processor[idx].is_terminated(): return self.logits_processor[idx].accept_token(token) def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): """update output tokens""" if len(self.logits_processor) == 0: return token_ids = next_tokens.numpy().tolist() with self.logits_lock: for idx in self.logits_processor.keys(): token = token_ids[idx][0] if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list: continue self._accept_token(idx, token) def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" # create async operation for guided decoding # TODO: support async self.update_vocab_mask(skip_idx_list) # self.async_step = self.executor.submit(self.update_vocab_mask) class Sampler(nn.Layer): """ Sampler for normal generation. """ def __init__(self): """ """ super().__init__() if ( current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_iluvatar() or current_platform.is_gcu() or current_platform.is_dcu() ): self.forward = self.forward_cuda else: raise NotImplementedError self.processor = SamplerProcessor() def apply_logits_processor( self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = [], ): """apply logits processor to sampler""" self.processor.add_logits_processor(ids, future, prefill_tokens) def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" self.processor.pre_process(skip_idx_list) def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor: """ """ return F.log_softmax(logits, axis=-1) def gather_logprobs( self, logprobs: paddle.Tensor, num_logprobs: int, token_ids: paddle.Tensor, ) -> LogprobsTensors: """ Gather logprobs for topk and sampled/prompt token. Args: logprobs: (num tokens) x (vocab) tensor num_logprobs: minimum number of logprobs to retain per token token_ids: prompt tokens (if prompt logprobs) or sampled tokens (if sampled logprobs); 1D token ID tensor with (num tokens) elements Must be int64. Returns: Top-k int indices tensor, (num tokens) x (num_logprobs + 1) Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) Sampled token rank tensor, (num tokens) """ assert token_ids.dtype == paddle.int64 # Get with the logprob of the prompt or sampled token. token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) # Compute the ranks of the actual token. token_ranks = (logprobs >= token_logprobs).sum(-1) if num_logprobs >= 1: # Find the topK values. topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) indices = paddle.concat([token_ids, topk_indices], axis=1) top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) else: indices = token_ids top_logprobs = token_logprobs return LogprobsTensors(indices, top_logprobs, token_ranks) def forward_cuda( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, skip_idx_list: List[int] = [], ) -> SamplerOutput: """ """ num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: raw_logprobs = self.compute_logprobs(logits) logits = self.processor.apply_token_mask(logits, skip_idx_list) logits = apply_penalty_multi_scores( sampling_metadata.pre_token_ids, sampling_metadata.prompt_ids, sampling_metadata.prompt_lens, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, ) probs = F.softmax(logits) probs = min_p_sampling(probs, sampling_metadata.min_p) _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) logprobs_tensors = ( None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) ) self.processor.update_output_tokens(next_tokens, skip_idx_list) sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape # [num_requests, 1], where each row represents one generated # token per request. sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, ) return sampler_output class SpeculativeSampler(nn.Layer): """ Sampler for speculative generation. """ def __init__(self, fd_config: FDConfig): """ """ super().__init__() if current_platform.is_cuda(): self.forward = self.forward_cuda else: raise NotImplementedError self.speculative_verify_window = fd_config.speculative_config.verify_window self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" pass def apply_logits_processor( self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = [], ): """apply logits processor to sampler""" pass def forward_cuda( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], ) -> paddle.Tensor: """ """ from fastdeploy.model_executor.ops.gpu import speculate_verify, top_p_candidates logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], share_inputs["output_padding_offset"], share_inputs["output_cum_offsets"], max_model_len, ) probs = F.softmax(logits) verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( probs, sampling_metadata.top_p, share_inputs["output_padding_offset"], self.speculative_max_candidate_len, max_model_len, ) speculate_verify( share_inputs["accept_tokens"], share_inputs["accept_num"], share_inputs["step_idx"], share_inputs["stop_flags"], share_inputs["seq_lens_encoder"], share_inputs["seq_lens_decoder"], share_inputs[ "draft_tokens" ], # Both input and output, need to write the last 1 token accepted to position 0. share_inputs["seq_lens_this_time"], verify_tokens, verify_scores, share_inputs["max_dec_len"], sampling_metadata.eos_token_ids, share_inputs["is_block_step"], share_inputs["output_cum_offsets"], actual_candidate_len, share_inputs["actual_draft_token_num"], sampling_metadata.top_p, max_model_len, self.speculative_verify_window, True, # enable_topp self.speculative_benchmark_mode, ) return None class MTPSampler(nn.Layer): """ """ def __init__(self, fd_config: FDConfig): """ """ super().__init__() if current_platform.is_cuda(): self.forward = self.forward_cuda else: raise NotImplementedError def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" pass def apply_logits_processor( self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = [], ): """apply logits processor to sampler""" pass def forward_cuda( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], ) -> paddle.Tensor: """ """ logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], share_inputs["seq_lens_decoder"], max_model_len, ) probs = F.softmax(logits) _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) return next_tokens