add detoken switch (#5463)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
qwes5s5
2025-12-10 21:44:02 +08:00
committed by GitHub
parent 3bdd54ef6e
commit d79438bb86
7 changed files with 77 additions and 32 deletions

View File

@@ -452,7 +452,7 @@ class OpenAIServingCompletion:
else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(
prompt_logprobs_tensors, num_prompt_logprobs
prompt_logprobs_tensors, num_prompt_logprobs, request.include_logprobs_decode_token
)
if request.return_token_ids:
chunk = CompletionStreamResponse(
@@ -651,7 +651,9 @@ class OpenAIServingCompletion:
num_prompt_logprobs = (
request.prompt_logprobs if request.prompt_logprobs != -1 else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs)
prompt_logprobs_res = self._build_prompt_logprobs(
prompt_logprobs_tensors, num_prompt_logprobs, request.include_logprobs_decode_token
)
if request.echo:
prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n))
token_ids = [*prompt_token_ids, *output["token_ids"]]
@@ -817,6 +819,7 @@ class OpenAIServingCompletion:
self,
prompt_logprobs_tensors: LogprobsTensors,
num_prompt_logprobs: int,
include_logprobs_decode_token: bool,
):
"""Update with prompt logprobs from worker.
Args:
@@ -828,10 +831,13 @@ class OpenAIServingCompletion:
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
]
if include_logprobs_decode_token:
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
]
else:
decoded_tokens = None
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape