mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
[OPs] Universal optimization and Fix early_stop cuda 700 (#3375)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* delete nonzero * delete setup_ops_base.py * check if * check gcp infer_seed.cpu() * fix repetition_early_stopper_kernel cuda 700
This commit is contained in:
@@ -90,10 +90,10 @@ class RepetitionEarlyStopper(EarlyStopper):
|
||||
)
|
||||
|
||||
B, W = self.trunc_scores.shape
|
||||
V = probs.shape[1]
|
||||
real_bsz, V = probs.shape
|
||||
BLOCK_W = triton.next_power_of_2(W)
|
||||
|
||||
grid = (B,)
|
||||
grid = (real_bsz,)
|
||||
repetition_early_stopper_kernel[grid](
|
||||
self.trunc_scores,
|
||||
probs,
|
||||
|
@@ -42,7 +42,9 @@ class SamplingMetadata:
|
||||
|
||||
top_p: paddle.Tensor
|
||||
top_k: Optional[paddle.Tensor] = None
|
||||
top_k_list: Optional[list] = None
|
||||
min_p: Optional[paddle.Tensor] = None
|
||||
min_p_list: Optional[list] = None
|
||||
seed: Optional[paddle.Tensor] = None
|
||||
max_num_logprobs: Optional[int] = None
|
||||
enable_early_stop: Optional[int] = False
|
||||
|
@@ -29,6 +29,7 @@ def top_k_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: Optional[paddle.Tensor] = None,
|
||||
top_k_list: Optional[list] = None,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
@@ -64,7 +65,7 @@ def top_k_top_p_sampling(
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
|
||||
elif top_p_class == "rejection":
|
||||
ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
|
||||
ids = rejection_top_p_sampling(x, top_p, top_k, top_k_list, seed, order)
|
||||
_ = None
|
||||
elif top_p_class == "base_non_truncated":
|
||||
_, ids = paddle.tensor.top_p_sampling(
|
||||
@@ -121,6 +122,7 @@ def rejection_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: paddle.Tensor,
|
||||
top_k_list: list,
|
||||
seed: int = -1,
|
||||
order: Literal["top_k_first", "joint"] = "top_k_first",
|
||||
) -> paddle.Tensor:
|
||||
@@ -139,7 +141,7 @@ def rejection_top_p_sampling(
|
||||
top_k_renorm_probs,
|
||||
)
|
||||
|
||||
if paddle.count_nonzero(top_k) == 0:
|
||||
if not any(x > 0 for x in top_k_list):
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
@@ -170,11 +172,12 @@ def rejection_top_p_sampling(
|
||||
def min_p_sampling(
|
||||
probs: paddle.tensor,
|
||||
min_p_arr: Optional[paddle.Tensor],
|
||||
min_p_arr_cpu: Optional[list],
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
min_p_sampling
|
||||
"""
|
||||
if paddle.count_nonzero(min_p_arr) == 0:
|
||||
if not any(x > 0 for x in min_p_arr_cpu):
|
||||
return probs
|
||||
else:
|
||||
if current_platform.is_cuda():
|
||||
|
@@ -281,10 +281,13 @@ class Sampler(nn.Layer):
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
probs = min_p_sampling(probs, sampling_metadata.min_p)
|
||||
|
||||
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs, sampling_metadata.top_p, sampling_metadata.top_k, seed=sampling_metadata.seed[0, 0]
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_k_list,
|
||||
seed=sampling_metadata.seed[0, 0],
|
||||
)
|
||||
|
||||
logprobs_tensors = (
|
||||
|
@@ -19,7 +19,6 @@ from fastdeploy.import_ops import import_custom_ops
|
||||
|
||||
PACKAGE = "fastdeploy.model_executor.ops.gpu"
|
||||
|
||||
import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals())
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
|
||||
|
@@ -17,7 +17,6 @@ from fastdeploy.import_ops import import_custom_ops
|
||||
|
||||
PACKAGE = "fastdeploy.model_executor.ops.iluvatar"
|
||||
|
||||
import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals())
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: F401
|
||||
|
@@ -94,7 +94,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
shape=[self.parallel_config.max_num_seqs, 1],
|
||||
fill_value=4,
|
||||
dtype="int64",
|
||||
)
|
||||
).cpu()
|
||||
self.restore_chunked_prefill_request = dict()
|
||||
|
||||
# Initialize attention Backend
|
||||
@@ -239,7 +239,9 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
|
||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
|
||||
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
|
||||
|
||||
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
|
||||
@@ -361,7 +363,9 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_k_list"] = [0] * max_num_seqs
|
||||
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
|
||||
self.share_inputs["min_p_list"] = [0.0] * max_num_seqs
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
||||
)
|
||||
@@ -408,7 +412,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
self.share_inputs["need_block_len"] = paddle.full([1], 0, dtype="int32")
|
||||
self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu()
|
||||
self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
@@ -539,7 +543,9 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
top_k_list=self.share_inputs["top_k_list"],
|
||||
min_p=self.share_inputs["min_p"],
|
||||
min_p_list=self.share_inputs["min_p_list"],
|
||||
seed=self.share_inputs["infer_seed"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
|
@@ -138,7 +138,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
shape=[self.parallel_config.max_num_seqs, 1],
|
||||
fill_value=4,
|
||||
dtype="int64",
|
||||
)
|
||||
).cpu()
|
||||
|
||||
self.restore_chunked_prefill_request = dict()
|
||||
|
||||
@@ -315,6 +315,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
|
||||
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
|
||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
|
||||
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
|
||||
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
||||
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
||||
@@ -478,7 +482,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
|
||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
|
||||
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
|
||||
|
||||
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
|
||||
@@ -612,7 +618,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_k_list"] = [0] * max_num_seqs
|
||||
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
|
||||
self.share_inputs["min_p_list"] = [0.0] * max_num_seqs
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
||||
)
|
||||
@@ -661,7 +669,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
self.share_inputs["need_block_len"] = paddle.full([1], 0, dtype="int32")
|
||||
self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu()
|
||||
self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
@@ -830,7 +838,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
top_k_list=self.share_inputs["top_k_list"],
|
||||
min_p=self.share_inputs["min_p"],
|
||||
min_p_list=self.share_inputs["min_p_list"],
|
||||
seed=self.share_inputs["infer_seed"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
|
@@ -361,7 +361,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
shape=[self.parallel_config.max_num_seqs, 1],
|
||||
fill_value=4,
|
||||
dtype="int64",
|
||||
)
|
||||
).cpu()
|
||||
|
||||
# Initialize attention Backend
|
||||
# Note(gonshaotian): Currently, all attention layers share one attention backend instance.
|
||||
@@ -435,6 +435,10 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
|
||||
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
|
||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
|
||||
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
|
||||
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
||||
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
||||
@@ -476,7 +480,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
|
||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
|
||||
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
|
||||
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
||||
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
||||
@@ -547,7 +553,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_k_list"] = [0] * max_num_seqs
|
||||
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
|
||||
self.share_inputs["min_p_list"] = [0.0] * max_num_seqs
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
||||
)
|
||||
@@ -674,7 +682,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
top_k_list=self.share_inputs["top_k_list"],
|
||||
min_p=self.share_inputs["min_p"],
|
||||
min_p_list=self.share_inputs["min_p_list"],
|
||||
seed=self.share_inputs["infer_seed"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
|
Reference in New Issue
Block a user