This commit is contained in:
chen
2025-11-26 17:07:26 +08:00
committed by GitHub
parent 214942e1ae
commit 00d0ef5134
2 changed files with 7 additions and 7 deletions

View File

@@ -292,8 +292,7 @@ def _build_stream_transfer_data(
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
if logprobs:
logprobs = logprobs.slice_rows(bid, bid + 1)
stream_transfer_data.logprobs = logprobs
stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1)
if prompt_logprobs_list:
stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid]
stream_transfer_datas.append(stream_transfer_data)

View File

@@ -117,11 +117,12 @@ class LogprobsTensors(NamedTuple):
Slice rows.
Keeps the number of max_num_logprobs unchanged.
"""
return LogprobsTensors(
self.logprob_token_ids[start:end],
self.logprobs[start:end],
self.selected_token_ranks[start:end],
)
with paddle.no_grad():
return LogprobsTensors(
paddle.to_tensor(self.logprob_token_ids[start:end], place=self.logprob_token_ids.place),
paddle.to_tensor(self.logprobs[start:end], place=self.logprob_token_ids.place),
paddle.to_tensor(self.selected_token_ranks[start:end], place=self.logprob_token_ids.place),
)
@dataclass