mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[OPs] Universal optimization and Fix early_stop cuda 700 (#3375)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* delete nonzero * delete setup_ops_base.py * check if * check gcp infer_seed.cpu() * fix repetition_early_stopper_kernel cuda 700
This commit is contained in:
@@ -29,6 +29,7 @@ def top_k_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: Optional[paddle.Tensor] = None,
|
||||
top_k_list: Optional[list] = None,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
@@ -64,7 +65,7 @@ def top_k_top_p_sampling(
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
|
||||
elif top_p_class == "rejection":
|
||||
ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
|
||||
ids = rejection_top_p_sampling(x, top_p, top_k, top_k_list, seed, order)
|
||||
_ = None
|
||||
elif top_p_class == "base_non_truncated":
|
||||
_, ids = paddle.tensor.top_p_sampling(
|
||||
@@ -121,6 +122,7 @@ def rejection_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: paddle.Tensor,
|
||||
top_k_list: list,
|
||||
seed: int = -1,
|
||||
order: Literal["top_k_first", "joint"] = "top_k_first",
|
||||
) -> paddle.Tensor:
|
||||
@@ -139,7 +141,7 @@ def rejection_top_p_sampling(
|
||||
top_k_renorm_probs,
|
||||
)
|
||||
|
||||
if paddle.count_nonzero(top_k) == 0:
|
||||
if not any(x > 0 for x in top_k_list):
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
@@ -170,11 +172,12 @@ def rejection_top_p_sampling(
|
||||
def min_p_sampling(
|
||||
probs: paddle.tensor,
|
||||
min_p_arr: Optional[paddle.Tensor],
|
||||
min_p_arr_cpu: Optional[list],
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
min_p_sampling
|
||||
"""
|
||||
if paddle.count_nonzero(min_p_arr) == 0:
|
||||
if not any(x > 0 for x in min_p_arr_cpu):
|
||||
return probs
|
||||
else:
|
||||
if current_platform.is_cuda():
|
||||
|
Reference in New Issue
Block a user