mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] General support for logprobs (#2974)
* [Feature] support logprobs in chat/completions and completions endpoints * Temporarily comment out text_offset due to incorrect logic * Clean up temporary debug prints * [Feature] support logprobs in offline mode via SamplingParams * fix: serialize Logprob as dict before zmq send to fix msgpack error * refactor: remove redundant methods to simplify codebase * Fix missing fields in CompletionOutput.to_dict affecting msgpack serialization * refactor: centralize param validation in engine_client to reduce duplication * revert: rollback changes in offline_demo.py * revert: rollback changes in offline_demo.py * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -424,7 +424,7 @@ class LLMEngine:
|
||||
else:
|
||||
err, data = self.zmq_server.receive_pyobj_once(block)
|
||||
if err is not None:
|
||||
llm_logger.error("Engine stops inserting zmq task into scheduler")
|
||||
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
|
||||
break
|
||||
|
||||
request, insert_task = None, []
|
||||
|
@@ -25,7 +25,7 @@ import numpy as np
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
|
||||
|
||||
|
||||
class RequestStatus(Enum):
|
||||
@@ -245,6 +245,7 @@ class CompletionOutput:
|
||||
token_ids: list[int]
|
||||
logprob: Optional[float] = None
|
||||
top_logprobs: Optional[LogprobsLists] = None
|
||||
logprobs: Optional[SampleLogprobs] = None
|
||||
draft_token_ids: list[int] = None
|
||||
text: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
@@ -259,6 +260,7 @@ class CompletionOutput:
|
||||
"token_ids": self.token_ids,
|
||||
"logprob": self.logprob,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"logprobs": self.logprobs,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"text": self.text,
|
||||
"reasoning_content": self.reasoning_content,
|
||||
@@ -281,7 +283,8 @@ class CompletionOutput:
|
||||
f"text={self.text!r}, "
|
||||
f"token_ids={self.token_ids}, "
|
||||
f"draft_token_ids={self.draft_token_ids}, "
|
||||
f"reasoning_content={self.reasoning_content!r}"
|
||||
f"reasoning_content={self.reasoning_content!r}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
)
|
||||
|
||||
|
||||
@@ -390,16 +393,20 @@ class RequestOutput:
|
||||
|
||||
def add(self, next_output: RequestOutput) -> None:
|
||||
"""Merge RequestOutput into this one"""
|
||||
|
||||
self.prompt = next_output.prompt
|
||||
self.prompt_token_ids = next_output.prompt_token_ids
|
||||
self.finished |= next_output.finished
|
||||
self.outputs.index = next_output.outputs.index
|
||||
self.outputs.token_ids.extend(next_output.outputs.token_ids)
|
||||
|
||||
if next_output.metrics.arrival_time is not None and self.metrics.inference_start_time is not None:
|
||||
self.metrics.model_forward_time = next_output.metrics.arrival_time - self.metrics.inference_start_time
|
||||
if next_output.metrics.arrival_time is not None and self.metrics.arrival_time is not None:
|
||||
self.metrics.model_execute_time = next_output.metrics.arrival_time - self.metrics.arrival_time
|
||||
if next_output.outputs.top_logprobs is not None:
|
||||
self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids)
|
||||
self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs)
|
||||
self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@@ -407,8 +414,9 @@ class RequestOutput:
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished}, "
|
||||
f"num_cached_tokens={self.num_cached_tokens}, "
|
||||
f"metrics={self.metrics}, "
|
||||
f"num_cached_tokens={self.num_cached_tokens})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@@ -42,6 +42,7 @@ class EngineClient:
|
||||
enable_mm=False,
|
||||
reasoning_parser=None,
|
||||
data_parallel_size=1,
|
||||
enable_logprob=False,
|
||||
):
|
||||
input_processor = InputPreprocessor(
|
||||
tokenizer,
|
||||
@@ -50,6 +51,7 @@ class EngineClient:
|
||||
mm_processor_kwargs,
|
||||
enable_mm,
|
||||
)
|
||||
self.enable_logprob = enable_logprob
|
||||
self.enable_mm = enable_mm
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.data_processor = input_processor.create_processor()
|
||||
@@ -200,6 +202,44 @@ class EngineClient:
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
|
||||
# logprobs
|
||||
logprobs = data.get("logprobs")
|
||||
top_logprobs = None
|
||||
|
||||
if isinstance(logprobs, bool) and logprobs:
|
||||
if not self.enable_logprob:
|
||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
top_logprobs = data.get("top_logprobs")
|
||||
elif isinstance(logprobs, int):
|
||||
top_logprobs = logprobs
|
||||
elif logprobs:
|
||||
raise ValueError("Invalid type for 'logprobs'")
|
||||
|
||||
# enable_logprob
|
||||
if top_logprobs:
|
||||
if not self.enable_logprob:
|
||||
err_msg = "Logprobs is disabled, please enable it in startup config."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if not isinstance(top_logprobs, int):
|
||||
err_type = type(top_logprobs).__name__
|
||||
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if top_logprobs < 0:
|
||||
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if top_logprobs > 20:
|
||||
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
def check_health(self, time_interval_threashold=30):
|
||||
"""
|
||||
Check the health of the model server by checking whether all workers are alive.
|
||||
|
@@ -31,6 +31,7 @@ from fastdeploy.engine.sampling_params import SamplingParams
|
||||
|
||||
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from fastdeploy.utils import llm_logger, retrive_model_from_server
|
||||
from fastdeploy.worker.output import Logprob, LogprobsLists
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
@@ -68,12 +69,14 @@ class LLM:
|
||||
model: str,
|
||||
revision: Optional[str] = "master",
|
||||
tokenizer: Optional[str] = None,
|
||||
enable_logprob: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
model = retrive_model_from_server(model, revision)
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
enable_logprob=enable_logprob,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -169,8 +172,10 @@ class LLM:
|
||||
|
||||
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
|
||||
|
||||
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
|
||||
for i in range(len(outputs)):
|
||||
outputs[i].prompt = prompts[i]
|
||||
return outputs
|
||||
@@ -223,8 +228,10 @@ class LLM:
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
)
|
||||
|
||||
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
|
||||
return outputs
|
||||
|
||||
def _add_request(
|
||||
@@ -278,7 +285,50 @@ class LLM:
|
||||
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
|
||||
return req_ids
|
||||
|
||||
def _run_engine(self, req_ids: list[str], use_tqdm: bool):
|
||||
def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
|
||||
"""
|
||||
Constructs a list of dictionaries mapping token IDs to Logprob objects,
|
||||
based on sliced LogprobsLists data (excluding the sampled token at index 0).
|
||||
|
||||
Args:
|
||||
logprobs_lists (LogprobsLists): Contains top-k token IDs, logprobs, and sampled ranks.
|
||||
max_num (int): Maximum number of top logprobs to include (excluding sampled token at index 0).
|
||||
|
||||
Returns:
|
||||
list[dict[int, Logprob]]: One dict per request, mapping token ID to Logprob.
|
||||
"""
|
||||
try:
|
||||
llm_logger.info(f"filter logprobs, topk_logprobs: {topk_logprobs}")
|
||||
if not logprobs_lists.logprob_token_ids:
|
||||
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
|
||||
return None
|
||||
|
||||
# exclude sampled token at index 0
|
||||
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
|
||||
effective_topk_logprobs = min(topk_logprobs, available_topk)
|
||||
|
||||
if effective_topk_logprobs <= 0:
|
||||
llm_logger.warning(
|
||||
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
|
||||
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
|
||||
)
|
||||
return None
|
||||
|
||||
# sliced 1 ~ (1 + effective_topk_logprobs)
|
||||
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
|
||||
result = []
|
||||
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
|
||||
logprob_dict = {
|
||||
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=None)
|
||||
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
|
||||
}
|
||||
result.append(logprob_dict)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}")
|
||||
|
||||
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
|
||||
"""
|
||||
运行引擎,并返回结果列表。
|
||||
|
||||
@@ -320,6 +370,13 @@ class LLM:
|
||||
|
||||
result = self.req_output.pop(req_id)
|
||||
result = self.llm_engine.data_processor.process_response(result)
|
||||
|
||||
# filter logprobs
|
||||
if result.outputs.top_logprobs and topk_logprobs:
|
||||
result.outputs.logprobs = self._build_sample_logprobs(
|
||||
result.outputs.top_logprobs, topk_logprobs
|
||||
)
|
||||
|
||||
output[pos] = result
|
||||
finished.append(i)
|
||||
|
||||
|
@@ -114,6 +114,7 @@ async def lifespan(app: FastAPI):
|
||||
args.enable_mm,
|
||||
args.reasoning_parser,
|
||||
args.data_parallel_size,
|
||||
args.enable_logprob,
|
||||
)
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
chat_handler = OpenAIServingChat(engine_client, pid, args.ips)
|
||||
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -220,7 +220,7 @@ class CompletionResponseChoice(BaseModel):
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
completion_token_ids: Optional[List[int]] = None
|
||||
arrival_time: Optional[float] = None
|
||||
logprobs: Optional[int] = None
|
||||
logprobs: Optional[CompletionLogprobs] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
@@ -239,6 +239,17 @@ class CompletionResponse(BaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionLogprobs(BaseModel):
|
||||
"""
|
||||
Completion logprobs.
|
||||
"""
|
||||
|
||||
tokens: Optional[List[str]] = None
|
||||
token_logprobs: Optional[List[float]] = None
|
||||
top_logprobs: Optional[List[Dict]] = None
|
||||
text_offset: Optional[List[int]] = None
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
"""
|
||||
Completion response choice for stream response.
|
||||
@@ -247,9 +258,9 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
arrival_time: float = None
|
||||
logprobs: Optional[CompletionLogprobs] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
completion_token_ids: Optional[List[int]] = None
|
||||
logprobs: Optional[float] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
|
@@ -76,6 +76,7 @@ class OpenAIServingChat:
|
||||
err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}"
|
||||
api_server_logger.error(err_msg)
|
||||
return ErrorResponse(message=err_msg, code=400)
|
||||
|
||||
if request.user is not None:
|
||||
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
|
||||
else:
|
||||
@@ -225,18 +226,11 @@ class OpenAIServingChat:
|
||||
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res: Optional[LogProbs] = None
|
||||
if request.logprobs and output_top_logprobs is not None:
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
)
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
@@ -375,17 +369,10 @@ class OpenAIServingChat:
|
||||
completion_token_ids.extend(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
if output_top_logprobs is not None:
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
@@ -448,7 +435,36 @@ class OpenAIServingChat:
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def build_logprobs_response(
|
||||
def _create_chat_logprobs(
|
||||
self,
|
||||
output_top_logprobs,
|
||||
request_logprobs: Optional[bool] = None,
|
||||
request_top_logprobs: Optional[int] = None,
|
||||
) -> Optional[LogProbs]:
|
||||
"""Create OpenAI-style logprobs for chat completions."""
|
||||
if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs):
|
||||
return None
|
||||
logprobs_res: Optional[LogProbs] = None
|
||||
for logprob_token_ids, logprobs, sampled_token_ranks in zip(
|
||||
output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2]
|
||||
):
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[logprob_token_ids],
|
||||
logprobs=[logprobs],
|
||||
sampled_token_ranks=[sampled_token_ranks],
|
||||
)
|
||||
step_logprobs_res = self._build_logprobs_response(
|
||||
request_logprobs=request_logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request_top_logprobs,
|
||||
)
|
||||
if logprobs_res is None:
|
||||
logprobs_res = step_logprobs_res
|
||||
else:
|
||||
logprobs_res.content.extend(step_logprobs_res.content)
|
||||
return logprobs_res
|
||||
|
||||
def _build_logprobs_response(
|
||||
self,
|
||||
request_logprobs: bool,
|
||||
response_logprobs: Optional[LogprobsLists],
|
||||
@@ -485,12 +501,10 @@ class OpenAIServingChat:
|
||||
token_str = self.engine_client.data_processor.process_logprob_response(
|
||||
[tid], clean_up_tokenization_spaces=False
|
||||
)
|
||||
# token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
entry = LogProbEntry(
|
||||
token=token_str,
|
||||
logprob=lp,
|
||||
# bytes=list(token_bytes)
|
||||
)
|
||||
token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
if "\ufffd" in token_str:
|
||||
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
|
||||
entry = LogProbEntry(token=token_str, logprob=lp, bytes=list(token_bytes))
|
||||
top_logprob_entries.append(entry)
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
sampled_entry = LogProbEntry(
|
||||
@@ -503,6 +517,6 @@ class OpenAIServingChat:
|
||||
return LogProbs(content=[sampled_entry])
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error("Error in build_logprobs_response: %s", e)
|
||||
api_server_logger.error("Error in _build_logprobs_response: %s", e)
|
||||
api_server_logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
@@ -17,7 +17,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiozmq
|
||||
import msgpack
|
||||
@@ -26,6 +26,7 @@ from aiozmq import zmq
|
||||
|
||||
from fastdeploy.engine.request import RequestOutput
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
CompletionLogprobs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
@@ -35,6 +36,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
UsageInfo,
|
||||
)
|
||||
from fastdeploy.utils import api_server_logger, get_host_ip
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
class OpenAIServingCompletion:
|
||||
@@ -160,6 +162,8 @@ class OpenAIServingCompletion:
|
||||
|
||||
valid_results = [dict()] * num_choices
|
||||
output_tokens = [0] * num_choices
|
||||
aggregated_top_logprobs = [[[], [], []]] * num_choices
|
||||
aggregated_token_ids = [[]] * num_choices
|
||||
completion_batched_token_ids = [[] for _ in range(num_choices)]
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
@@ -182,6 +186,15 @@ class OpenAIServingCompletion:
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
|
||||
output = data["outputs"]
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
if output_top_logprobs is not None:
|
||||
aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0])
|
||||
aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1])
|
||||
aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2])
|
||||
|
||||
aggregated_token_ids[rid].extend(data["outputs"]["token_ids"])
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, include_stop_str_in_output=request.include_stop_str_in_output
|
||||
)
|
||||
@@ -189,6 +202,8 @@ class OpenAIServingCompletion:
|
||||
completion_batched_token_ids[rid].extend(data["outputs"]["token_ids"])
|
||||
if data.get("finished", False):
|
||||
data["output_token_ids"] = output_tokens[rid]
|
||||
data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid]
|
||||
data["outputs"]["token_ids"] = aggregated_token_ids[rid]
|
||||
valid_results[rid] = data
|
||||
num_choices -= 1
|
||||
break
|
||||
@@ -292,6 +307,10 @@ class OpenAIServingCompletion:
|
||||
arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx]
|
||||
|
||||
output = res["outputs"]
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res: Optional[CompletionLogprobs] = None
|
||||
if request.logprobs and output_top_logprobs is not None:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
|
||||
choices.append(
|
||||
CompletionResponseStreamChoice(
|
||||
@@ -302,6 +321,7 @@ class OpenAIServingCompletion:
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
)
|
||||
if res["finished"]:
|
||||
@@ -367,6 +387,7 @@ class OpenAIServingCompletion:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
aggregated_logprobs: Optional[CompletionLogprobs] = None
|
||||
|
||||
for idx in range(len(final_res_batch)):
|
||||
final_res = final_res_batch[idx]
|
||||
@@ -376,6 +397,18 @@ class OpenAIServingCompletion:
|
||||
completion_token_ids = completion_batched_token_ids[idx]
|
||||
|
||||
output = final_res["outputs"]
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
|
||||
if output_top_logprobs is not None:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
if aggregated_logprobs is None:
|
||||
aggregated_logprobs = logprobs_res
|
||||
else:
|
||||
aggregated_logprobs.tokens.extend(logprobs_res.tokens)
|
||||
aggregated_logprobs.token_logprobs.extend(logprobs_res.token_logprobs)
|
||||
aggregated_logprobs.top_logprobs.extend(logprobs_res.top_logprobs)
|
||||
aggregated_logprobs.text_offset.extend(logprobs_res.text_offset)
|
||||
|
||||
if request.echo:
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
@@ -396,7 +429,7 @@ class OpenAIServingCompletion:
|
||||
completion_token_ids=completion_token_ids if request.return_token_ids else None,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
logprobs=None,
|
||||
logprobs=aggregated_logprobs,
|
||||
finish_reason=None,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
@@ -419,3 +452,99 @@ class OpenAIServingCompletion:
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
output_top_logprobs,
|
||||
request_logprobs: Optional[int] = None,
|
||||
prompt_text_offset: Optional[int] = None,
|
||||
) -> Optional[CompletionLogprobs]:
|
||||
"""Create OpenAI-style logprobs for completions."""
|
||||
|
||||
# Parameter validation
|
||||
if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs):
|
||||
return None
|
||||
|
||||
logprobs_res: Optional[CompletionLogprobs] = None
|
||||
# Iterate over the top-k candidates for each token
|
||||
for logprob_token_ids, logprobs, sampled_token_ranks in zip(
|
||||
output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2]
|
||||
):
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[logprob_token_ids],
|
||||
logprobs=[logprobs],
|
||||
sampled_token_ranks=[sampled_token_ranks],
|
||||
)
|
||||
# Build the logprobs response
|
||||
step_logprobs_res = self._build_logprobs_response(
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request_logprobs,
|
||||
prompt_text_offset=prompt_text_offset,
|
||||
)
|
||||
if logprobs_res is None:
|
||||
logprobs_res = step_logprobs_res
|
||||
else:
|
||||
# Append the new tokens to the existing logprobs response
|
||||
logprobs_res.tokens.extend(step_logprobs_res.tokens)
|
||||
logprobs_res.token_logprobs.extend(step_logprobs_res.token_logprobs)
|
||||
logprobs_res.top_logprobs.extend(step_logprobs_res.top_logprobs)
|
||||
|
||||
return logprobs_res
|
||||
|
||||
def _build_logprobs_response(
|
||||
self,
|
||||
response_logprobs: Optional[LogprobsLists] = None,
|
||||
request_top_logprobs: Optional[int] = None,
|
||||
prompt_text_offset: Optional[int] = None,
|
||||
) -> Optional[CompletionLogprobs]:
|
||||
"""
|
||||
Construct a logprobs response object in line with the OpenAI style.
|
||||
Retain the complete top-k candidates and avoid circular references.
|
||||
"""
|
||||
|
||||
# Parameter validation
|
||||
if response_logprobs is None or request_top_logprobs is None or request_top_logprobs < 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# The top-k candidates for the current token
|
||||
topk_token_ids = []
|
||||
topk_logprobs = []
|
||||
|
||||
if response_logprobs.logprob_token_ids and len(response_logprobs.logprob_token_ids) > 0:
|
||||
topk_token_ids = response_logprobs.logprob_token_ids[0][: request_top_logprobs + 1]
|
||||
|
||||
if response_logprobs.logprobs and len(response_logprobs.logprobs) > 0:
|
||||
topk_logprobs = response_logprobs.logprobs[0][: request_top_logprobs + 1]
|
||||
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
tokens = []
|
||||
token_logprobs = []
|
||||
top_logprobs = {}
|
||||
idx = 0
|
||||
for tid, lp in zip(topk_token_ids, topk_logprobs):
|
||||
token_str = self.engine_client.data_processor.process_logprob_response(
|
||||
[tid], clean_up_tokenization_spaces=False
|
||||
)
|
||||
if "\ufffd" in token_str:
|
||||
token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
|
||||
if idx == 0:
|
||||
tokens.append(token_str)
|
||||
token_logprobs.append(lp)
|
||||
else:
|
||||
top_logprobs[token_str] = lp
|
||||
idx += 1
|
||||
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
# text_offset = prompt_text_offset + len(tokens) - 1
|
||||
return CompletionLogprobs(
|
||||
tokens=tokens,
|
||||
token_logprobs=token_logprobs,
|
||||
top_logprobs=[top_logprobs],
|
||||
# text_offset=[text_offset],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error("Error in _build_logprobs_response: %s", e)
|
||||
return None
|
||||
|
@@ -394,6 +394,7 @@ class TokenProcessor:
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
if recovery_stop:
|
||||
|
@@ -20,6 +20,20 @@ from typing import NamedTuple, Optional
|
||||
import paddle
|
||||
|
||||
|
||||
class Logprob(NamedTuple):
|
||||
"""
|
||||
A named tuple containing information about a token's log probability.
|
||||
"""
|
||||
|
||||
logprob: float
|
||||
rank: Optional[int] = None
|
||||
decoded_token: Optional[str] = None
|
||||
|
||||
|
||||
# [{token_id, logprob}] for tokens sampled from the top-k
|
||||
SampleLogprobs = list[dict[int, Logprob]]
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
""" """
|
||||
|
||||
@@ -38,6 +52,17 @@ class LogprobsLists(NamedTuple):
|
||||
self.sampled_token_ranks[start:end],
|
||||
)
|
||||
|
||||
def slice_columns(self, start: int, end: int):
|
||||
"""
|
||||
Slice columns (per-row top-k logprobs and token IDs).
|
||||
Keeps the number of requests unchanged.
|
||||
"""
|
||||
return LogprobsLists(
|
||||
[row[start:end] for row in self.logprob_token_ids],
|
||||
[row[start:end] for row in self.logprobs],
|
||||
self.sampled_token_ranks, # unchanged
|
||||
)
|
||||
|
||||
|
||||
class LogprobsTensors(NamedTuple):
|
||||
""" """
|
||||
|
Reference in New Issue
Block a user