[Feature] General support for logprobs (#2974)

* [Feature] support logprobs in chat/completions and completions endpoints

* Temporarily comment out text_offset due to incorrect logic

* Clean up temporary debug prints

* [Feature] support logprobs in offline mode via SamplingParams

* fix: serialize Logprob as dict before zmq send to fix msgpack error

* refactor: remove redundant methods to simplify codebase

* Fix missing fields in CompletionOutput.to_dict affecting msgpack serialization

* refactor: centralize param validation in engine_client to reduce duplication

* revert: rollback changes in offline_demo.py

* revert: rollback changes in offline_demo.py

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2025-07-31 20:25:56 +08:00
committed by GitHub
parent fe17410f9c
commit dade19d7a4
10 changed files with 330 additions and 44 deletions

View File

@@ -424,7 +424,7 @@ class LLMEngine:
else: else:
err, data = self.zmq_server.receive_pyobj_once(block) err, data = self.zmq_server.receive_pyobj_once(block)
if err is not None: if err is not None:
llm_logger.error("Engine stops inserting zmq task into scheduler") llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
break break
request, insert_task = None, [] request, insert_task = None, []

View File

@@ -25,7 +25,7 @@ import numpy as np
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.utils import data_processor_logger from fastdeploy.utils import data_processor_logger
from fastdeploy.worker.output import LogprobsLists from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
class RequestStatus(Enum): class RequestStatus(Enum):
@@ -245,6 +245,7 @@ class CompletionOutput:
token_ids: list[int] token_ids: list[int]
logprob: Optional[float] = None logprob: Optional[float] = None
top_logprobs: Optional[LogprobsLists] = None top_logprobs: Optional[LogprobsLists] = None
logprobs: Optional[SampleLogprobs] = None
draft_token_ids: list[int] = None draft_token_ids: list[int] = None
text: Optional[str] = None text: Optional[str] = None
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
@@ -259,6 +260,7 @@ class CompletionOutput:
"token_ids": self.token_ids, "token_ids": self.token_ids,
"logprob": self.logprob, "logprob": self.logprob,
"top_logprobs": self.top_logprobs, "top_logprobs": self.top_logprobs,
"logprobs": self.logprobs,
"draft_token_ids": self.draft_token_ids, "draft_token_ids": self.draft_token_ids,
"text": self.text, "text": self.text,
"reasoning_content": self.reasoning_content, "reasoning_content": self.reasoning_content,
@@ -281,7 +283,8 @@ class CompletionOutput:
f"text={self.text!r}, " f"text={self.text!r}, "
f"token_ids={self.token_ids}, " f"token_ids={self.token_ids}, "
f"draft_token_ids={self.draft_token_ids}, " f"draft_token_ids={self.draft_token_ids}, "
f"reasoning_content={self.reasoning_content!r}" f"reasoning_content={self.reasoning_content!r}, "
f"logprobs={self.logprobs}, "
) )
@@ -390,16 +393,20 @@ class RequestOutput:
def add(self, next_output: RequestOutput) -> None: def add(self, next_output: RequestOutput) -> None:
"""Merge RequestOutput into this one""" """Merge RequestOutput into this one"""
self.prompt = next_output.prompt self.prompt = next_output.prompt
self.prompt_token_ids = next_output.prompt_token_ids self.prompt_token_ids = next_output.prompt_token_ids
self.finished |= next_output.finished self.finished |= next_output.finished
self.outputs.index = next_output.outputs.index self.outputs.index = next_output.outputs.index
self.outputs.token_ids.extend(next_output.outputs.token_ids) self.outputs.token_ids.extend(next_output.outputs.token_ids)
if next_output.metrics.arrival_time is not None and self.metrics.inference_start_time is not None: if next_output.metrics.arrival_time is not None and self.metrics.inference_start_time is not None:
self.metrics.model_forward_time = next_output.metrics.arrival_time - self.metrics.inference_start_time self.metrics.model_forward_time = next_output.metrics.arrival_time - self.metrics.inference_start_time
if next_output.metrics.arrival_time is not None and self.metrics.arrival_time is not None: if next_output.metrics.arrival_time is not None and self.metrics.arrival_time is not None:
self.metrics.model_execute_time = next_output.metrics.arrival_time - self.metrics.arrival_time self.metrics.model_execute_time = next_output.metrics.arrival_time - self.metrics.arrival_time
if next_output.outputs.top_logprobs is not None:
self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids)
self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs)
self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@@ -407,8 +414,9 @@ class RequestOutput:
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"num_cached_tokens={self.num_cached_tokens}, "
f"metrics={self.metrics}, " f"metrics={self.metrics}, "
f"num_cached_tokens={self.num_cached_tokens})"
) )
@classmethod @classmethod

