mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[GCU] Support gcu platform (#2702)
baseline: e7fa57ebae
Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
@@ -79,6 +79,21 @@ def apply_penalty_multi_scores(
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import \
|
||||
get_token_penalty_multi_scores
|
||||
logits = get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@@ -19,7 +19,11 @@ from typing import Literal, Optional
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import \
|
||||
top_p_sampling as gcu_top_p_sampling
|
||||
|
||||
def top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
@@ -46,13 +50,16 @@ def top_p_sampling(
|
||||
ids = rejection_top_p_sampling(x, ps, seed)
|
||||
_ = None
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
if current_platform.is_gcu():
|
||||
_, ids = gcu_top_p_sampling(x, ps)
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
return _, ids
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user