mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
check (#5237)
This commit is contained in:
@@ -292,8 +292,7 @@ def _build_stream_transfer_data(
|
||||
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
|
||||
)
|
||||
if logprobs:
|
||||
logprobs = logprobs.slice_rows(bid, bid + 1)
|
||||
stream_transfer_data.logprobs = logprobs
|
||||
stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1)
|
||||
if prompt_logprobs_list:
|
||||
stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid]
|
||||
stream_transfer_datas.append(stream_transfer_data)
|
||||
|
||||
@@ -117,11 +117,12 @@ class LogprobsTensors(NamedTuple):
|
||||
Slice rows.
|
||||
Keeps the number of max_num_logprobs unchanged.
|
||||
"""
|
||||
return LogprobsTensors(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.selected_token_ranks[start:end],
|
||||
)
|
||||
with paddle.no_grad():
|
||||
return LogprobsTensors(
|
||||
paddle.to_tensor(self.logprob_token_ids[start:end], place=self.logprob_token_ids.place),
|
||||
paddle.to_tensor(self.logprobs[start:end], place=self.logprob_token_ids.place),
|
||||
paddle.to_tensor(self.selected_token_ranks[start:end], place=self.logprob_token_ids.place),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user