View File

@@ -42,6 +42,7 @@ class EngineClient:
enable_mm=False, enable_mm=False,
reasoning_parser=None, reasoning_parser=None,
data_parallel_size=1, data_parallel_size=1,
enable_logprob=False,
): ):
input_processor = InputPreprocessor( input_processor = InputPreprocessor(
tokenizer, tokenizer,
@@ -50,6 +51,7 @@ class EngineClient:
mm_processor_kwargs, mm_processor_kwargs,
enable_mm, enable_mm,
) )
self.enable_logprob = enable_logprob
self.enable_mm = enable_mm self.enable_mm = enable_mm
self.reasoning_parser = reasoning_parser self.reasoning_parser = reasoning_parser
self.data_processor = input_processor.create_processor() self.data_processor = input_processor.create_processor()
@@ -200,6 +202,44 @@ class EngineClient:
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError("Stream options can only be defined when `stream=True`.") raise ValueError("Stream options can only be defined when `stream=True`.")
# logprobs
logprobs = data.get("logprobs")
top_logprobs = None
if isinstance(logprobs, bool) and logprobs:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
top_logprobs = data.get("top_logprobs")
elif isinstance(logprobs, int):
top_logprobs = logprobs
elif logprobs:
raise ValueError("Invalid type for 'logprobs'")
# enable_logprob
if top_logprobs:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
if not isinstance(top_logprobs, int):
err_type = type(top_logprobs).__name__
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
if top_logprobs < 0:
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
if top_logprobs > 20:
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
def check_health(self, time_interval_threashold=30): def check_health(self, time_interval_threashold=30):
""" """
Check the health of the model server by checking whether all workers are alive. Check the health of the model server by checking whether all workers are alive.

View File

@@ -31,6 +31,7 @@ from fastdeploy.engine.sampling_params import SamplingParams
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam # from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
from fastdeploy.utils import llm_logger, retrive_model_from_server from fastdeploy.utils import llm_logger, retrive_model_from_server
from fastdeploy.worker.output import Logprob, LogprobsLists
root_logger = logging.getLogger() root_logger = logging.getLogger()
for handler in root_logger.handlers[:]: for handler in root_logger.handlers[:]:
@@ -68,12 +69,14 @@ class LLM:
model: str, model: str,
revision: Optional[str] = "master", revision: Optional[str] = "master",
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
enable_logprob: Optional[bool] = False,
**kwargs, **kwargs,
): ):
model = retrive_model_from_server(model, revision) model = retrive_model_from_server(model, revision)
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
enable_logprob=enable_logprob,
**kwargs, **kwargs,
) )
@@ -169,8 +172,10 @@ class LLM:
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params) req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
# get output # get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm) outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
for i in range(len(outputs)): for i in range(len(outputs)):
outputs[i].prompt = prompts[i] outputs[i].prompt = prompts[i]
return outputs return outputs
@@ -223,8 +228,10 @@ class LLM:
chat_template_kwargs=chat_template_kwargs, chat_template_kwargs=chat_template_kwargs,
) )
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
# get output # get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm) outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
return outputs return outputs
def _add_request( def _add_request(
@@ -278,7 +285,50 @@ class LLM:
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking) self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
return req_ids return req_ids
def _run_engine(self, req_ids: list[str], use_tqdm: bool): def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
"""
Constructs a list of dictionaries mapping token IDs to Logprob objects,
based on sliced LogprobsLists data (excluding the sampled token at index 0).
Args:
logprobs_lists (LogprobsLists): Contains top-k token IDs, logprobs, and sampled ranks.
max_num (int): Maximum number of top logprobs to include (excluding sampled token at index 0).
Returns:
list[dict[int, Logprob]]: One dict per request, mapping token ID to Logprob.
"""
try:
llm_logger.info(f"filter logprobs, topk_logprobs: {topk_logprobs}")
if not logprobs_lists.logprob_token_ids:
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
return None
# exclude sampled token at index 0
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
effective_topk_logprobs = min(topk_logprobs, available_topk)
if effective_topk_logprobs <= 0:
llm_logger.warning(
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
)
return None
# sliced 1 ~ (1 + effective_topk_logprobs)
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
result = []
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
logprob_dict = {
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=None)
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
}
result.append(logprob_dict)
return result
except Exception as e:
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}")
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
""" """
运行引擎,并返回结果列表。 运行引擎,并返回结果列表。
@@ -320,6 +370,13 @@ class LLM:
result = self.req_output.pop(req_id) result = self.req_output.pop(req_id)
result = self.llm_engine.data_processor.process_response(result) result = self.llm_engine.data_processor.process_response(result)
# filter logprobs
if result.outputs.top_logprobs and topk_logprobs:
result.outputs.logprobs = self._build_sample_logprobs(
result.outputs.top_logprobs, topk_logprobs
)
output[pos] = result output[pos] = result
finished.append(i) finished.append(i)

