[Feature] support top_k_top_p sampling (#2753)

* support top_k_top_p sampling

* fix

* add api param

* add api para

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Sunny-bot1
2025-07-10 11:58:58 +08:00
committed by GitHub
parent b0f525955c
commit e45050cae3
15 changed files with 501 additions and 53 deletions

View File

@@ -29,9 +29,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import \
AttentionBackend
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import (Sampler,
SpeculativeSampler
)
from fastdeploy.model_executor.layers.sample.sampler import (
Sampler, SpeculativeSampler)
from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
from fastdeploy.model_executor.pre_and_post_process import (post_process,
@@ -145,12 +144,29 @@ class IluvatarModelRunner(ModelRunnerBase):
-1].disaggregate_info["role"] == "prefill":
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
prefill_tokens = []
if (request.guided_json is not None
or request.guided_regex is not None
@@ -225,8 +241,8 @@ class IluvatarModelRunner(ModelRunnerBase):
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -273,6 +289,15 @@ class IluvatarModelRunner(ModelRunnerBase):
idx, request.get("logits_processor"), prefill_tokens)
self.share_inputs["not_need_stop"][0] = True
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int):
@@ -329,8 +354,11 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
self.model_config.top_p,
dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full(
@@ -558,6 +586,7 @@ class IluvatarModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"],