diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 381faa81e..4c3f262ff 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -206,6 +206,7 @@ class ModelConfig: self.revision = None self.prefix_layer_name = "layers" self.kv_cache_quant_scale_path = "" + self.enable_entropy = False self.partial_rotary_factor: float = 1.0 self.num_nextn_predict_layers = 0 diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 6c77e5eb2..d14ad8897 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -509,6 +509,11 @@ class EngineArgs: Whether to skip port availability check. Default is False (not skip). """ + enable_entropy: bool = False + """ + Flag to enable entropy output. Default is False (disabled). + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -854,6 +859,12 @@ class EngineArgs: default=EngineArgs.logits_processors, help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.", ) + model_group.add_argument( + "--enable-entropy", + action="store_true", + default=EngineArgs.enable_entropy, + help="Enable output of token-level entropy.", + ) # Parallel processing parameters group parallel_group = parser.add_argument_group("Parallel Configuration") diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index d574f77be..e756a8e00 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1706,6 +1706,7 @@ class EngineService: "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, "enable_logprob": self.cfg.model_config.enable_logprob, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, + "enable_entropy": self.cfg.model_config.enable_entropy, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 43eb18e47..845857427 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -591,6 +591,7 @@ class LLMEngine: "enable_logprob": self.cfg.model_config.enable_logprob, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, "shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle, + "enable_entropy": self.cfg.model_config.enable_entropy, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/model_executor/entropy_utils.py b/fastdeploy/model_executor/entropy_utils.py new file mode 100644 index 000000000..c9fc431b4 --- /dev/null +++ b/fastdeploy/model_executor/entropy_utils.py @@ -0,0 +1,99 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.utils import data_processor_logger + + +def calculate_logits_entropy(logits, share_inputs, temperature): + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + real_seq_lens = paddle.where( + share_inputs["seq_lens_encoder"][:real_bsz].squeeze(1) != 0, + paddle.ones([1], dtype="int32"), + share_inputs["seq_lens_this_time"].squeeze(1), + ) + + def get_entropy(logits): + a0 = logits - paddle.max(logits, axis=-1, keepdim=True) + ea0 = paddle.exp(a0) + z0 = paddle.sum(ea0, axis=-1, keepdim=True) + p0 = ea0 / z0 + return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1) + + batch_indices = paddle.arange(real_bsz, dtype="int32") + batch_id_per_token = paddle.repeat_interleave(batch_indices, real_seq_lens) + for i in range(logits.shape[0]): + if temperature[batch_id_per_token[i]] > 0 and temperature[batch_id_per_token[i]] != 1.0: + logits[i] = logits[i].scale_(1 / temperature[batch_id_per_token[i]]) + + entropy_tensor = get_entropy(logits) + entropy = entropy_tensor.tolist() + + for i in range(real_bsz): + for _ in range(real_seq_lens[i]): + share_inputs["entropy_list"][i].append(entropy.pop(0)) + if share_inputs["stop_flags"][i] and len(share_inputs["entropy_list"][i]) != 0: + data_processor_logger.info( + f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}" + ) + share_inputs["entropy_list"][i] = [] + + +def speculate_calculate_logits_entropy(logits, share_inputs, temperature): + # get accepted logits + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + total_accepted_num = paddle.sum(share_inputs["accept_num"]) + real_seq_lens = paddle.where( + share_inputs["seq_lens_encoder"][:real_bsz].squeeze(1) != 0, + paddle.ones([1], dtype="int32"), + share_inputs["seq_lens_this_time"].squeeze(1), + ) + seq_start_idx = paddle.concat([paddle.zeros([1], dtype="int32"), paddle.cumsum(real_seq_lens, dtype="int32")]) + repeated_starts = paddle.repeat_interleave(seq_start_idx[:-1], share_inputs["accept_num"][:real_bsz]) + offsets = paddle.concat([paddle.arange(share_inputs["accept_num"][i].item()) for i in range(real_bsz)]).astype( + "int32" + ) + accepted_idx = repeated_starts + offsets + + accepted_logits = paddle.empty([total_accepted_num, logits.shape[1]], dtype=logits.dtype) + for i in range(total_accepted_num): + accepted_logits[i] = logits[accepted_idx[i]] + + def get_entropy(logits): + a0 = logits - paddle.max(logits, axis=-1, keepdim=True) + ea0 = paddle.exp(a0) + z0 = paddle.sum(ea0, axis=-1, keepdim=True) + p0 = ea0 / z0 + return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1) + + batch_indices = paddle.arange(share_inputs["accept_num"].shape[0], dtype="int32") + batch_id_per_token = paddle.repeat_interleave(batch_indices, share_inputs["accept_num"]) + for i in range(accepted_logits.shape[0]): + if temperature[batch_id_per_token[i]] > 0 and temperature[batch_id_per_token[i]] != 1.0: + accepted_logits[i] = accepted_logits[i].scale_(1 / temperature[batch_id_per_token[i]]) + + entropy_tensor = get_entropy(accepted_logits) + entropy = entropy_tensor.tolist() + + for i in range(real_bsz): + for _ in range(share_inputs["accept_num"][i]): + share_inputs["entropy_list"][i].append(entropy.pop(0)) + if share_inputs["stop_flags"][i] and len(share_inputs["entropy_list"][i]) != 0: + data_processor_logger.info( + f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}" + ) + share_inputs["entropy_list"][i] = [] diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index c3a426488..96a14ee93 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -546,6 +546,7 @@ class Sampler(nn.Layer): # token per request. sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, + logits=logits, ) return sampler_output @@ -845,6 +846,7 @@ class SpeculativeSampler(nn.Layer): logprobs_tensors=logprobs_tensors, token_num_per_batch=share_inputs["accept_num"], cu_batch_token_offset=share_inputs["cu_batch_token_offset"], + logits=logits, ) return sampler_output diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index f07c663b5..dee8dc372 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -93,6 +93,11 @@ else: speculate_limit_thinking_content_length_v2, ) +from fastdeploy.model_executor.entropy_utils import ( + calculate_logits_entropy, + speculate_calculate_logits_entropy, +) +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOutput @@ -307,12 +312,14 @@ def post_process_normal( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: Dict[str, paddle.Tensor], + sampling_metadata: SamplingMetadata, block_size: int = 64, save_each_rank: bool = False, skip_save_output: bool = False, async_output_queue: queue.Queue = None, think_end_id: int = -1, line_break_id: int = -1, + enable_entropy: bool = False, ): """Post-processing steps after completing a single token generation.""" if think_end_id > 0: @@ -371,6 +378,9 @@ def post_process_normal( False, ) + if enable_entropy: + calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature) + # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): if envs.ENABLE_V1_KVCACHE_SCHEDULER: @@ -436,10 +446,12 @@ def post_process_specualate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: Dict[str, paddle.Tensor], + sampling_metadata: SamplingMetadata, save_each_rank: bool = False, skip_save_output: bool = False, think_end_id: int = -1, line_break_id: int = -1, + enable_entropy: bool = False, ): if think_end_id > 0: speculate_limit_thinking_content_length( @@ -464,6 +476,10 @@ def post_process_specualate( model_output.eos_token_id, model_output.min_tokens, ) + + if enable_entropy: + speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature) + speculate_update( model_output.seq_lens_encoder, model_output.seq_lens_decoder, @@ -525,6 +541,7 @@ def post_process( sampler_or_pooler_output: Union[SamplerOutput, PoolerOutput], model_output: ModelOutputData, share_inputs: Dict[str, paddle.Tensor], + sampling_metadata: SamplingMetadata = None, block_size: int = 64, save_each_rank: bool = False, speculative_decoding: bool = False, @@ -532,6 +549,7 @@ def post_process( async_output_queue: queue.Queue = None, think_end_id: int = -1, line_break_id: int = -1, + enable_entropy: bool = False, ) -> None: """Post-processing steps after completing a single token generation.""" @@ -551,22 +569,26 @@ def post_process( sampler_or_pooler_output, model_output, share_inputs, + sampling_metadata, save_each_rank, skip_save_output, think_end_id, line_break_id, + enable_entropy, ) else: post_process_normal( sampler_or_pooler_output, model_output, share_inputs, + sampling_metadata, block_size, save_each_rank, skip_save_output, async_output_queue, think_end_id, line_break_id, + enable_entropy, ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 9bf533605..4265acfdf 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -235,6 +235,8 @@ class GPUModelRunner(ModelRunnerBase): ) self.async_output_copy_thread.start() + self.enable_entropy = self.model_config.enable_entropy + def _async_output_busy_loop(self): """Entrypoint for the thread which handles outputs asynchronously.""" while True: @@ -643,6 +645,7 @@ class GPUModelRunner(ModelRunnerBase): request = req_dicts[i] # assert isinstance(request, Request) idx = request.idx + self.share_inputs["req_ids"][idx] = str(request.request_id) if hasattr(request, "pooling_params") and request.pooling_params is not None: batch_pooling_params.append(request.pooling_params) @@ -1309,6 +1312,9 @@ class GPUModelRunner(ModelRunnerBase): -1, dtype="int64", ) + self.share_inputs["req_ids"] = [""] * max_num_seqs + self.share_inputs["entropy_list"] = [[] for _ in range(max_num_seqs)] + if self.speculative_decoding: max_draft_token_num = self.speculative_config.num_speculative_tokens self.share_inputs["input_ids_cpu"] = paddle.full( @@ -1830,6 +1836,7 @@ class GPUModelRunner(ModelRunnerBase): sampler_or_pooler_output=pooler_output, model_output=model_output_data, share_inputs=self.share_inputs, + sampling_metadata=self.sampling_metadata, block_size=self.cache_config.block_size, speculative_decoding=self.speculative_decoding, skip_save_output=True, @@ -1932,12 +1939,14 @@ class GPUModelRunner(ModelRunnerBase): sampler_or_pooler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, + sampling_metadata=self.sampling_metadata, block_size=self.cache_config.block_size, speculative_decoding=self.speculative_decoding, skip_save_output=True, async_output_queue=self.async_output_queue, think_end_id=self.model_config.think_end_id, line_break_id=self.model_config.line_break_id, + enable_entropy=self.enable_entropy, ) if self.speculative_decoding: if self.speculative_method == "mtp": @@ -2398,11 +2407,13 @@ class GPUModelRunner(ModelRunnerBase): sampler_or_pooler_output=pooler_output, model_output=model_output_data, share_inputs=self.share_inputs, + sampling_metadata=self.sampling_metadata, block_size=self.cache_config.block_size, save_each_rank=self.parallel_config.use_ep, speculative_decoding=self.speculative_decoding, skip_save_output=False, async_output_queue=self.async_output_queue, + enable_entropy=self.enable_entropy, ) return None @@ -2524,6 +2535,7 @@ class GPUModelRunner(ModelRunnerBase): sampler_or_pooler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, + sampling_metadata=self.sampling_metadata, block_size=self.cache_config.block_size, save_each_rank=self.parallel_config.use_ep, speculative_decoding=self.speculative_decoding, @@ -2531,6 +2543,7 @@ class GPUModelRunner(ModelRunnerBase): async_output_queue=self.async_output_queue, think_end_id=self.model_config.think_end_id, line_break_id=self.model_config.line_break_id, + enable_entropy=self.enable_entropy, ) if self.guided_backend is not None and sampler_output is not None: self.sampler.post_process(sampler_output.sampled_token_ids) diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 13d822b09..76710c4b0 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -172,6 +172,7 @@ class SamplerOutput: logprobs_tensors: Optional[LogprobsTensors] token_num_per_batch: Optional[paddle.Tensor] = None cu_batch_token_offset: Optional[paddle.Tensor] = None + logits: Optional[paddle.Tensor] = None @dataclass diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index d82343a7e..4fcecbddd 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -893,6 +893,12 @@ def parse_args(): help="Shutdown comm group if worker idle.", ) + parser.add_argument( + "--enable_entropy", + action="store_true", + help="Enable output of token-level entropy.", + ) + args = parser.parse_args() return args diff --git a/tests/model_executor/test_entropy_utils.py b/tests/model_executor/test_entropy_utils.py new file mode 100644 index 000000000..1135a77f5 --- /dev/null +++ b/tests/model_executor/test_entropy_utils.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + +from fastdeploy.model_executor.entropy_utils import ( + calculate_logits_entropy, + speculate_calculate_logits_entropy, +) + + +class TestCalculateLogitsEntropy(unittest.TestCase): + + def test_basic_functionality(self): + share_inputs = { + "seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"), + "seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"), + "entropy_list": [[], [], []], + "stop_flags": paddle.to_tensor([[False], [True], [False]], dtype="bool"), + "req_ids": ["req_1", "req_2", "req_3"], + } + + logits = paddle.to_tensor( + [ + [10.0, 1.0, 1.0], + [1.0, 1.0, 10.0], + ], + dtype="float32", + ) + temperature = paddle.ones([3], dtype="float32") + + calculate_logits_entropy(logits, share_inputs, temperature) + + self.assertEqual(len(share_inputs["entropy_list"][0]), 1) + self.assertEqual(len(share_inputs["entropy_list"][1]), 0) + self.assertEqual(len(share_inputs["entropy_list"][2]), 1) + + self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0024676250759512186, places=6) + self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 0.0024676250759512186, places=6) + + def test_temperature_effect(self): + share_inputs = { + "seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"), + "seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"), + "entropy_list": [[], [], []], + "stop_flags": paddle.to_tensor([[False], [True], [False]], dtype="bool"), + "req_ids": ["req_1", "req_2", "req_3"], + } + + logits = paddle.to_tensor( + [ + [10.0, 1.0, 1.0], + [1.0, 1.0, 10.0], + ], + dtype="float32", + ) + temperature = paddle.to_tensor([[0.8], [1.0], [0.8]], dtype="float32") + + calculate_logits_entropy(logits, share_inputs, temperature) + + self.assertEqual(len(share_inputs["entropy_list"][0]), 1) + self.assertEqual(len(share_inputs["entropy_list"][1]), 0) + self.assertEqual(len(share_inputs["entropy_list"][2]), 1) + + self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0003187173861078918, places=6) + self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 0.0003187173861078918, places=6) + + def test_entropy_list_clear(self): + share_inputs = { + "seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"), + "seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"), + "entropy_list": [[], [], []], + "stop_flags": paddle.to_tensor([[True], [True], [False]], dtype="bool"), + "req_ids": ["req_1", "req_2", "req_3"], + } + + logits = paddle.to_tensor( + [ + [10.0, 1.0, 1.0], + [1.0, 1.0, 10.0], + ], + dtype="float32", + ) + temperature = paddle.to_tensor([[0.8], [1.0], [0.8]], dtype="float32") + + calculate_logits_entropy(logits, share_inputs, temperature) + + self.assertEqual(len(share_inputs["entropy_list"][0]), 0) + self.assertEqual(len(share_inputs["entropy_list"][1]), 0) + self.assertEqual(len(share_inputs["entropy_list"][2]), 1) + + self.assertAlmostEqual(share_inputs["entropy_list"][2][0], 0.0003187173861078918, places=6) + + +class TestSpeculateCalculateLogitsEntropy(unittest.TestCase): + + def test_basic_functionality(self): + share_inputs = { + "seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"), + "seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"), + "entropy_list": [[], [], [], []], + "stop_flags": paddle.to_tensor([[False], [False], [True], [False]], dtype="bool"), + "req_ids": ["req_1", "req_2", "req_3", "req_4"], + "accept_num": paddle.to_tensor([2, 1, 0, 0], dtype="int32"), # 推理接受数量 + } + + logits = paddle.to_tensor( + [ + [10.0, 1.0, 1.0], + [1.0, 10.0, 1.0], + [1.0, 1.0, 10.0], + [1.0, 1.0, 10.0], + ], + dtype="float32", + ) + temperature = paddle.ones([3], dtype="float32") + + speculate_calculate_logits_entropy(logits, share_inputs, temperature) + + print(share_inputs["entropy_list"]) + + self.assertEqual(len(share_inputs["entropy_list"][0]), 2) + self.assertEqual(len(share_inputs["entropy_list"][1]), 1) + self.assertEqual(len(share_inputs["entropy_list"][2]), 0) + self.assertEqual(len(share_inputs["entropy_list"][3]), 0) + + self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0024676250759512186, places=6) + self.assertAlmostEqual(share_inputs["entropy_list"][0][1], 0.0024676250759512186, places=6) + self.assertAlmostEqual(share_inputs["entropy_list"][1][0], 0.0024676250759512186, places=6) + + def test_temperature_effect(self): + share_inputs = { + "seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"), + "seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"), + "entropy_list": [[], [], [], []], + "stop_flags": paddle.to_tensor([[False], [False], [True], [False]], dtype="bool"), + "req_ids": ["req_1", "req_2", "req_3", "req_4"], + "accept_num": paddle.to_tensor([2, 1, 0, 0], dtype="int32"), # 推理接受数量 + } + + logits = paddle.to_tensor( + [ + [10.0, 1.0, 1.0], + [1.0, 10.0, 1.0], + [1.0, 1.0, 10.0], + [1.0, 1.0, 10.0], + ], + dtype="float32", + ) + temperature = paddle.to_tensor([[0.8], [0.8], [0.8], [0.8]], dtype="float32") + + speculate_calculate_logits_entropy(logits, share_inputs, temperature) + + print(share_inputs["entropy_list"]) + + self.assertEqual(len(share_inputs["entropy_list"][0]), 2) + self.assertEqual(len(share_inputs["entropy_list"][1]), 1) + self.assertEqual(len(share_inputs["entropy_list"][2]), 0) + self.assertEqual(len(share_inputs["entropy_list"][3]), 0) + + self.assertAlmostEqual(share_inputs["entropy_list"][0][0], 0.0003187173861078918, places=6) + self.assertAlmostEqual(share_inputs["entropy_list"][0][1], 0.0003187173861078918, places=6) + self.assertAlmostEqual(share_inputs["entropy_list"][1][0], 0.0003187173861078918, places=6) + + def test_entropy_list_clear(self): + share_inputs = { + "seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"), + "seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"), + "entropy_list": [[], [], [], []], + "stop_flags": paddle.to_tensor([[True], [False], [True], [False]], dtype="bool"), + "req_ids": ["req_1", "req_2", "req_3", "req_4"], + "accept_num": paddle.to_tensor([2, 1, 0, 0], dtype="int32"), # 推理接受数量 + } + + logits = paddle.to_tensor( + [ + [10.0, 1.0, 1.0], + [1.0, 10.0, 1.0], + [1.0, 1.0, 10.0], + [1.0, 1.0, 10.0], + ], + dtype="float32", + ) + temperature = paddle.ones([3], dtype="float32") + + speculate_calculate_logits_entropy(logits, share_inputs, temperature) + + print(share_inputs["entropy_list"]) + + self.assertEqual(len(share_inputs["entropy_list"][0]), 0) + self.assertEqual(len(share_inputs["entropy_list"][1]), 1) + self.assertEqual(len(share_inputs["entropy_list"][2]), 0) + self.assertEqual(len(share_inputs["entropy_list"][3]), 0) + + self.assertAlmostEqual(share_inputs["entropy_list"][1][0], 0.0024676250759512186, places=6) + + +if __name__ == "__main__": + unittest.main()