""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import multiprocessing import time from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, List, Optional import paddle import paddle.nn.functional as F from paddle import nn from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.envs import FD_FILL_BITMASK_BATCH from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase from fastdeploy.model_executor.layers.sample.early_stopper import ( get_early_stopper_cls_from_stragegy, ) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.ops import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, min_p_sampling, speculate_get_target_logits, speculate_insert_first_token, top_k_top_p_sampling, ) from fastdeploy.platforms import current_platform from fastdeploy.reasoning import ReasoningParser from fastdeploy.worker.output import LogprobsTensors, SamplerOutput def top_p_normalize_probs_paddle( probs: paddle.Tensor, top_ps: paddle.Tensor, ): probs_idx = probs.argsort(axis=-1, descending=True) probs_sort = paddle.take_along_axis(probs, probs_idx, axis=-1) probs_sum = paddle.cumsum(probs_sort, axis=-1) probs_sort = paddle.where((probs_sum - probs_sort) > top_ps, paddle.zeros_like(probs_sort), probs_sort) probs_sort.divide_(probs_sort.sum(axis=-1, keepdim=True)) return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1) def padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder): real_bsz = seq_lens_this_time.shape[0] repeats = paddle.where(seq_lens_encoder[:real_bsz] == 0, seq_lens_this_time, paddle.ones_like(seq_lens_this_time)) top_p_padding = paddle.repeat_interleave(top_p[:real_bsz], repeats).unsqueeze(1) top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1) return top_p_padding, top_k_padding class GuidedDecoding: """ processor for guided decoding. """ def __init__(self, fd_config: FDConfig): self.token_bitmask = None self.max_num_seqs: int = int( fd_config.scheduler_config.max_num_seqs if fd_config.scheduler_config is not None else 1 ) self.logits_processors: List[Any] = [None] * self.max_num_seqs self.reasoning_parser = None self._prefill_done_idxs: List[bool] = [False] * self.max_num_seqs # for pd self._tokens_to_acc: List[None | List[int]] = [None] * self.max_num_seqs self.fill_bitmask_parallel_batch_size: int = FD_FILL_BITMASK_BATCH max_workers = max( 1, min(multiprocessing.cpu_count() // 2, int(self.max_num_seqs) / int(self.fill_bitmask_parallel_batch_size)), ) self.executor_for_fillmask = ThreadPoolExecutor(max_workers=int(max_workers)) self._fillmask_futures: List[Future] = [None] * self.max_num_seqs self.is_cuda_platform = current_platform.is_cuda() logger.info( f"GuidedDecoding max_num_seqs={self.max_num_seqs} fill_bitmask_parallel_batch_size={self.fill_bitmask_parallel_batch_size} is_cuda_platform={self.is_cuda_platform} max_workers={max_workers}" ) def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): self.reasoning_parser = reasoning_parser def add_logits_processor( self, idx: int, future: Optional[Any] = None, prefill_tokens: List[int] = [], ): """add logits processor to SamplerProcessor""" self._prefill_done_idxs[idx] = False if future is None: # normal request without guided_backend self.logits_processors[idx] = None return if len(prefill_tokens) != 0: # first_token from prefill node self._prefill_done_idxs[idx] = True if future.done(): # cached xgrammar self.logits_processors[idx] = future.result() for token in prefill_tokens: self._accept_token(idx, token) else: # async self.logits_processors[idx] = future self._tokens_to_acc[idx] = prefill_tokens def should_fill_bitmask(self, idx: int) -> bool: """ Determines whether to fill a bitmask for the logits processor at the given index. Args: idx (int): The index of the logits processor to check Returns: bool: True if the idx request bitmask should be filled """ if self.reasoning_parser is not None: if self.logits_processors[idx].enable_reasoning: # guided return True if not self.logits_processors[idx].reasoning_ended: return False return True def reset_processor(self, idx: int): """reset idx""" self._prefill_done_idxs[idx] = False self.logits_processors[idx] = None def update_vocab_mask(self, prefill_done_idxs: List[int] = []): """update vocab mask. (cpu-heavy operation)""" for idx in prefill_done_idxs: if self.logits_processors[idx] is None: continue assert not self._prefill_done_idxs[idx] self._prefill_done_idxs[idx] = True if isinstance(self.logits_processors[idx], Future): continue idxs = [] for idx, processor in enumerate(self.logits_processors): if processor is None or not self._prefill_done_idxs[idx]: continue # skip, join at apply_token_mask if isinstance(processor, Future): continue if processor.is_terminated: self.reset_processor(idx) continue self.accept_tokens_from_prefill_node(idx) if self.token_bitmask is None: self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask() if self.should_fill_bitmask(idx): idxs.append(idx) self._async_batch_fill_token_bitmask(idxs) def batch_fill_token_bitmask(self, batch: List[int]): """ Fills the token bitmask for a batch of logits processor indices. This method is typically called asynchronously via a thread pool executor to parallelize the bitmask filling operation. It is important that any shared data structures accessed within this method (such as `self.token_bitmask` and `self.logits_processors`) are thread-safe or properly synchronized to avoid race conditions. Args: batch (List[int]): List of indices for which to fill the token bitmask. """ for idx in batch: self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx) def _async_batch_fill_token_bitmask(self, idxs: List[int]): """launch async fill""" batch: List[int] = [] for idx in idxs: batch.append(idx) if len(batch) == self.fill_bitmask_parallel_batch_size: promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:]) self._fillmask_futures[idx] = promise batch = [] if batch: promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:]) self._fillmask_futures[batch[-1]] = promise def join_async_fillmask(self): """join all async fill futures""" for idx, furture in enumerate(self._fillmask_futures): if furture is not None: try: furture.result() except Exception as e: logger.error(f"Exception in async fillmask future at idx {idx}: {e}", exc_info=True) self._fillmask_futures[idx] = None def accept_tokens_from_prefill_node(self, idx: int): """accept prefill token, not future""" if self._tokens_to_acc[idx] is not None: # accept token from prefill node first for token in self._tokens_to_acc[idx]: self._accept_token(idx, token) self._tokens_to_acc[idx] = None def apply_token_mask(self, logits: paddle.Tensor, prefill_done_idxs: List[int] = []): """apply token mask to logits""" indices = [] for idx, processor in enumerate(self.logits_processors): if processor is None or not self._prefill_done_idxs[idx]: continue # compiled done, check idx should fill, fill_token_bitmask done in preprocess if not isinstance(processor, Future): if self.should_fill_bitmask(idx): indices.append(idx) continue # is Future, processor async compiled not ready, need join and wait ts = time.time() wait = False if not processor.done(): wait = True self.logits_processors[idx] = processor.result() if wait: logger.debug(f"[{idx} join async compile xgrammar, time_cost:{time.time() - ts}]") self.accept_tokens_from_prefill_node(idx) # Possible optimization: Extract 'think' content validation from logits_processors, # allowing join operations to complete immediately after 'think' terminates. # Furthermore, the current idx could be skipped, with compilation overhead # estimated at only a few milliseconds. # check idx for fill_token_mask if not self.should_fill_bitmask(idx): continue indices.append(idx) if self.token_bitmask is None: self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask() # launch async fill self._async_batch_fill_token_bitmask([idx]) if len(indices) == 0: return logits self.join_async_fillmask() from fastdeploy.model_executor.guided_decoding.xgrammar_backend import ( apply_token_mask, ) return apply_token_mask(logits, self.token_bitmask, indices=indices, is_cuda_platform=self.is_cuda_platform) def _accept_token(self, idx: int, token: int): """accept token""" if self.reasoning_parser is not None: if not self.logits_processors[idx].enable_reasoning: if not self.logits_processors[idx].reasoning_ended: reasoning_ended = self.reasoning_parser.is_reasoning_end([token]) self.logits_processors[idx].reasoning_ended = reasoning_ended return if not self.logits_processors[idx].accept_token(token) or self.logits_processors[idx].is_terminated: self.reset_processor(idx) def update_output_tokens(self, next_tokens: paddle.Tensor): """update output tokens""" if len(self.logits_processors) == 0: return token_ids = next_tokens.numpy().tolist() for idx, processor in enumerate(self.logits_processors): if not self._prefill_done_idxs[idx] or processor is None: continue if idx >= len(token_ids): continue token = token_ids[idx][0] if token < 0: self.reset_processor(idx) continue logger.debug(f"[{idx}]accept token{token}") self._accept_token(idx, token) def pre_process(self, prefill_done_idxs: List[int] = []): """pre process before running""" self.update_vocab_mask(prefill_done_idxs) class Sampler(nn.Layer): """ Sampler for normal generation. """ def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprobs"): """ """ super().__init__() if ( current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_iluvatar() or current_platform.is_gcu() or current_platform.is_dcu() or current_platform.is_maca() ): self.forward = self.forward_cuda elif current_platform.is_intel_hpu(): self.forward = self.forward_intel_hpu else: raise NotImplementedError self.guided_decoding = GuidedDecoding(fd_config) self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode # Can only be created when fd_config.early_stopper_config.enable_early_stop = True if ( fd_config is not None and fd_config.early_stop_config is not None and fd_config.early_stop_config.enable_early_stop ): early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy) self.early_stopper = early_stopper_cls() self.early_stopper.initialize(fd_config.scheduler_config.max_num_seqs, fd_config.early_stop_config) def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): """set reasoning parser""" self.guided_decoding.apply_reasoning_parser(reasoning_parser) def apply_logits_processor( self, ids: int, future: Future[LogitsProcessorBase] = None, prefill_tokens: List[int] = [] ): """apply logits processor to sampler""" self.guided_decoding.add_logits_processor(ids, future, prefill_tokens) def pre_process(self, prefill_done_idxs: List[int] = []): """pre process before running""" self.guided_decoding.pre_process(prefill_done_idxs) def post_process(self, next_tokens: paddle.Tensor): """post process after running""" self.guided_decoding.update_output_tokens(next_tokens) def compute_logprobs( self, logits: paddle.Tensor, sampling_metadata: Optional[SamplingMetadata] = None, ) -> paddle.Tensor: """ """ if sampling_metadata is None: return F.log_softmax(logits, axis=-1) last_logits = logits real_bsz = last_logits.shape[0] temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs share_inputs = sampling_metadata.share_inputs if temp_scaled_logprobs is not None: real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] temperature = sampling_metadata.temperature[:real_bsz] temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)) last_logits = last_logits / temp_temperature last_logprobs = F.log_softmax(last_logits, axis=-1) top_p_logprob = None top_p_req_mask = None if top_p_normalized_logprobs is not None and share_inputs is not None: seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz] seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz] seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz] seq_lens_time_sum = seq_lens_this_time + seq_lens_encoder + seq_lens_decoder real_req_mask = seq_lens_time_sum > 0 top_p_req_mask = paddle.logical_and(top_p_normalized_logprobs[:real_bsz], real_req_mask) real_req_top_p = sampling_metadata.top_p[:real_bsz] # Normalize logprobs if top_p normalization is enabled # NOTE: only normalize logprobs when top_p is set and not equal to 1.0 top_p_req_mask = paddle.logical_and(top_p_req_mask, real_req_top_p != 1.0) if top_p_req_mask.any(): probs = F.softmax(last_logits, axis=-1) probs = top_p_normalize_probs_paddle(probs, real_req_top_p) top_p_logprob = paddle.log(probs) if top_p_logprob is not None: last_logprobs = paddle.where(top_p_req_mask, top_p_logprob, last_logprobs) return last_logprobs def gather_logprobs( self, logprobs: paddle.Tensor, num_logprobs: int, token_ids: paddle.Tensor, ) -> LogprobsTensors: """ Gather logprobs for topk and sampled/prompt token. Args: logprobs: (num tokens) x (vocab) tensor num_logprobs: minimum number of logprobs to retain per token token_ids: prompt tokens (if prompt logprobs) or sampled tokens (if sampled logprobs); 1D token ID tensor with (num tokens) elements Must be int64. Returns: Top-k int indices tensor, (num tokens) x (num_logprobs + 1) Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) Sampled token rank tensor, (num tokens) """ assert token_ids.dtype == paddle.int64 logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) # Get with the logprob of the prompt or sampled token. if len(token_ids.shape) < len(logprobs.shape): token_ids = token_ids.unsqueeze(-1) token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) # Compute the ranks of the actual token. token_ranks = (logprobs >= token_logprobs).sum(-1) if num_logprobs >= 1: # Find the topK values. topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) indices = paddle.concat([token_ids, topk_indices], axis=1) top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) else: indices = token_ids top_logprobs = token_logprobs indices = indices.cpu() top_logprobs = top_logprobs.cpu() token_ranks = token_ranks.cpu() return LogprobsTensors(indices, top_logprobs, token_ranks) def forward_cuda( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, p_done_idxs: List[int] = [], ) -> SamplerOutput: """ """ logits = self.guided_decoding.apply_token_mask(logits, p_done_idxs) num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: if self.logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs(logits, sampling_metadata) elif self.logprobs_mode == "raw_logits": raw_logprobs = logits.clone() for proc in sampling_metadata.logits_processors or []: logits = proc.apply(logits) logits = apply_penalty_multi_scores( sampling_metadata.pre_token_ids, sampling_metadata.prompt_ids, sampling_metadata.prompt_lens, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, ) if num_logprobs is not None: if self.logprobs_mode == "processed_logprobs": raw_logprobs = self.compute_logprobs(logits, sampling_metadata) elif self.logprobs_mode == "processed_logits": raw_logprobs = logits.clone() probs = F.softmax(logits) probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list, seed=sampling_metadata.seed[0, 0], ) logprobs_tensors = ( None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) ) if sampling_metadata.enable_early_stop: # will set the stop batch in stop_flags assert sampling_metadata.stop_flags is not None, "need stop_flags for early stop" self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags) sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape # [num_requests, 1], where each row represents one generated # token per request. sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, ) return sampler_output def forward_intel_hpu( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, batch_ids: paddle.Tensor, max_batch: int, rank: int, local_rank: int, ) -> paddle.Tensor: if logits.dtype != paddle.float32: logits = paddle.cast(logits, paddle.float32) from fastdeploy.model_executor.ops.intel_hpu import fused_sampler _, next_tokens = fused_sampler( sampling_metadata.pre_token_ids, sampling_metadata.prompt_ids, sampling_metadata.seq_lens_encoder, sampling_metadata.seq_lens_decoder, sampling_metadata.step_idx, sampling_metadata.stop_flags, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, sampling_metadata.top_p, rank, local_rank, ) if next_tokens.shape[0] != max_batch: dim = next_tokens.shape[-1] tmp_tokens = paddle.full((max_batch, dim), -1 if local_rank == 0 else 0, dtype=next_tokens.dtype) tmp_tokens = paddle.scatter(tmp_tokens, batch_ids, next_tokens[: batch_ids.shape[0], :]) return tmp_tokens return next_tokens class SpeculativeSampler(nn.Layer): """ Sampler for speculative generation. """ def __init__(self, fd_config: FDConfig): """ """ super().__init__() if current_platform.is_cuda(): self.forward = self.forward_cuda elif current_platform.is_xpu(): self.forward = self.forward_xpu else: raise NotImplementedError self.logprobs_mode = fd_config.model_config.logprobs_mode self.speculative_verify_window = fd_config.speculative_config.verify_window self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" pass def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): """set reasoning parser""" pass def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): """post process after running""" pass def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []): """apply logits processor to sampler""" pass def compute_logprobs( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, ) -> paddle.Tensor: """compute logprobs""" share_inputs = sampling_metadata.share_inputs last_logits = logits real_bsz = share_inputs["seq_lens_this_time"].shape[0] batch_token_num = share_inputs["accept_num"][:real_bsz] temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs if temp_scaled_logprobs is not None: real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] temperature = sampling_metadata.temperature[:real_bsz] real_bsz_temp_scaled = ( real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool") ) temperature = temperature.squeeze(1).repeat_interleave(batch_token_num) temp_temperature = paddle.where( real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) ).unsqueeze(1) last_logits = last_logits / temp_temperature last_logprobs = F.log_softmax(last_logits, axis=-1) top_p_logprob = None top_p_token_mask = None if top_p_normalized_logprobs is not None and share_inputs is not None: real_token_top_p = ( sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) ) top_p_normalized_logprobs = ( top_p_normalized_logprobs[:real_bsz] .astype("int32") .squeeze(1) .repeat_interleave(batch_token_num) .astype("bool") .unsqueeze(1) ) top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) if top_p_token_mask.any(): probs = F.softmax(last_logits, axis=-1) probs = top_p_normalize_probs_paddle(probs, real_token_top_p) top_p_logprob = paddle.log(probs) if top_p_logprob is not None: last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) return last_logprobs def gather_logprobs( self, logprobs: paddle.Tensor, num_logprobs: int, token_ids: paddle.Tensor, ) -> LogprobsTensors: """ Gather logprobs for topk and sampled/prompt token. Args: logprobs: (num tokens) x (vocab) tensor num_logprobs: minimum number of logprobs to retain per token token_ids: prompt tokens (if prompt logprobs) or sampled tokens (if sampled logprobs); 1D token ID tensor with (num tokens) elements Must be int64. Returns: Top-k int indices tensor, (num tokens) x (num_logprobs + 1) Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) Sampled token rank tensor, (num tokens) """ assert token_ids.dtype == paddle.int64 token_ids = token_ids.unsqueeze(1) logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) # Get with the logprob of the prompt or sampled token. token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) # Compute the ranks of the actual token. token_ranks = (logprobs >= token_logprobs).sum(-1) if num_logprobs >= 1: # Find the topK values. topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) indices = paddle.concat([token_ids, topk_indices], axis=1) top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) else: indices = token_ids top_logprobs = token_logprobs return LogprobsTensors(indices, top_logprobs, token_ranks) def forward_cuda( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> paddle.Tensor: """ """ from fastdeploy.model_executor.ops.gpu import speculate_verify, top_p_candidates logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], share_inputs["output_padding_offset"], share_inputs["output_cum_offsets"], max_model_len, ) probs = F.softmax(logits) top_p, top_k = padding_sampling_params( sampling_metadata.top_p, sampling_metadata.top_k, share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], ) _, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0]) verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( probs, sampling_metadata.top_p, share_inputs["output_padding_offset"], self.speculative_max_candidate_len, max_model_len, ) speculate_verify( sampled_token_ids, share_inputs["accept_tokens"], share_inputs["accept_num"], share_inputs["step_idx"], share_inputs["stop_flags"], share_inputs["seq_lens_encoder"], share_inputs["seq_lens_decoder"], share_inputs[ "draft_tokens" ], # Both input and output, need to write the last 1 token accepted to position 0. share_inputs["seq_lens_this_time"], verify_tokens, verify_scores, share_inputs["max_dec_len"], sampling_metadata.eos_token_ids, share_inputs["is_block_step"], share_inputs["output_cum_offsets"], actual_candidate_len, share_inputs["actual_draft_token_num"], sampling_metadata.top_p, max_model_len, self.speculative_verify_window, True, # enable_topp (self.speculative_benchmark_mode or reject_all_drafts), accept_all_drafts, ) num_logprobs = sampling_metadata.max_num_logprobs batch_token_num = None if num_logprobs is not None: real_bsz = share_inputs["seq_lens_this_time"].shape[0] batch_token_num = paddle.where( share_inputs["seq_lens_encoder"][:real_bsz] != 0, paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), share_inputs["seq_lens_this_time"], ).squeeze(1) share_inputs["batch_token_num"] = batch_token_num ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( "int32" ) cu_batch_token_offset = paddle.concat( [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])] ).astype("int32") share_inputs["cu_batch_token_offset"] = cu_batch_token_offset target_logits = paddle.empty( [share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype ) speculate_get_target_logits( target_logits, logits, cu_batch_token_offset, ori_cu_batch_token_offset, share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], share_inputs["accept_num"], ) if self.logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata) elif self.logprobs_mode == "raw_logits": raw_logprobs = target_logits.clone() logprobs_tensors = None token_ids = share_inputs["accept_tokens"] if num_logprobs is not None: token_ids = paddle.concat( [share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)] ) logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) sampler_output = SamplerOutput( sampled_token_ids=token_ids, logprobs_tensors=logprobs_tensors, token_num_per_batch=share_inputs["accept_num"], cu_batch_token_offset=share_inputs["cu_batch_token_offset"], ) return sampler_output def forward_xpu( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> paddle.Tensor: from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], share_inputs["output_padding_offset"], share_inputs["output_cum_offsets"], max_model_len, ) probs = F.softmax(logits) verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( probs, sampling_metadata.top_p, share_inputs["output_padding_offset"], self.speculative_max_candidate_len, max_model_len, ) speculate_verify( share_inputs["accept_tokens"], share_inputs["accept_num"], share_inputs["step_idx"], share_inputs["stop_flags"], share_inputs["seq_lens_encoder"], share_inputs["seq_lens_decoder"], share_inputs[ "draft_tokens" ], # Both input and output, need to write the last 1 token accepted to position 0. share_inputs["seq_lens_this_time"], verify_tokens, verify_scores, share_inputs["max_dec_len"], sampling_metadata.eos_token_ids, share_inputs["is_block_step"], share_inputs["output_cum_offsets"], actual_candidate_len, share_inputs["actual_draft_token_num"], sampling_metadata.top_p, max_model_len, self.speculative_verify_window, True, # enable_topp (self.speculative_benchmark_mode or reject_all_drafts), accept_all_drafts, ) # TODO(chenhuan09): support return logprobs token_ids = share_inputs["accept_tokens"] sampler_output = SamplerOutput( sampled_token_ids=token_ids, logprobs_tensors=None, token_num_per_batch=share_inputs["accept_num"], cu_batch_token_offset=None, ) return sampler_output class MTPSampler(nn.Layer): """ """ def __init__(self, fd_config: FDConfig): """ """ super().__init__() if current_platform.is_cuda(): self.forward = self.forward_cuda elif current_platform.is_xpu(): self.forward = self.forward_xpu else: raise NotImplementedError self.logprobs_mode = fd_config.model_config.logprobs_mode def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" pass def apply_logits_processor( self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = [], ): """apply logits processor to sampler""" pass def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): """set reasoning parser""" pass def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): """post process after running""" pass def compute_logprobs( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, ) -> paddle.Tensor: """compute logprobs""" share_inputs = sampling_metadata.share_inputs real_bsz = share_inputs["seq_lens_this_time"].shape[0] last_logits = logits temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs if temp_scaled_logprobs is not None: real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] temperature = sampling_metadata.temperature[:real_bsz] real_bsz_temp_scaled = ( real_bsz_temp_scaled.astype("int32") .squeeze(1) .repeat_interleave(share_inputs["batch_token_num"][:real_bsz]) .astype("bool") ) temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"][:real_bsz]) temp_temperature = paddle.where( real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) ).unsqueeze(1) last_logits = last_logits / temp_temperature last_logprobs = F.log_softmax(last_logits, axis=-1) top_p_logprob = None top_p_token_mask = None if top_p_normalized_logprobs is not None and share_inputs is not None: real_token_top_p = ( sampling_metadata.top_p[:real_bsz] .squeeze(1) .repeat_interleave(share_inputs["batch_token_num"][:real_bsz]) .unsqueeze(1) ) top_p_normalized_logprobs = ( top_p_normalized_logprobs[:real_bsz] .astype("int32") .squeeze(1) .repeat_interleave(share_inputs["batch_token_num"][:real_bsz]) .astype("bool") .unsqueeze(1) ) top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) if top_p_token_mask.any(): probs = F.softmax(last_logits, axis=-1) probs = top_p_normalize_probs_paddle(probs, real_token_top_p) top_p_logprob = paddle.log(probs) if top_p_logprob is not None: last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) return last_logprobs def gather_logprobs( self, logprobs: paddle.Tensor, num_logprobs: int, token_ids: paddle.Tensor, ) -> LogprobsTensors: """ Gather logprobs for topk and sampled/prompt token. Args: logprobs: (num tokens) x (vocab) tensor num_logprobs: minimum number of logprobs to retain per token token_ids: prompt tokens (if prompt logprobs) or sampled tokens (if sampled logprobs); 1D token ID tensor with (num tokens) elements Must be int64. Returns: Top-k int indices tensor, (num tokens) x (num_logprobs + 1) Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) Sampled token rank tensor, (num tokens) """ assert token_ids.dtype == paddle.int64 token_ids = token_ids.unsqueeze(1) logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) # Get with the logprob of the prompt or sampled token. token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) # Compute the ranks of the actual token. token_ranks = (logprobs >= token_logprobs).sum(-1) if num_logprobs >= 1: # Find the topK values. topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) indices = paddle.concat([token_ids, topk_indices], axis=1) top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) else: indices = token_ids top_logprobs = token_logprobs return LogprobsTensors(indices, top_logprobs, token_ranks) def forward_cuda( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], ) -> paddle.Tensor: """ """ num_logprobs = sampling_metadata.max_num_logprobs real_bsz = share_inputs["seq_lens_this_time"].shape[0] if num_logprobs is not None and share_inputs["substep"] == 0: real_token_num = share_inputs["batch_token_num"][:real_bsz].sum() if self.logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs( share_inputs["draft_logits"][:real_token_num, :], sampling_metadata ) elif self.logprobs_mode == "raw_logits": raw_logprobs = share_inputs["draft_logits"][:real_token_num, :].clone() logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], share_inputs["output_padding_offset"], share_inputs["output_cum_offsets"], max_model_len, ) probs = F.softmax(logits) top_p, top_k = padding_sampling_params( sampling_metadata.top_p, sampling_metadata.top_k, share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], ) _, next_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0]) token_ids = None logprobs_tensors = None if num_logprobs is not None and share_inputs["substep"] == 0: token_ids = paddle.empty(real_token_num, dtype="int64") speculate_insert_first_token( token_ids, share_inputs["accept_tokens"], next_tokens, share_inputs["cu_next_token_offset"], share_inputs["cu_batch_token_offset"], share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], ) logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) sampler_output = SamplerOutput( sampled_token_ids=token_ids, logprobs_tensors=logprobs_tensors, token_num_per_batch=share_inputs["batch_token_num"][:real_bsz], cu_batch_token_offset=share_inputs["cu_batch_token_offset"], ) return next_tokens, sampler_output def forward_xpu( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], ) -> paddle.Tensor: logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, sampling_metadata.repetition_penalties, sampling_metadata.frequency_penalties, sampling_metadata.presence_penalties, sampling_metadata.temperature, sampling_metadata.bad_words_token_ids, sampling_metadata.step_idx, sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], share_inputs["output_padding_offset"], share_inputs["output_cum_offsets"], max_model_len, ) probs = F.softmax(logits) _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list ) # TODO(chenhuan09): add support for logprobs token_ids = None logprobs_tensors = None sampler_output = SamplerOutput( sampled_token_ids=token_ids, logprobs_tensors=logprobs_tensors, token_num_per_batch=None, cu_batch_token_offset=None, ) return next_tokens, sampler_output