View File

@@ -114,6 +114,7 @@ async def lifespan(app: FastAPI):
args.enable_mm, args.enable_mm,
args.reasoning_parser, args.reasoning_parser,
args.data_parallel_size, args.data_parallel_size,
args.enable_logprob,
) )
app.state.dynamic_load_weight = args.dynamic_load_weight app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid, args.ips) chat_handler = OpenAIServingChat(engine_client, pid, args.ips)

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import json import json
import time import time
from typing import Any, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@@ -220,7 +220,7 @@ class CompletionResponseChoice(BaseModel):
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None
completion_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None
arrival_time: Optional[float] = None arrival_time: Optional[float] = None
logprobs: Optional[int] = None logprobs: Optional[CompletionLogprobs] = None
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
@@ -239,6 +239,17 @@ class CompletionResponse(BaseModel):
usage: UsageInfo usage: UsageInfo
class CompletionLogprobs(BaseModel):
"""
Completion logprobs.
"""
tokens: Optional[List[str]] = None
token_logprobs: Optional[List[float]] = None
top_logprobs: Optional[List[Dict]] = None
text_offset: Optional[List[int]] = None
class CompletionResponseStreamChoice(BaseModel): class CompletionResponseStreamChoice(BaseModel):
""" """
Completion response choice for stream response. Completion response choice for stream response.
@@ -247,9 +258,9 @@ class CompletionResponseStreamChoice(BaseModel):
index: int index: int
text: str text: str
arrival_time: float = None arrival_time: float = None
logprobs: Optional[CompletionLogprobs] = None
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None
completion_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None
logprobs: Optional[float] = None
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None

View File

