polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -13,21 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \
LogitsProcessorBase
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
LogitsProcessorBase,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
top_k_top_p_sampling)
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
@@ -44,11 +48,13 @@ class SamplerProcessor:
self.executor = ThreadPoolExecutor()
self.logits_lock = threading.Lock()
def add_logits_processor(self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = []):
""" add logits processor to SamplerProcessor """
def add_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""add logits processor to SamplerProcessor"""
with self.logits_lock:
if future is None:
if ids in self.logits_processor:
@@ -67,7 +73,7 @@ class SamplerProcessor:
self.logits_processor[ids] = [future, prefill_tokens]
def update_vocab_mask(self, skip_idx_list: List[int] = []):
""" update vocab mask. (cpu-heavy operation) """
"""update vocab mask. (cpu-heavy operation)"""
if len(self.logits_processor) == 0:
return
@@ -102,10 +108,8 @@ class SamplerProcessor:
processor.fill_token_bitmask(self.token_bitmask, idx)
def apply_token_mask(self,
logits: paddle.Tensor,
skip_idx_list: List[int] = []):
""" apply token mask to logits """
def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []):
"""apply token mask to logits"""
if len(self.logits_processor) == 0 or self.token_bitmask is None:
return logits
@@ -121,26 +125,20 @@ class SamplerProcessor:
indices = list(self.logits_processor.keys())
mask_idx = [i for i in indices if i not in skip_idx_list]
return available_processors.apply_token_mask(logits,
self.token_bitmask,
indices=mask_idx)
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=mask_idx)
def _accept_token(self, idx: int, token: int):
""" accept token """
"""accept token"""
if idx not in self.logits_processor:
raise ValueError(
f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}"
)
raise ValueError(f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}")
if self.logits_processor[idx].is_terminated():
return
self.logits_processor[idx].accept_token(token)
def update_output_tokens(self,
next_tokens: paddle.Tensor,
skip_idx_list: List[int] = []):
""" update output tokens """
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""update output tokens"""
if len(self.logits_processor) == 0:
return
@@ -148,14 +146,13 @@ class SamplerProcessor:
with self.logits_lock:
for idx in self.logits_processor.keys():
token = token_ids[idx][0]
if token < 0 or self.logits_processor[
idx] is None or idx in skip_idx_list:
if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list:
continue
self._accept_token(idx, token)
def pre_process(self, skip_idx_list: List[int] = []):
""" pre process before running """
"""pre process before running"""
# create async operation for guided decoding
# TODO: support async
self.update_vocab_mask(skip_idx_list)
@@ -168,31 +165,35 @@ class Sampler(nn.Layer):
"""
def __init__(self):
"""
"""
""" """
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar() or current_platform.is_gcu():
if (
current_platform.is_cuda()
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_gcu()
):
self.forward = self.forward_cuda
else:
raise NotImplementedError()
raise NotImplementedError
self.processor = SamplerProcessor()
def apply_logits_processor(self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = []):
""" apply logits processor to sampler """
def apply_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""apply logits processor to sampler"""
self.processor.add_logits_processor(ids, future, prefill_tokens)
def pre_process(self, skip_idx_list: List[int] = []):
""" pre process before running """
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor:
"""
"""
""" """
return F.log_softmax(logits, axis=-1)
def gather_logprobs(
@@ -226,9 +227,7 @@ class Sampler(nn.Layer):
if num_logprobs >= 1:
# Find the topK values.
topk_logprobs, topk_indices = paddle.topk(logprobs,
num_logprobs,
axis=-1)
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
@@ -243,8 +242,7 @@ class Sampler(nn.Layer):
sampling_metadata: SamplingMetadata,
skip_idx_list: List[int] = [],
) -> SamplerOutput:
"""
"""
""" """
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
@@ -270,8 +268,9 @@ class Sampler(nn.Layer):
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
logprobs_tensors = None if num_logprobs is None else \
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
logprobs_tensors = (
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
)
self.processor.update_output_tokens(next_tokens, skip_idx_list)
@@ -291,26 +290,27 @@ class SpeculativeSampler(nn.Layer):
"""
def __init__(self, fd_config: FDConfig):
"""
"""
""" """
super().__init__()
if current_platform.is_cuda():
self.forward = self.forward_cuda
else:
raise NotImplementedError()
raise NotImplementedError
self.speculative_verify_window = fd_config.speculative_config.verify_window
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
def pre_process(self, skip_idx_list: List[int] = []):
""" pre process before running """
"""pre process before running"""
pass
def apply_logits_processor(self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = []):
""" apply logits processor to sampler """
def apply_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""apply logits processor to sampler"""
pass
def forward_cuda(
@@ -320,11 +320,9 @@ class SpeculativeSampler(nn.Layer):
max_model_len: int,
share_inputs: List[paddle.Tensor],
) -> paddle.Tensor:
"""
"""
""" """
from fastdeploy.model_executor.ops.gpu import (speculate_verify,
top_p_candidates)
from fastdeploy.model_executor.ops.gpu import speculate_verify, top_p_candidates
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
@@ -361,7 +359,8 @@ class SpeculativeSampler(nn.Layer):
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs[
"draft_tokens"], # Both input and output, need to write the last 1 token accepted to position 0.
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
share_inputs["seq_lens_this_time"],
verify_tokens,
verify_scores,
@@ -382,27 +381,27 @@ class SpeculativeSampler(nn.Layer):
class MTPSampler(nn.Layer):
"""
"""
""" """
def __init__(self, fd_config: FDConfig):
"""
"""
""" """
super().__init__()
if current_platform.is_cuda():
self.forward = self.forward_cuda
else:
raise NotImplementedError()
raise NotImplementedError
def pre_process(self, skip_idx_list: List[int] = []):
""" pre process before running """
"""pre process before running"""
pass
def apply_logits_processor(self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = []):
""" apply logits processor to sampler """
def apply_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""apply logits processor to sampler"""
pass
def forward_cuda(
@@ -412,8 +411,7 @@ class MTPSampler(nn.Layer):
max_model_len: int,
share_inputs: List[paddle.Tensor],
) -> paddle.Tensor:
"""
"""
""" """
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,