fix: replace list * n initialization with list comprehension to avoid shared references (#3618)

This commit is contained in:
SunLei
2025-08-26 17:53:31 +08:00
committed by GitHub
parent 3200a80de3
commit 2f28f40d90

View File

@@ -205,8 +205,8 @@ class OpenAIServingCompletion:
valid_results = [dict()] * num_choices valid_results = [dict()] * num_choices
output_tokens = [0] * num_choices output_tokens = [0] * num_choices
aggregated_top_logprobs = [[[], [], []]] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)]
aggregated_token_ids = [[]] * num_choices aggregated_token_ids = [[] for _ in range(num_choices)]
completion_batched_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)]
current_waiting_time = 0 current_waiting_time = 0
while num_choices > 0: while num_choices > 0:
@@ -477,7 +477,6 @@ class OpenAIServingCompletion:
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
num_generated_tokens = 0 num_generated_tokens = 0
aggregated_logprobs: Optional[CompletionLogprobs] = None
for idx in range(len(final_res_batch)): for idx in range(len(final_res_batch)):
final_res = final_res_batch[idx] final_res = final_res_batch[idx]
@@ -489,15 +488,9 @@ class OpenAIServingCompletion:
output = final_res["outputs"] output = final_res["outputs"]
output_top_logprobs = output["top_logprobs"] output_top_logprobs = output["top_logprobs"]
aggregated_logprobs: Optional[CompletionLogprobs] = None
if output_top_logprobs is not None: if output_top_logprobs is not None:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
if aggregated_logprobs is None:
aggregated_logprobs = logprobs_res
else:
aggregated_logprobs.tokens.extend(logprobs_res.tokens)
aggregated_logprobs.token_logprobs.extend(logprobs_res.token_logprobs)
aggregated_logprobs.top_logprobs.extend(logprobs_res.top_logprobs)
aggregated_logprobs.text_offset.extend(logprobs_res.text_offset)
if request.echo: if request.echo:
assert prompt_text is not None assert prompt_text is not None