@@ -76,6 +76,7 @@ class OpenAIServingChat:
err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}" err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}"
api_server_logger.error(err_msg) api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400) return ErrorResponse(message=err_msg, code=400)
if request.user is not None: if request.user is not None:
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}" request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
else: else:
@@ -225,18 +226,11 @@ class OpenAIServingChat:
output = res["outputs"] output = res["outputs"]
delta_text = output["text"] delta_text = output["text"]
raw_top_logprobs = output["top_logprobs"] output_top_logprobs = output["top_logprobs"]
logprobs_res = None logprobs_res: Optional[LogProbs] = None
if raw_top_logprobs is not None: if request.logprobs and output_top_logprobs is not None:
top_logprobs = LogprobsLists( logprobs_res = self._create_chat_logprobs(
logprob_token_ids=raw_top_logprobs[0], output_top_logprobs, request.logprobs, request.top_logprobs
logprobs=raw_top_logprobs[1],
sampled_token_ranks=raw_top_logprobs[2],
)
logprobs_res = self.build_logprobs_response(
request_logprobs=request.logprobs,
response_logprobs=top_logprobs,
request_top_logprobs=request.top_logprobs,
) )
previous_num_tokens += len(output["token_ids"]) previous_num_tokens += len(output["token_ids"])
@@ -375,17 +369,10 @@ class OpenAIServingChat:
completion_token_ids.extend(data["outputs"]["token_ids"]) completion_token_ids.extend(data["outputs"]["token_ids"])
# The logprob for handling the response # The logprob for handling the response
output = data["outputs"] output = data["outputs"]
raw_top_logprobs = output["top_logprobs"] output_top_logprobs = output["top_logprobs"]
if raw_top_logprobs is not None: if output_top_logprobs is not None:
top_logprobs = LogprobsLists( logprobs_res = self._create_chat_logprobs(
logprob_token_ids=raw_top_logprobs[0], output_top_logprobs, request.logprobs, request.top_logprobs
logprobs=raw_top_logprobs[1],
sampled_token_ranks=raw_top_logprobs[2],
)
logprobs_res = self.build_logprobs_response(
request_logprobs=request.logprobs,
response_logprobs=top_logprobs,
request_top_logprobs=request.top_logprobs,
) )
if logprobs_res and logprobs_res.content is not None: if logprobs_res and logprobs_res.content is not None:
logprob_contents.extend(logprobs_res.content) logprob_contents.extend(logprobs_res.content)
@@ -448,7 +435,36 @@ class OpenAIServingChat:
usage=usage, usage=usage,
) )
def build_logprobs_response( def _create_chat_logprobs(
self,
output_top_logprobs,
request_logprobs: Optional[bool] = None,
request_top_logprobs: Optional[int] = None,
) -> Optional[LogProbs]:
"""Create OpenAI-style logprobs for chat completions."""
if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs):
return None
logprobs_res: Optional[LogProbs] = None
for logprob_token_ids, logprobs, sampled_token_ranks in zip(
output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2]
):
top_logprobs = LogprobsLists(
logprob_token_ids=[logprob_token_ids],
logprobs=[logprobs],
sampled_token_ranks=[sampled_token_ranks],
)
step_logprobs_res = self._build_logprobs_response(
request_logprobs=request_logprobs,
response_logprobs=top_logprobs,
request_top_logprobs=request_top_logprobs,
)
if logprobs_res is None:
logprobs_res = step_logprobs_res
else:
logprobs_res.content.extend(step_logprobs_res.content)
return logprobs_res
def _build_logprobs_response(
self, self,
request_logprobs: bool, request_logprobs: bool,
response_logprobs: Optional[LogprobsLists], response_logprobs: Optional[LogprobsLists],
@@ -485,12 +501,10 @@ class OpenAIServingChat:
token_str = self.engine_client.data_processor.process_logprob_response( token_str = self.engine_client.data_processor.process_logprob_response(
[tid], clean_up_tokenization_spaces=False [tid], clean_up_tokenization_spaces=False
) )
# token_bytes = token_str.encode("utf-8", errors="replace") token_bytes = token_str.encode("utf-8", errors="replace")
entry = LogProbEntry( if "\ufffd" in token_str:
token=token_str, token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
logprob=lp, entry = LogProbEntry(token=token_str, logprob=lp, bytes=list(token_bytes))
# bytes=list(token_bytes)
)
top_logprob_entries.append(entry) top_logprob_entries.append(entry)
# Construct the sampled token object (avoid sharing references with top_logprob_entries) # Construct the sampled token object (avoid sharing references with top_logprob_entries)
sampled_entry = LogProbEntry( sampled_entry = LogProbEntry(
@@ -503,6 +517,6 @@ class OpenAIServingChat:
return LogProbs(content=[sampled_entry]) return LogProbs(content=[sampled_entry])
except Exception as e: except Exception as e:
api_server_logger.error("Error in build_logprobs_response: %s", e) api_server_logger.error("Error in _build_logprobs_response: %s", e)
api_server_logger.error(traceback.format_exc()) api_server_logger.error(traceback.format_exc())
return None return None

View File

@@ -17,7 +17,7 @@
import asyncio import asyncio
import time import time
import uuid import uuid
from typing import List from typing import List, Optional
import aiozmq import aiozmq
import msgpack import msgpack
@@ -26,6 +26,7 @@ from aiozmq import zmq
from fastdeploy.engine.request import RequestOutput from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import ( from fastdeploy.entrypoints.openai.protocol import (
CompletionLogprobs,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseChoice, CompletionResponseChoice,
@@ -35,6 +36,7 @@ from fastdeploy.entrypoints.openai.protocol import (
UsageInfo, UsageInfo,
) )
from fastdeploy.utils import api_server_logger, get_host_ip from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.worker.output import LogprobsLists
class OpenAIServingCompletion: class OpenAIServingCompletion:
@@ -160,6 +162,8 @@ class OpenAIServingCompletion:
valid_results = [dict()] * num_choices valid_results = [dict()] * num_choices
output_tokens = [0] * num_choices output_tokens = [0] * num_choices
aggregated_top_logprobs = [[[], [], []]] * num_choices
aggregated_token_ids = [[]] * num_choices
completion_batched_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)]
current_waiting_time = 0 current_waiting_time = 0
while num_choices > 0: while num_choices > 0:
@@ -182,6 +186,15 @@ class OpenAIServingCompletion:
if data.get("error_code", 200) != 200: if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"])) raise ValueError("{}".format(data["error_msg"]))
output = data["outputs"]
output_top_logprobs = output["top_logprobs"]
if output_top_logprobs is not None:
aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0])
aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1])
aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2])
aggregated_token_ids[rid].extend(data["outputs"]["token_ids"])
self.engine_client.data_processor.process_response_dict( self.engine_client.data_processor.process_response_dict(
data, stream=False, include_stop_str_in_output=request.include_stop_str_in_output data, stream=False, include_stop_str_in_output=request.include_stop_str_in_output
) )
@@ -189,6 +202,8 @@ class OpenAIServingCompletion:
completion_batched_token_ids[rid].extend(data["outputs"]["token_ids"]) completion_batched_token_ids[rid].extend(data["outputs"]["token_ids"])
if data.get("finished", False): if data.get("finished", False):
data["output_token_ids"] = output_tokens[rid] data["output_token_ids"] = output_tokens[rid]
data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid]
data["outputs"]["token_ids"] = aggregated_token_ids[rid]
valid_results[rid] = data valid_results[rid] = data
num_choices -= 1 num_choices -= 1
break break
@@ -292,6 +307,10 @@ class OpenAIServingCompletion:
arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx] arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx]
output = res["outputs"] output = res["outputs"]
output_top_logprobs = output["top_logprobs"]
logprobs_res: Optional[CompletionLogprobs] = None
if request.logprobs and output_top_logprobs is not None:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
choices.append( choices.append(
CompletionResponseStreamChoice( CompletionResponseStreamChoice(
@@ -302,6 +321,7 @@ class OpenAIServingCompletion:
tool_calls=output.get("tool_call_content"), tool_calls=output.get("tool_call_content"),
reasoning_content=output.get("reasoning_content"), reasoning_content=output.get("reasoning_content"),
arrival_time=arrival_time, arrival_time=arrival_time,
logprobs=logprobs_res,
) )
) )
if res["finished"]: if res["finished"]:
@@ -367,6 +387,7 @@ class OpenAIServingCompletion:
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
num_generated_tokens = 0 num_generated_tokens = 0
aggregated_logprobs: Optional[CompletionLogprobs] = None
for idx in range(len(final_res_batch)): for idx in range(len(final_res_batch)):
final_res = final_res_batch[idx] final_res = final_res_batch[idx]
@@ -376,6 +397,18 @@ class OpenAIServingCompletion:
completion_token_ids = completion_batched_token_ids[idx] completion_token_ids = completion_batched_token_ids[idx]
output = final_res["outputs"] output = final_res["outputs"]
output_top_logprobs = output["top_logprobs"]
if output_top_logprobs is not None:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
if aggregated_logprobs is None:
aggregated_logprobs = logprobs_res
else:
aggregated_logprobs.tokens.extend(logprobs_res.tokens)
aggregated_logprobs.token_logprobs.extend(logprobs_res.token_logprobs)
aggregated_logprobs.top_logprobs.extend(logprobs_res.top_logprobs)
aggregated_logprobs.text_offset.extend(logprobs_res.text_offset)
if request.echo: if request.echo:
assert prompt_text is not None assert prompt_text is not None
if request.max_tokens == 0: if request.max_tokens == 0:
@@ -396,7 +429,7 @@ class OpenAIServingCompletion:
completion_token_ids=completion_token_ids if request.return_token_ids else None, completion_token_ids=completion_token_ids if request.return_token_ids else None,
reasoning_content=output.get("reasoning_content"), reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call_content"), tool_calls=output.get("tool_call_content"),
logprobs=None, logprobs=aggregated_logprobs,
finish_reason=None, finish_reason=None,
) )
choices.append(choice_data) choices.append(choice_data)
@@ -419,3 +452,99 @@ class OpenAIServingCompletion:
choices=choices, choices=choices,
usage=usage, usage=usage,
) )
def _create_completion_logprobs(
self,
output_top_logprobs,
request_logprobs: Optional[int] = None,
prompt_text_offset: Optional[int] = None,
) -> Optional[CompletionLogprobs]:
"""Create OpenAI-style logprobs for completions."""
# Parameter validation
if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs):
return None
logprobs_res: Optional[CompletionLogprobs] = None
# Iterate over the top-k candidates for each token
for logprob_token_ids, logprobs, sampled_token_ranks in zip(
output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2]
):
top_logprobs = LogprobsLists(
logprob_token_ids=[logprob_token_ids],
logprobs=[logprobs],
sampled_token_ranks=[sampled_token_ranks],
)
# Build the logprobs response
step_logprobs_res = self._build_logprobs_response(
response_logprobs=top_logprobs,
request_top_logprobs=request_logprobs,
prompt_text_offset=prompt_text_offset,
)
if logprobs_res is None:
logprobs_res = step_logprobs_res
else:
# Append the new tokens to the existing logprobs response
logprobs_res.tokens.extend(step_logprobs_res.tokens)
logprobs_res.token_logprobs.extend(step_logprobs_res.token_logprobs)
logprobs_res.top_logprobs.extend(step_logprobs_res.top_logprobs)
return logprobs_res
def _build_logprobs_response(
self,
response_logprobs: Optional[LogprobsLists] = None,
request_top_logprobs: Optional[int] = None,
prompt_text_offset: Optional[int] = None,
) -> Optional[CompletionLogprobs]:
"""
Construct a logprobs response object in line with the OpenAI style.
Retain the complete top-k candidates and avoid circular references.
"""
# Parameter validation
if response_logprobs is None or request_top_logprobs is None or request_top_logprobs < 0:
return None
try:
# The top-k candidates for the current token
topk_token_ids = []
topk_logprobs = []
if response_logprobs.logprob_token_ids and len(response_logprobs.logprob_token_ids) > 0:
topk_token_ids = response_logprobs.logprob_token_ids[0][: request_top_logprobs + 1]
if response_logprobs.logprobs and len(response_logprobs.logprobs) > 0:
topk_logprobs = response_logprobs.logprobs[0][: request_top_logprobs + 1]
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
tokens = []
token_logprobs = []
top_logprobs = {}
idx = 0
for tid, lp in zip(topk_token_ids, topk_logprobs):
token_str = self.engine_client.data_processor.process_logprob_response(
[tid], clean_up_tokenization_spaces=False
)
if "\ufffd" in token_str:
token_bytes = token_str.encode("utf-8", errors="replace")
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
if idx == 0:
tokens.append(token_str)
token_logprobs.append(lp)
else:
top_logprobs[token_str] = lp
idx += 1
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
# text_offset = prompt_text_offset + len(tokens) - 1
return CompletionLogprobs(
tokens=tokens,
token_logprobs=token_logprobs,
top_logprobs=[top_logprobs],
# text_offset=[text_offset],
)
except Exception as e:
api_server_logger.error("Error in _build_logprobs_response: %s", e)
return None

