native top_p_sampling (#2901)

This commit is contained in:
lifulll
2025-07-22 14:09:59 +08:00
committed by GitHub
parent 0eedbdaee0
commit 2c6a9e887e
14 changed files with 93 additions and 7 deletions

View File

@@ -63,7 +63,12 @@ class SiluAndMul(nn.Layer):
"""
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_iluvatar():
if (
current_platform.is_cuda()
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_dcu()
):
self.forward = self.forward_cuda
elif current_platform.is_gcu():
self.forward = self.forward_gcu

View File

@@ -32,7 +32,7 @@ from fastdeploy.model_executor.layers.attention.ops import (
)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda() and not current_platform.is_dcu():
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
decode_mla_write_cache,
multi_head_latent_attention,

View File

@@ -20,7 +20,7 @@ import paddle
from fastdeploy.platforms import current_platform
if current_platform.is_cuda() and not current_platform.is_dcu():
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
append_attention as append_attention_gpu,
)

View File

@@ -18,5 +18,6 @@ dcu backend methods
from .fused_moe_triton_backends import DCUTritonWeightOnlyMoEMethod
from .weight_only import DCUWeightOnlyLinearMethod
from .top_p_sampling import native_top_p_sampling
__all__ = ["DCUTritonWeightOnlyMoEMethod", "DCUWeightOnlyLinearMethod"]
__all__ = ["DCUTritonWeightOnlyMoEMethod", "DCUWeightOnlyLinearMethod", "native_top_p_sampling"]

View File

@@ -0,0 +1,40 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
def native_top_p_sampling(
probs: paddle.Tensor,
top_p: paddle.Tensor
) -> tuple[paddle.Tensor, paddle.Tensor]:
sorted_indices = paddle.argsort(probs, descending=True)
sorted_probs = paddle.sort(probs, descending=True)
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_indices = sorted_indices + paddle.arange(probs.shape[0], dtype="int64").unsqueeze(-1) * probs.shape[-1]
condition = paddle.scatter(
sorted_indices_to_remove.flatten(), sorted_indices.flatten(), sorted_indices_to_remove.flatten()
)
condition = paddle.cast(condition, "bool").reshape(probs.shape)
probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
next_tokens = paddle.multinomial(probs)
return None, next_tokens

View File

@@ -61,6 +61,7 @@ class LinearBase(nn.Layer):
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
):
self.forward = self.forward_cuda
else:

View File

@@ -26,7 +26,7 @@ from fastdeploy.platforms import current_platform
from ..utils import create_and_set_parameter, get_tensor
from .fused_moe_backend_base import MoEMethodBase
if current_platform.is_cuda() and not current_platform.is_dcu():
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
moe_expert_dispatch,
moe_expert_reduce,

View File

@@ -53,6 +53,23 @@ def apply_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
)
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import \
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
prompt_ids,
prompt_lens,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
)
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import get_token_penalty_multi_scores

View File

@@ -79,6 +79,9 @@ def top_k_top_p_sampling(
else:
if current_platform.is_gcu():
_, ids = gcu_top_p_sampling(x, top_p)
elif current_platform.is_dcu():
from fastdeploy.model_executor.layers.backends import native_top_p_sampling
_, ids = native_top_p_sampling(x, top_p)
else:
_, ids = paddle.tensor.top_p_sampling(
x,

View File

@@ -173,6 +173,7 @@ class Sampler(nn.Layer):
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
):
self.forward = self.forward_cuda
else:

View File

@@ -44,7 +44,7 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.platforms import current_platform
if current_platform.is_cuda() and not current_platform.is_dcu():
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
extract_text_token_output,
text_image_gather_scatter,

View File

@@ -479,6 +479,17 @@ def rebuild_padding(
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cum_offsets,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length,
)
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cum_offsets,

View File

@@ -39,7 +39,7 @@ class Platform:
"""
whether platform is cuda
"""
return paddle.is_compiled_with_cuda()
return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
def is_npu(self) -> bool:
"""

View File

@@ -28,3 +28,10 @@ use-triton-in-paddle
crcmod
fastsafetensors==0.1.14
msgpack
opentelemetry-api>=1.24.0
opentelemetry-sdk>=1.24.0
opentelemetry-instrumentation-redis
opentelemetry-instrumentation-mysql
opentelemetry-distro 
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi