[FIX 2.0.2] Topk topp sampling fix (#2805)

* fix topk-topp

* fix
This commit is contained in:
Sunny-bot1
2025-07-10 21:15:03 +08:00
committed by GitHub
parent e681e1e719
commit 4025ea7e5b
6 changed files with 27 additions and 70 deletions

View File

@@ -266,26 +266,11 @@ class XPUModelRunner(ModelRunnerBase):
def process_prefill_inputs(self, req_dicts: List[Request]):
""" Process inputs for prefill tasks and update share_inputs buffer """
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
length = request.prompt_token_ids_len
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array(
request.prompt_token_ids)
if len(request.eos_token_ids
@@ -294,8 +279,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["pre_ids"][idx:idx + 1] = -1
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 1.0)
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -344,15 +329,6 @@ class XPUModelRunner(ModelRunnerBase):
request.get("stop_token_ids"), dtype="int64")
self.share_inputs["not_need_stop"][0] = True
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _init_share_inputs(self, max_num_seqs: int):
"""Initialize all share buffers for model inputs.