add detoken switch (#5463)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
qwes5s5
2025-12-10 21:44:02 +08:00
committed by GitHub
parent 3bdd54ef6e
commit d79438bb86
7 changed files with 77 additions and 32 deletions

View File

@@ -457,6 +457,7 @@ class CompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False
include_logprobs_decode_token: Optional[bool] = True
prompt_logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
@@ -620,6 +621,7 @@ class ChatCompletionRequest(BaseModel):
top_logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False
include_logprobs_decode_token: Optional[bool] = True
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False

View File

@@ -303,7 +303,7 @@ class OpenAIServingChat:
else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(
prompt_logprobs_tensors, num_prompt_logprobs
prompt_logprobs_tensors, num_prompt_logprobs, request.include_logprobs_decode_token
)
choice = ChatCompletionResponseStreamChoice(
index=i,
@@ -370,12 +370,18 @@ class OpenAIServingChat:
request.top_logprobs if request.top_logprobs != -1 else self.engine_client.ori_vocab_size
)
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, num_top_logprobs
output_top_logprobs,
request.logprobs,
num_top_logprobs,
request.include_logprobs_decode_token,
)
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
draft_logprobs_res = self._create_chat_logprobs(
output_draft_top_logprobs, request.logprobs, num_top_logprobs
output_draft_top_logprobs,
request.logprobs,
num_top_logprobs,
request.include_logprobs_decode_token,
)
delta_message = DeltaMessage(
@@ -577,7 +583,10 @@ class OpenAIServingChat:
)
# logprobs
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, num_top_logprobs
output_top_logprobs,
request.logprobs,
num_top_logprobs,
request.include_logprobs_decode_token,
)
if logprobs_res and logprobs_res.content is not None:
logprob_contents[idx].extend(logprobs_res.content)
@@ -585,7 +594,10 @@ class OpenAIServingChat:
# draft_logprobs
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
draft_logprobs_res = self._create_chat_logprobs(
output_draft_top_logprobs, request.logprobs, num_top_logprobs
output_draft_top_logprobs,
request.logprobs,
num_top_logprobs,
request.include_logprobs_decode_token,
)
if draft_logprobs_res and draft_logprobs_res.content is not None:
draft_logprob_contents[idx].extend(draft_logprobs_res.content)
@@ -596,7 +608,9 @@ class OpenAIServingChat:
if request.prompt_logprobs != -1
else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs)
prompt_logprobs_res = self._build_prompt_logprobs(
prompt_logprobs_tensors, num_prompt_logprobs, request.include_logprobs_decode_token
)
if prompt_logprobs_res:
prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res))
if data["finished"]:
@@ -738,6 +752,7 @@ class OpenAIServingChat:
output_top_logprobs,
request_logprobs: Optional[bool] = None,
request_top_logprobs: Optional[int] = None,
request_decode_flag: Optional[bool] = True,
) -> Optional[LogProbs]:
"""Create OpenAI-style logprobs for chat completions."""
if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs):
@@ -755,6 +770,7 @@ class OpenAIServingChat:
request_logprobs=request_logprobs,
response_logprobs=top_logprobs,
request_top_logprobs=request_top_logprobs,
request_decode_flag=request_decode_flag,
)
if logprobs_res is None:
logprobs_res = step_logprobs_res
@@ -767,6 +783,7 @@ class OpenAIServingChat:
request_logprobs: bool,
response_logprobs: Optional[LogprobsLists],
request_top_logprobs: int,
request_decode_flag: bool,
) -> Optional[LogProbs]:
"""
Construct a logprobs response object in line with the OpenAI style.
@@ -796,12 +813,16 @@ class OpenAIServingChat:
# Construct the candidate token structure (LogProbEntry) of topk
top_logprob_entries: List[LogProbEntry] = []
for tid, lp in zip(topk_token_ids, topk_logprobs):
token_str = self.engine_client.data_processor.process_logprob_response(
[tid], clean_up_tokenization_spaces=False
)
token_bytes = token_str.encode("utf-8", errors="replace")
if "\ufffd" in token_str:
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
if request_decode_flag:
token_str = self.engine_client.data_processor.process_logprob_response(
[tid], clean_up_tokenization_spaces=False
)
token_bytes = token_str.encode("utf-8", errors="replace")
if "\ufffd" in token_str:
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
else:
token_str = ""
token_bytes = []
entry = LogProbEntry(token=token_str, logprob=lp, bytes=list(token_bytes))
top_logprob_entries.append(entry)
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
@@ -823,6 +844,7 @@ class OpenAIServingChat:
self,
prompt_logprobs_tensors: LogprobsTensors,
num_prompt_logprobs: int,
include_logprobs_decode_token: bool,
):
"""Update with prompt logprobs from worker.
Args:
@@ -834,10 +856,13 @@ class OpenAIServingChat:
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
]
if include_logprobs_decode_token:
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
]
else:
decoded_tokens = None
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape

View File

@@ -452,7 +452,7 @@ class OpenAIServingCompletion:
else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(
prompt_logprobs_tensors, num_prompt_logprobs
prompt_logprobs_tensors, num_prompt_logprobs, request.include_logprobs_decode_token
)
if request.return_token_ids:
chunk = CompletionStreamResponse(
@@ -651,7 +651,9 @@ class OpenAIServingCompletion:
num_prompt_logprobs = (
request.prompt_logprobs if request.prompt_logprobs != -1 else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs)
prompt_logprobs_res = self._build_prompt_logprobs(
prompt_logprobs_tensors, num_prompt_logprobs, request.include_logprobs_decode_token
)
if request.echo:
prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n))
token_ids = [*prompt_token_ids, *output["token_ids"]]
@@ -817,6 +819,7 @@ class OpenAIServingCompletion:
self,
prompt_logprobs_tensors: LogprobsTensors,
num_prompt_logprobs: int,
include_logprobs_decode_token: bool,
):
"""Update with prompt logprobs from worker.
Args:
@@ -828,10 +831,13 @@ class OpenAIServingCompletion:
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
]
if include_logprobs_decode_token:
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
]
else:
decoded_tokens = None
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape