mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* support entropy * fix bug --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
100 lines
4.4 KiB
Python
100 lines
4.4 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import paddle
|
|
|
|
from fastdeploy.utils import data_processor_logger
|
|
|
|
|
|
def calculate_logits_entropy(logits, share_inputs, temperature):
|
|
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
|
real_seq_lens = paddle.where(
|
|
share_inputs["seq_lens_encoder"][:real_bsz].squeeze(1) != 0,
|
|
paddle.ones([1], dtype="int32"),
|
|
share_inputs["seq_lens_this_time"].squeeze(1),
|
|
)
|
|
|
|
def get_entropy(logits):
|
|
a0 = logits - paddle.max(logits, axis=-1, keepdim=True)
|
|
ea0 = paddle.exp(a0)
|
|
z0 = paddle.sum(ea0, axis=-1, keepdim=True)
|
|
p0 = ea0 / z0
|
|
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)
|
|
|
|
batch_indices = paddle.arange(real_bsz, dtype="int32")
|
|
batch_id_per_token = paddle.repeat_interleave(batch_indices, real_seq_lens)
|
|
for i in range(logits.shape[0]):
|
|
if temperature[batch_id_per_token[i]] > 0 and temperature[batch_id_per_token[i]] != 1.0:
|
|
logits[i] = logits[i].scale_(1 / temperature[batch_id_per_token[i]])
|
|
|
|
entropy_tensor = get_entropy(logits)
|
|
entropy = entropy_tensor.tolist()
|
|
|
|
for i in range(real_bsz):
|
|
for _ in range(real_seq_lens[i]):
|
|
share_inputs["entropy_list"][i].append(entropy.pop(0))
|
|
if share_inputs["stop_flags"][i] and len(share_inputs["entropy_list"][i]) != 0:
|
|
data_processor_logger.info(
|
|
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}"
|
|
)
|
|
share_inputs["entropy_list"][i] = []
|
|
|
|
|
|
def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
|
|
# get accepted logits
|
|
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
|
total_accepted_num = paddle.sum(share_inputs["accept_num"])
|
|
real_seq_lens = paddle.where(
|
|
share_inputs["seq_lens_encoder"][:real_bsz].squeeze(1) != 0,
|
|
paddle.ones([1], dtype="int32"),
|
|
share_inputs["seq_lens_this_time"].squeeze(1),
|
|
)
|
|
seq_start_idx = paddle.concat([paddle.zeros([1], dtype="int32"), paddle.cumsum(real_seq_lens, dtype="int32")])
|
|
repeated_starts = paddle.repeat_interleave(seq_start_idx[:-1], share_inputs["accept_num"][:real_bsz])
|
|
offsets = paddle.concat([paddle.arange(share_inputs["accept_num"][i].item()) for i in range(real_bsz)]).astype(
|
|
"int32"
|
|
)
|
|
accepted_idx = repeated_starts + offsets
|
|
|
|
accepted_logits = paddle.empty([total_accepted_num, logits.shape[1]], dtype=logits.dtype)
|
|
for i in range(total_accepted_num):
|
|
accepted_logits[i] = logits[accepted_idx[i]]
|
|
|
|
def get_entropy(logits):
|
|
a0 = logits - paddle.max(logits, axis=-1, keepdim=True)
|
|
ea0 = paddle.exp(a0)
|
|
z0 = paddle.sum(ea0, axis=-1, keepdim=True)
|
|
p0 = ea0 / z0
|
|
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)
|
|
|
|
batch_indices = paddle.arange(share_inputs["accept_num"].shape[0], dtype="int32")
|
|
batch_id_per_token = paddle.repeat_interleave(batch_indices, share_inputs["accept_num"])
|
|
for i in range(accepted_logits.shape[0]):
|
|
if temperature[batch_id_per_token[i]] > 0 and temperature[batch_id_per_token[i]] != 1.0:
|
|
accepted_logits[i] = accepted_logits[i].scale_(1 / temperature[batch_id_per_token[i]])
|
|
|
|
entropy_tensor = get_entropy(accepted_logits)
|
|
entropy = entropy_tensor.tolist()
|
|
|
|
for i in range(real_bsz):
|
|
for _ in range(share_inputs["accept_num"][i]):
|
|
share_inputs["entropy_list"][i].append(entropy.pop(0))
|
|
if share_inputs["stop_flags"][i] and len(share_inputs["entropy_list"][i]) != 0:
|
|
data_processor_logger.info(
|
|
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}"
|
|
)
|
|
share_inputs["entropy_list"][i] = []
|