mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-30 03:22:05 +08:00
native top_p_sampling (#2901)
This commit is contained in:
@@ -63,7 +63,12 @@ class SiluAndMul(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_iluvatar():
|
if (
|
||||||
|
current_platform.is_cuda()
|
||||||
|
or current_platform.is_xpu()
|
||||||
|
or current_platform.is_iluvatar()
|
||||||
|
or current_platform.is_dcu()
|
||||||
|
):
|
||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
elif current_platform.is_gcu():
|
elif current_platform.is_gcu():
|
||||||
self.forward = self.forward_gcu
|
self.forward = self.forward_gcu
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from fastdeploy.model_executor.layers.attention.ops import (
|
|||||||
)
|
)
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda() and not current_platform.is_dcu():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
decode_mla_write_cache,
|
decode_mla_write_cache,
|
||||||
multi_head_latent_attention,
|
multi_head_latent_attention,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import paddle
|
|||||||
|
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda() and not current_platform.is_dcu():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
append_attention as append_attention_gpu,
|
append_attention as append_attention_gpu,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,5 +18,6 @@ dcu backend methods
|
|||||||
|
|
||||||
from .fused_moe_triton_backends import DCUTritonWeightOnlyMoEMethod
|
from .fused_moe_triton_backends import DCUTritonWeightOnlyMoEMethod
|
||||||
from .weight_only import DCUWeightOnlyLinearMethod
|
from .weight_only import DCUWeightOnlyLinearMethod
|
||||||
|
from .top_p_sampling import native_top_p_sampling
|
||||||
|
|
||||||
__all__ = ["DCUTritonWeightOnlyMoEMethod", "DCUWeightOnlyLinearMethod"]
|
__all__ = ["DCUTritonWeightOnlyMoEMethod", "DCUWeightOnlyLinearMethod", "native_top_p_sampling"]
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
def native_top_p_sampling(
|
||||||
|
probs: paddle.Tensor,
|
||||||
|
top_p: paddle.Tensor
|
||||||
|
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
sorted_indices = paddle.argsort(probs, descending=True)
|
||||||
|
sorted_probs = paddle.sort(probs, descending=True)
|
||||||
|
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
|
||||||
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||||
|
sorted_indices_to_remove[:, 0] = 0
|
||||||
|
sorted_indices = sorted_indices + paddle.arange(probs.shape[0], dtype="int64").unsqueeze(-1) * probs.shape[-1]
|
||||||
|
|
||||||
|
condition = paddle.scatter(
|
||||||
|
sorted_indices_to_remove.flatten(), sorted_indices.flatten(), sorted_indices_to_remove.flatten()
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = paddle.cast(condition, "bool").reshape(probs.shape)
|
||||||
|
probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
|
||||||
|
next_tokens = paddle.multinomial(probs)
|
||||||
|
|
||||||
|
return None, next_tokens
|
||||||
@@ -61,6 +61,7 @@ class LinearBase(nn.Layer):
|
|||||||
or current_platform.is_xpu()
|
or current_platform.is_xpu()
|
||||||
or current_platform.is_iluvatar()
|
or current_platform.is_iluvatar()
|
||||||
or current_platform.is_gcu()
|
or current_platform.is_gcu()
|
||||||
|
or current_platform.is_dcu()
|
||||||
):
|
):
|
||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from fastdeploy.platforms import current_platform
|
|||||||
from ..utils import create_and_set_parameter, get_tensor
|
from ..utils import create_and_set_parameter, get_tensor
|
||||||
from .fused_moe_backend_base import MoEMethodBase
|
from .fused_moe_backend_base import MoEMethodBase
|
||||||
|
|
||||||
if current_platform.is_cuda() and not current_platform.is_dcu():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
moe_expert_dispatch,
|
moe_expert_dispatch,
|
||||||
moe_expert_reduce,
|
moe_expert_reduce,
|
||||||
|
|||||||
@@ -53,6 +53,23 @@ def apply_penalty_multi_scores(
|
|||||||
min_dec_lens,
|
min_dec_lens,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
)
|
)
|
||||||
|
elif current_platform.is_dcu():
|
||||||
|
from fastdeploy.model_executor.ops.gpu import \
|
||||||
|
get_token_penalty_multi_scores
|
||||||
|
logits = get_token_penalty_multi_scores(
|
||||||
|
pre_token_ids,
|
||||||
|
prompt_ids,
|
||||||
|
prompt_lens,
|
||||||
|
logits,
|
||||||
|
repetition_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
presence_penalties,
|
||||||
|
temperature,
|
||||||
|
bad_words_token_ids,
|
||||||
|
step_idx,
|
||||||
|
min_dec_lens,
|
||||||
|
eos_token_ids,
|
||||||
|
)
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
from fastdeploy.model_executor.ops.xpu import get_token_penalty_multi_scores
|
from fastdeploy.model_executor.ops.xpu import get_token_penalty_multi_scores
|
||||||
|
|
||||||
|
|||||||
@@ -79,6 +79,9 @@ def top_k_top_p_sampling(
|
|||||||
else:
|
else:
|
||||||
if current_platform.is_gcu():
|
if current_platform.is_gcu():
|
||||||
_, ids = gcu_top_p_sampling(x, top_p)
|
_, ids = gcu_top_p_sampling(x, top_p)
|
||||||
|
elif current_platform.is_dcu():
|
||||||
|
from fastdeploy.model_executor.layers.backends import native_top_p_sampling
|
||||||
|
_, ids = native_top_p_sampling(x, top_p)
|
||||||
else:
|
else:
|
||||||
_, ids = paddle.tensor.top_p_sampling(
|
_, ids = paddle.tensor.top_p_sampling(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ class Sampler(nn.Layer):
|
|||||||
or current_platform.is_xpu()
|
or current_platform.is_xpu()
|
||||||
or current_platform.is_iluvatar()
|
or current_platform.is_iluvatar()
|
||||||
or current_platform.is_gcu()
|
or current_platform.is_gcu()
|
||||||
|
or current_platform.is_dcu()
|
||||||
):
|
):
|
||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (
|
|||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda() and not current_platform.is_dcu():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
extract_text_token_output,
|
extract_text_token_output,
|
||||||
text_image_gather_scatter,
|
text_image_gather_scatter,
|
||||||
|
|||||||
@@ -479,6 +479,17 @@ def rebuild_padding(
|
|||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import rebuild_padding
|
from fastdeploy.model_executor.ops.gpu import rebuild_padding
|
||||||
|
|
||||||
|
hidden_states = rebuild_padding(
|
||||||
|
tmp_out,
|
||||||
|
cum_offsets,
|
||||||
|
seq_len_this_time,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_encoder,
|
||||||
|
output_padding_offset,
|
||||||
|
max_input_length,
|
||||||
|
)
|
||||||
|
elif current_platform.is_dcu():
|
||||||
|
from fastdeploy.model_executor.ops.gpu import rebuild_padding
|
||||||
hidden_states = rebuild_padding(
|
hidden_states = rebuild_padding(
|
||||||
tmp_out,
|
tmp_out,
|
||||||
cum_offsets,
|
cum_offsets,
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
whether platform is cuda
|
whether platform is cuda
|
||||||
"""
|
"""
|
||||||
return paddle.is_compiled_with_cuda()
|
return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
|
||||||
|
|
||||||
def is_npu(self) -> bool:
|
def is_npu(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -28,3 +28,10 @@ use-triton-in-paddle
|
|||||||
crcmod
|
crcmod
|
||||||
fastsafetensors==0.1.14
|
fastsafetensors==0.1.14
|
||||||
msgpack
|
msgpack
|
||||||
|
opentelemetry-api>=1.24.0
|
||||||
|
opentelemetry-sdk>=1.24.0
|
||||||
|
opentelemetry-instrumentation-redis
|
||||||
|
opentelemetry-instrumentation-mysql
|
||||||
|
opentelemetry-distro
|
||||||
|
opentelemetry-exporter-otlp
|
||||||
|
opentelemetry-instrumentation-fastapi
|
||||||
|
|||||||
Reference in New Issue
Block a user