View File

@@ -394,6 +394,7 @@ class TokenProcessor:
logprobs=[topk_logprobs], logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank], sampled_token_ranks=[sampled_rank],
) )
if token_id in task.eos_token_ids or is_prefill or recovery_stop: if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True result.finished = True
if recovery_stop: if recovery_stop:

View File

@@ -20,6 +20,20 @@ from typing import NamedTuple, Optional
import paddle import paddle
class Logprob(NamedTuple):
"""
A named tuple containing information about a token's log probability.
"""
logprob: float
rank: Optional[int] = None
decoded_token: Optional[str] = None
# [{token_id, logprob}] for tokens sampled from the top-k
SampleLogprobs = list[dict[int, Logprob]]
class LogprobsLists(NamedTuple): class LogprobsLists(NamedTuple):
""" """ """ """
@@ -38,6 +52,17 @@ class LogprobsLists(NamedTuple):
self.sampled_token_ranks[start:end], self.sampled_token_ranks[start:end],
) )
def slice_columns(self, start: int, end: int):
"""
Slice columns (per-row top-k logprobs and token IDs).
Keeps the number of requests unchanged.
"""
return LogprobsLists(
[row[start:end] for row in self.logprob_token_ids],
[row[start:end] for row in self.logprobs],
self.sampled_token_ranks, # unchanged
)
class LogprobsTensors(NamedTuple): class LogprobsTensors(NamedTuple):
""" """ """ """