mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Entropy calculation support (#5692)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* support entropy * fix bug --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -206,6 +206,7 @@ class ModelConfig:
|
||||
self.revision = None
|
||||
self.prefix_layer_name = "layers"
|
||||
self.kv_cache_quant_scale_path = ""
|
||||
self.enable_entropy = False
|
||||
|
||||
self.partial_rotary_factor: float = 1.0
|
||||
self.num_nextn_predict_layers = 0
|
||||
|
||||
@@ -509,6 +509,11 @@ class EngineArgs:
|
||||
Whether to skip port availability check. Default is False (not skip).
|
||||
"""
|
||||
|
||||
enable_entropy: bool = False
|
||||
"""
|
||||
Flag to enable entropy output. Default is False (disabled).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -854,6 +859,12 @@ class EngineArgs:
|
||||
default=EngineArgs.logits_processors,
|
||||
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--enable-entropy",
|
||||
action="store_true",
|
||||
default=EngineArgs.enable_entropy,
|
||||
help="Enable output of token-level entropy.",
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
|
||||
@@ -1706,6 +1706,7 @@ class EngineService:
|
||||
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
|
||||
"enable_logprob": self.cfg.model_config.enable_logprob,
|
||||
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||
"enable_entropy": self.cfg.model_config.enable_entropy,
|
||||
}
|
||||
for worker_flag, value in worker_store_true_flag.items():
|
||||
if value:
|
||||
|
||||
@@ -591,6 +591,7 @@ class LLMEngine:
|
||||
"enable_logprob": self.cfg.model_config.enable_logprob,
|
||||
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||
"shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle,
|
||||
"enable_entropy": self.cfg.model_config.enable_entropy,
|
||||
}
|
||||
for worker_flag, value in worker_store_true_flag.items():
|
||||
if value:
|
||||
|
||||
99
fastdeploy/model_executor/entropy_utils.py
Normal file
99
fastdeploy/model_executor/entropy_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
|
||||
def calculate_logits_entropy(logits, share_inputs, temperature):
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
real_seq_lens = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz].squeeze(1) != 0,
|
||||
paddle.ones([1], dtype="int32"),
|
||||
share_inputs["seq_lens_this_time"].squeeze(1),
|
||||
)
|
||||
|
||||
def get_entropy(logits):
|
||||
a0 = logits - paddle.max(logits, axis=-1, keepdim=True)
|
||||
ea0 = paddle.exp(a0)
|
||||
z0 = paddle.sum(ea0, axis=-1, keepdim=True)
|
||||
p0 = ea0 / z0
|
||||
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)
|
||||
|
||||
batch_indices = paddle.arange(real_bsz, dtype="int32")
|
||||
batch_id_per_token = paddle.repeat_interleave(batch_indices, real_seq_lens)
|
||||
for i in range(logits.shape[0]):
|
||||
if temperature[batch_id_per_token[i]] > 0 and temperature[batch_id_per_token[i]] != 1.0:
|
||||
logits[i] = logits[i].scale_(1 / temperature[batch_id_per_token[i]])
|
||||
|
||||
entropy_tensor = get_entropy(logits)
|
||||
entropy = entropy_tensor.tolist()
|
||||
|
||||
for i in range(real_bsz):
|
||||
for _ in range(real_seq_lens[i]):
|
||||
share_inputs["entropy_list"][i].append(entropy.pop(0))
|
||||
if share_inputs["stop_flags"][i] and len(share_inputs["entropy_list"][i]) != 0:
|
||||
data_processor_logger.info(
|
||||
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}"
|
||||
)
|
||||
share_inputs["entropy_list"][i] = []
|
||||
|
||||
|
||||
def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
|
||||
# get accepted logits
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
total_accepted_num = paddle.sum(share_inputs["accept_num"])
|
||||
real_seq_lens = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz].squeeze(1) != 0,
|
||||
paddle.ones([1], dtype="int32"),
|
||||
share_inputs["seq_lens_this_time"].squeeze(1),
|
||||
)
|
||||
seq_start_idx = paddle.concat([paddle.zeros([1], dtype="int32"), paddle.cumsum(real_seq_lens, dtype="int32")])
|
||||
repeated_starts = paddle.repeat_interleave(seq_start_idx[:-1], share_inputs["accept_num"][:real_bsz])
|
||||
offsets = paddle.concat([paddle.arange(share_inputs["accept_num"][i].item()) for i in range(real_bsz)]).astype(
|
||||
"int32"
|
||||
)
|
||||
accepted_idx = repeated_starts + offsets
|
||||
|
||||
accepted_logits = paddle.empty([total_accepted_num, logits.shape[1]], dtype=logits.dtype)
|
||||
for i in range(total_accepted_num):
|
||||
accepted_logits[i] = logits[accepted_idx[i]]
|
||||
|
||||
def get_entropy(logits):
|
||||
a0 = logits - paddle.max(logits, axis=-1, keepdim=True)
|
||||
ea0 = paddle.exp(a0)
|
||||
z0 = paddle.sum(ea0, axis=-1, keepdim=True)
|
||||
p0 = ea0 / z0
|
||||
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)
|
||||
|
||||
batch_indices = paddle.arange(share_inputs["accept_num"].shape[0], dtype="int32")
|
||||
batch_id_per_token = paddle.repeat_interleave(batch_indices, share_inputs["accept_num"])
|
||||
for i in range(accepted_logits.shape[0]):
|
||||
if temperature[batch_id_per_token[i]] > 0 and temperature[batch_id_per_token[i]] != 1.0:
|
||||
accepted_logits[i] = accepted_logits[i].scale_(1 / temperature[batch_id_per_token[i]])
|
||||
|
||||
entropy_tensor = get_entropy(accepted_logits)
|
||||
entropy = entropy_tensor.tolist()
|
||||
|
||||
for i in range(real_bsz):
|
||||
for _ in range(share_inputs["accept_num"][i]):
|
||||
share_inputs["entropy_list"][i].append(entropy.pop(0))
|
||||
if share_inputs["stop_flags"][i] and len(share_inputs["entropy_list"][i]) != 0:
|
||||
data_processor_logger.info(
|
||||
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}"
|
||||
)
|
||||
share_inputs["entropy_list"][i] = []
|
||||
@@ -546,6 +546,7 @@ class Sampler(nn.Layer):
|
||||
# token per request.
|
||||
sampled_token_ids=next_tokens,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
logits=logits,
|
||||
)
|
||||
|
||||
return sampler_output
|
||||
@@ -845,6 +846,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
logits=logits,
|
||||
)
|
||||
|
||||
return sampler_output
|
||||
|
||||
@@ -93,6 +93,11 @@ else:
|
||||
speculate_limit_thinking_content_length_v2,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.entropy_utils import (
|
||||
calculate_logits_entropy,
|
||||
speculate_calculate_logits_entropy,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOutput
|
||||
@@ -307,12 +312,14 @@ def post_process_normal(
|
||||
sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
block_size: int = 64,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
async_output_queue: queue.Queue = None,
|
||||
think_end_id: int = -1,
|
||||
line_break_id: int = -1,
|
||||
enable_entropy: bool = False,
|
||||
):
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
if think_end_id > 0:
|
||||
@@ -371,6 +378,9 @@ def post_process_normal(
|
||||
False,
|
||||
)
|
||||
|
||||
if enable_entropy:
|
||||
calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
|
||||
|
||||
# 2. Update the input buffer of the model
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
@@ -436,10 +446,12 @@ def post_process_specualate(
|
||||
sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
think_end_id: int = -1,
|
||||
line_break_id: int = -1,
|
||||
enable_entropy: bool = False,
|
||||
):
|
||||
if think_end_id > 0:
|
||||
speculate_limit_thinking_content_length(
|
||||
@@ -464,6 +476,10 @@ def post_process_specualate(
|
||||
model_output.eos_token_id,
|
||||
model_output.min_tokens,
|
||||
)
|
||||
|
||||
if enable_entropy:
|
||||
speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
|
||||
|
||||
speculate_update(
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
@@ -525,6 +541,7 @@ def post_process(
|
||||
sampler_or_pooler_output: Union[SamplerOutput, PoolerOutput],
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
sampling_metadata: SamplingMetadata = None,
|
||||
block_size: int = 64,
|
||||
save_each_rank: bool = False,
|
||||
speculative_decoding: bool = False,
|
||||
@@ -532,6 +549,7 @@ def post_process(
|
||||
async_output_queue: queue.Queue = None,
|
||||
think_end_id: int = -1,
|
||||
line_break_id: int = -1,
|
||||
enable_entropy: bool = False,
|
||||
) -> None:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
|
||||
@@ -551,22 +569,26 @@ def post_process(
|
||||
sampler_or_pooler_output,
|
||||
model_output,
|
||||
share_inputs,
|
||||
sampling_metadata,
|
||||
save_each_rank,
|
||||
skip_save_output,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
enable_entropy,
|
||||
)
|
||||
else:
|
||||
post_process_normal(
|
||||
sampler_or_pooler_output,
|
||||
model_output,
|
||||
share_inputs,
|
||||
sampling_metadata,
|
||||
block_size,
|
||||
save_each_rank,
|
||||
skip_save_output,
|
||||
async_output_queue,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
enable_entropy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -235,6 +235,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
self.enable_entropy = self.model_config.enable_entropy
|
||||
|
||||
def _async_output_busy_loop(self):
|
||||
"""Entrypoint for the thread which handles outputs asynchronously."""
|
||||
while True:
|
||||
@@ -643,6 +645,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
request = req_dicts[i]
|
||||
# assert isinstance(request, Request)
|
||||
idx = request.idx
|
||||
self.share_inputs["req_ids"][idx] = str(request.request_id)
|
||||
|
||||
if hasattr(request, "pooling_params") and request.pooling_params is not None:
|
||||
batch_pooling_params.append(request.pooling_params)
|
||||
@@ -1309,6 +1312,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
-1,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["req_ids"] = [""] * max_num_seqs
|
||||
self.share_inputs["entropy_list"] = [[] for _ in range(max_num_seqs)]
|
||||
|
||||
if self.speculative_decoding:
|
||||
max_draft_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.share_inputs["input_ids_cpu"] = paddle.full(
|
||||
@@ -1830,6 +1836,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
sampler_or_pooler_output=pooler_output,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
sampling_metadata=self.sampling_metadata,
|
||||
block_size=self.cache_config.block_size,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
skip_save_output=True,
|
||||
@@ -1932,12 +1939,14 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
sampler_or_pooler_output=sampler_output,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
sampling_metadata=self.sampling_metadata,
|
||||
block_size=self.cache_config.block_size,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
skip_save_output=True,
|
||||
async_output_queue=self.async_output_queue,
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
enable_entropy=self.enable_entropy,
|
||||
)
|
||||
if self.speculative_decoding:
|
||||
if self.speculative_method == "mtp":
|
||||
@@ -2398,11 +2407,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
sampler_or_pooler_output=pooler_output,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
sampling_metadata=self.sampling_metadata,
|
||||
block_size=self.cache_config.block_size,
|
||||
save_each_rank=self.parallel_config.use_ep,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
skip_save_output=False,
|
||||
async_output_queue=self.async_output_queue,
|
||||
enable_entropy=self.enable_entropy,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -2524,6 +2535,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
sampler_or_pooler_output=sampler_output,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
sampling_metadata=self.sampling_metadata,
|
||||
block_size=self.cache_config.block_size,
|
||||
save_each_rank=self.parallel_config.use_ep,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
@@ -2531,6 +2543,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
async_output_queue=self.async_output_queue,
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
enable_entropy=self.enable_entropy,
|
||||
)
|
||||
if self.guided_backend is not None and sampler_output is not None:
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids)
|
||||
|
||||
@@ -172,6 +172,7 @@ class SamplerOutput:
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
token_num_per_batch: Optional[paddle.Tensor] = None
|
||||
cu_batch_token_offset: Optional[paddle.Tensor] = None
|
||||
logits: Optional[paddle.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -893,6 +893,12 @@ def parse_args():
|
||||
help="Shutdown comm group if worker idle.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_entropy",
|
||||
action="store_true",
|
||||
help="Enable output of token-level entropy.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
212
tests/model_executor/test_entropy_utils.py
Normal file
212
tests/model_executor/test_entropy_utils.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.entropy_utils import (
|
||||
calculate_logits_entropy,
|
||||
speculate_calculate_logits_entropy,
|
||||
)
|
||||
|
||||
|
||||
class TestCalculateLogitsEntropy(unittest.TestCase):
|
||||
|
||||
def test_basic_functionality(self):
|
||||
share_inputs = {
|
||||
"seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"),
|
||||
"seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"),
|
||||
"entropy_list": [[], [], []],
|
||||
"stop_flags": paddle.to_tensor([[False], [True], [False]], dtype="bool"),
|
||||
"req_ids": ["req_1", "req_2", "req_3"],
|
||||
}
|
||||
|
||||
logits = paddle.to_tensor(
|
||||
[
|
||||
[10.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
],
|
||||
dtype="float32",
|
||||
)
|
||||
temperature = paddle.ones([3], dtype="float32")
|
||||
|
||||
calculate_logits_entropy(logits, share_inputs, temperature)
|
||||
|
||||
self.assertEqual(len(share_inputs["entropy_list"][0]), 1)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][1]), 0)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][2]), 1)
|
||||
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0024676250759512186, places=6)
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 0.0024676250759512186, places=6)
|
||||
|
||||
def test_temperature_effect(self):
|
||||
share_inputs = {
|
||||
"seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"),
|
||||
"seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"),
|
||||
"entropy_list": [[], [], []],
|
||||
"stop_flags": paddle.to_tensor([[False], [True], [False]], dtype="bool"),
|
||||
"req_ids": ["req_1", "req_2", "req_3"],
|
||||
}
|
||||
|
||||
logits = paddle.to_tensor(
|
||||
[
|
||||
[10.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
],
|
||||
dtype="float32",
|
||||
)
|
||||
temperature = paddle.to_tensor([[0.8], [1.0], [0.8]], dtype="float32")
|
||||
|
||||
calculate_logits_entropy(logits, share_inputs, temperature)
|
||||
|
||||
self.assertEqual(len(share_inputs["entropy_list"][0]), 1)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][1]), 0)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][2]), 1)
|
||||
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0003187173861078918, places=6)
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 0.0003187173861078918, places=6)
|
||||
|
||||
def test_entropy_list_clear(self):
|
||||
share_inputs = {
|
||||
"seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"),
|
||||
"seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"),
|
||||
"entropy_list": [[], [], []],
|
||||
"stop_flags": paddle.to_tensor([[True], [True], [False]], dtype="bool"),
|
||||
"req_ids": ["req_1", "req_2", "req_3"],
|
||||
}
|
||||
|
||||
logits = paddle.to_tensor(
|
||||
[
|
||||
[10.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
],
|
||||
dtype="float32",
|
||||
)
|
||||
temperature = paddle.to_tensor([[0.8], [1.0], [0.8]], dtype="float32")
|
||||
|
||||
calculate_logits_entropy(logits, share_inputs, temperature)
|
||||
|
||||
self.assertEqual(len(share_inputs["entropy_list"][0]), 0)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][1]), 0)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][2]), 1)
|
||||
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 0.0003187173861078918, places=6)
|
||||
|
||||
|
||||
class TestSpeculateCalculateLogitsEntropy(unittest.TestCase):
|
||||
|
||||
def test_basic_functionality(self):
|
||||
share_inputs = {
|
||||
"seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"),
|
||||
"seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"),
|
||||
"entropy_list": [[], [], [], []],
|
||||
"stop_flags": paddle.to_tensor([[False], [False], [True], [False]], dtype="bool"),
|
||||
"req_ids": ["req_1", "req_2", "req_3", "req_4"],
|
||||
"accept_num": paddle.to_tensor([2, 1, 0, 0], dtype="int32"), # 推理接受数量
|
||||
}
|
||||
|
||||
logits = paddle.to_tensor(
|
||||
[
|
||||
[10.0, 1.0, 1.0],
|
||||
[1.0, 10.0, 1.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
],
|
||||
dtype="float32",
|
||||
)
|
||||
temperature = paddle.ones([3], dtype="float32")
|
||||
|
||||
speculate_calculate_logits_entropy(logits, share_inputs, temperature)
|
||||
|
||||
print(share_inputs["entropy_list"])
|
||||
|
||||
self.assertEqual(len(share_inputs["entropy_list"][0]), 2)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][1]), 1)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][2]), 0)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][3]), 0)
|
||||
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0024676250759512186, places=6)
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][0][1], 0.0024676250759512186, places=6)
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][1][0], 0.0024676250759512186, places=6)
|
||||
|
||||
def test_temperature_effect(self):
|
||||
share_inputs = {
|
||||
"seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"),
|
||||
"seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"),
|
||||
"entropy_list": [[], [], [], []],
|
||||
"stop_flags": paddle.to_tensor([[False], [False], [True], [False]], dtype="bool"),
|
||||
"req_ids": ["req_1", "req_2", "req_3", "req_4"],
|
||||
"accept_num": paddle.to_tensor([2, 1, 0, 0], dtype="int32"), # 推理接受数量
|
||||
}
|
||||
|
||||
logits = paddle.to_tensor(
|
||||
[
|
||||
[10.0, 1.0, 1.0],
|
||||
[1.0, 10.0, 1.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
],
|
||||
dtype="float32",
|
||||
)
|
||||
temperature = paddle.to_tensor([[0.8], [0.8], [0.8], [0.8]], dtype="float32")
|
||||
|
||||
speculate_calculate_logits_entropy(logits, share_inputs, temperature)
|
||||
|
||||
print(share_inputs["entropy_list"])
|
||||
|
||||
self.assertEqual(len(share_inputs["entropy_list"][0]), 2)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][1]), 1)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][2]), 0)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][3]), 0)
|
||||
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0003187173861078918, places=6)
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][0][1], 0.0003187173861078918, places=6)
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][1][0], 0.0003187173861078918, places=6)
|
||||
|
||||
def test_entropy_list_clear(self):
|
||||
share_inputs = {
|
||||
"seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"),
|
||||
"seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"),
|
||||
"entropy_list": [[], [], [], []],
|
||||
"stop_flags": paddle.to_tensor([[True], [False], [True], [False]], dtype="bool"),
|
||||
"req_ids": ["req_1", "req_2", "req_3", "req_4"],
|
||||
"accept_num": paddle.to_tensor([2, 1, 0, 0], dtype="int32"), # 推理接受数量
|
||||
}
|
||||
|
||||
logits = paddle.to_tensor(
|
||||
[
|
||||
[10.0, 1.0, 1.0],
|
||||
[1.0, 10.0, 1.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
[1.0, 1.0, 10.0],
|
||||
],
|
||||
dtype="float32",
|
||||
)
|
||||
temperature = paddle.ones([3], dtype="float32")
|
||||
|
||||
speculate_calculate_logits_entropy(logits, share_inputs, temperature)
|
||||
|
||||
print(share_inputs["entropy_list"])
|
||||
|
||||
self.assertEqual(len(share_inputs["entropy_list"][0]), 0)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][1]), 1)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][2]), 0)
|
||||
self.assertEqual(len(share_inputs["entropy_list"][3]), 0)
|
||||
|
||||
self.assertAlmostEqual(share_inputs["entropy_list"][1][0], 0.0024676250759512186, places=6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user