[Feature] support top_k_top_p sampling (#2753)

* support top_k_top_p sampling

* fix

* add api param

* add api para

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Sunny-bot1
2025-07-10 11:58:58 +08:00
committed by GitHub
parent b0f525955c
commit e45050cae3
15 changed files with 501 additions and 53 deletions

View File

@@ -16,8 +16,8 @@
import json
import os
from datetime import datetime
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from fastdeploy import envs

View File

@@ -52,6 +52,7 @@ class SamplingParams:
the model more random. Zero means greedy sampling.
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
seed: Random seed to use for the generation.
stop: list of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
@@ -81,7 +82,8 @@ class SamplingParams:
frequency_penalty: float = None
repetition_penalty: float = None
temperature: float = None
top_p: float = None
top_p: float = 1.0
top_k: int = 0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
@@ -111,6 +113,7 @@ class SamplingParams:
repetition_penalty,
temperature,
top_p,
top_k,
seed=None,
stop=None,
stop_token_ids=None,
@@ -129,7 +132,8 @@ class SamplingParams:
repetition_penalty=repetition_penalty
if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0,
top_p=top_p if top_p is not None else 0.7,
top_p=top_p if top_p is not None else 1.0,
top_k=top_k if top_k is not None else 0,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,

View File

@@ -292,6 +292,7 @@ class CompletionRequest(BaseModel):
suffix: Optional[dict] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
user: Optional[str] = None
response_format: Optional[AnyResponseFormat] = None
@@ -405,6 +406,7 @@ class ChatCompletionRequest(BaseModel):
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
user: Optional[str] = None
metadata: Optional[dict] = None

View File

@@ -27,34 +27,56 @@ if current_platform.is_gcu():
def top_p_sampling(
x: paddle.Tensor,
ps: paddle.Tensor,
top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,
threshold: Optional[paddle.Tensor] = None,
topp_seed: Optional[paddle.Tensor] = None,
seed: int = -1,
k: int = 0,
mode: Literal['truncated', 'non-truncated'] = "truncated",
order: Literal['top_k_first', 'joint'] = "top_k_first",
) -> tuple[paddle.Tensor, paddle.Tensor]:
"""
top_p_sampling
x(Tensor): An input 2-D Tensor with type float32, float16 and bfloat16.
top_p(Tensor): A 1-D Tensor with type float32, float16 and bfloat16,
used to specify the top_p corresponding to each query.
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
used to specify the top_k corresponding to each query.
Only used when FD_SAMPLING_CLASS is `rejection`.
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
used to avoid sampling low score tokens.
topp_seed(Tensor|None, optional): A 1-D Tensor with type int64,
used to specify the random seed for each query.
seed(int, optional): the random seed. Default is -1,
k(int): the number of top_k scores/ids to be returned. Default is 0.
Only used when FD_SAMPLING_CLASS is `air`.
mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value.
If the mode is `non-truncated`, it will not be truncated. Default is `truncated`.
Only used when FD_SAMPLING_CLASS is `air` or `base`.
order(str): The order of applying top-k and top-p sampling, should be either `top_k_first` or `joint`.
If `top_k_first`, we first apply top-k filter, then apply top-p sampling on the top-k results.
If `joint`, we apply top-k and top-p filter simultaneously in each round. Default is `top_k_first`.
Only used when FD_SAMPLING_CLASS is `rejection`.
"""
top_p_class = envs.FD_SAMPLING_CLASS.lower()
if top_p_class == "air":
_, ids = air_top_p_sampling(x,
ps,
top_p,
threshold,
topp_seed,
seed=seed,
k=k,
mode=mode)
elif top_p_class == "rejection":
ids = rejection_top_p_sampling(x, ps, seed)
ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
_ = None
else:
if current_platform.is_gcu():
_, ids = gcu_top_p_sampling(x, ps)
_, ids = gcu_top_p_sampling(x, top_p)
else:
_, ids = paddle.tensor.top_p_sampling(x,
ps,
top_p,
threshold=threshold,
topp_seed=topp_seed,
seed=seed,
@@ -65,7 +87,7 @@ def top_p_sampling(
def air_top_p_sampling(
x: paddle.Tensor,
ps: paddle.Tensor,
top_p: paddle.Tensor,
threshold: Optional[paddle.Tensor] = None,
topp_seed: Optional[paddle.Tensor] = None,
seed: int = -1,
@@ -77,7 +99,7 @@ def air_top_p_sampling(
"""
try:
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
out, ids = air_top_p_sampling(x, ps, threshold, topp_seed, seed, k,
out, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed, k,
mode)
except ImportError:
raise RuntimeError("Cannot import air_top_p_sampling op.")
@@ -86,19 +108,46 @@ def air_top_p_sampling(
def rejection_top_p_sampling(
x: paddle.Tensor,
ps: paddle.Tensor,
top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,
seed: int = -1,
order: Literal['top_k_first', 'joint'] = "top_k_first",
) -> paddle.Tensor:
"""
rejection_top_p_sampling
"""
assert top_p is not None, "Top_p should not be none when FD_SAMPLING_CLASS is rejection"
try:
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
ids = rejection_top_p_sampling(
x,
ps,
seed,
)
from fastdeploy.model_executor.ops.gpu import (
rejection_top_p_sampling, top_k_renorm_probs)
if top_k is None:
ids = rejection_top_p_sampling(
x,
top_p,
None,
seed,
)
elif top_k is not None and top_p is not None:
if order == "top_k_first":
renorm_probs = top_k_renorm_probs(x, top_k)
ids = rejection_top_p_sampling(
renorm_probs,
top_p,
None,
seed,
)
else:
ids = rejection_top_p_sampling(
x,
top_p,
top_k,
seed,
)
else:
raise ValueError(
"Top_p cannot be none."
)
except ImportError:
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
return ids

View File

@@ -214,7 +214,7 @@ class Sampler(nn.Layer):
probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
self.processor.update_output_tokens(next_tokens, skip_idx_list)
return next_tokens
@@ -367,5 +367,5 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
return next_tokens

View File

@@ -154,12 +154,29 @@ class GCUModelRunner(ModelRunnerBase):
-1].disaggregate_info["role"] == "prefill":
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
prefill_tokens = []
if (request.guided_json is not None
or request.guided_regex is not None
@@ -234,8 +251,8 @@ class GCUModelRunner(ModelRunnerBase):
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -286,6 +303,16 @@ class GCUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts)
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int):
""" Set dummy prefill inputs to share_inputs """
@@ -340,8 +367,11 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
self.model_config.top_p,
dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full(
@@ -563,6 +593,7 @@ class GCUModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"],

View File

@@ -161,6 +161,15 @@ class GPUModelRunner(ModelRunnerBase):
-1].disaggregate_info["role"] == "prefill":
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
@@ -168,6 +177,13 @@ class GPUModelRunner(ModelRunnerBase):
length = len(request.prompt_token_ids)
assert length > 0, "The prompt requested must not be empty."
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
prefill_tokens = []
if (request.guided_json is not None
or request.guided_regex is not None
@@ -242,8 +258,8 @@ class GPUModelRunner(ModelRunnerBase):
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -294,6 +310,16 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts)
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int):
""" Set dummy prefill inputs to share_inputs """
@@ -349,8 +375,11 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
self.model_config.top_p,
dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full(
@@ -574,6 +603,7 @@ class GPUModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"],

View File

@@ -29,9 +29,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import \
AttentionBackend
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import (Sampler,
SpeculativeSampler
)
from fastdeploy.model_executor.layers.sample.sampler import (
Sampler, SpeculativeSampler)
from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
from fastdeploy.model_executor.pre_and_post_process import (post_process,
@@ -145,12 +144,29 @@ class IluvatarModelRunner(ModelRunnerBase):
-1].disaggregate_info["role"] == "prefill":
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
prefill_tokens = []
if (request.guided_json is not None
or request.guided_regex is not None
@@ -225,8 +241,8 @@ class IluvatarModelRunner(ModelRunnerBase):
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -273,6 +289,15 @@ class IluvatarModelRunner(ModelRunnerBase):
idx, request.get("logits_processor"), prefill_tokens)
self.share_inputs["not_need_stop"][0] = True
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int):
@@ -329,8 +354,11 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
self.model_config.top_p,
dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full(
@@ -558,6 +586,7 @@ class IluvatarModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"],

View File

@@ -14,14 +14,14 @@
# limitations under the License.
"""
from abc import ABC, abstractmethod
import argparse
from abc import ABC, abstractmethod
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from fastdeploy.config import ModelConfig
from fastdeploy.config import ModelConfig
from fastdeploy.utils import get_logger
logger = get_logger("worker", "worker.log")

View File

@@ -282,11 +282,26 @@ class XPUModelRunner(ModelRunnerBase):
def process_prefill_inputs(self, req_dicts: List[Request]):
""" Process inputs for prefill tasks and update share_inputs buffer """
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
length = request.prompt_token_ids_len
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array(
request.prompt_token_ids)
if len(request.eos_token_ids
@@ -295,7 +310,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["pre_ids"][idx:idx + 1] = -1
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -344,6 +360,15 @@ class XPUModelRunner(ModelRunnerBase):
request.get("stop_token_ids"), dtype="int64")
self.share_inputs["not_need_stop"][0] = True
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _init_share_inputs(self, max_num_seqs: int):
"""Initialize all share buffers for model inputs.
@@ -363,8 +388,11 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
self.model_config.top_p,
dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full(
@@ -514,6 +542,7 @@ class XPUModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"],