mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Feature] General support for logprobs (#2974)
* [Feature] support logprobs in chat/completions and completions endpoints * Temporarily comment out text_offset due to incorrect logic * Clean up temporary debug prints * [Feature] support logprobs in offline mode via SamplingParams * fix: serialize Logprob as dict before zmq send to fix msgpack error * refactor: remove redundant methods to simplify codebase * Fix missing fields in CompletionOutput.to_dict affecting msgpack serialization * refactor: centralize param validation in engine_client to reduce duplication * revert: rollback changes in offline_demo.py * revert: rollback changes in offline_demo.py * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -76,6 +76,7 @@ class OpenAIServingChat:
|
||||
err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}"
|
||||
api_server_logger.error(err_msg)
|
||||
return ErrorResponse(message=err_msg, code=400)
|
||||
|
||||
if request.user is not None:
|
||||
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
|
||||
else:
|
||||
@@ -225,18 +226,11 @@ class OpenAIServingChat:
|
||||
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res: Optional[LogProbs] = None
|
||||
if request.logprobs and output_top_logprobs is not None:
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
)
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
@@ -375,17 +369,10 @@ class OpenAIServingChat:
|
||||
completion_token_ids.extend(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
if output_top_logprobs is not None:
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
@@ -448,7 +435,36 @@ class OpenAIServingChat:
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def build_logprobs_response(
|
||||
def _create_chat_logprobs(
|
||||
self,
|
||||
output_top_logprobs,
|
||||
request_logprobs: Optional[bool] = None,
|
||||
request_top_logprobs: Optional[int] = None,
|
||||
) -> Optional[LogProbs]:
|
||||
"""Create OpenAI-style logprobs for chat completions."""
|
||||
if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs):
|
||||
return None
|
||||
logprobs_res: Optional[LogProbs] = None
|
||||
for logprob_token_ids, logprobs, sampled_token_ranks in zip(
|
||||
output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2]
|
||||
):
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[logprob_token_ids],
|
||||
logprobs=[logprobs],
|
||||
sampled_token_ranks=[sampled_token_ranks],
|
||||
)
|
||||
step_logprobs_res = self._build_logprobs_response(
|
||||
request_logprobs=request_logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request_top_logprobs,
|
||||
)
|
||||
if logprobs_res is None:
|
||||
logprobs_res = step_logprobs_res
|
||||
else:
|
||||
logprobs_res.content.extend(step_logprobs_res.content)
|
||||
return logprobs_res
|
||||
|
||||
def _build_logprobs_response(
|
||||
self,
|
||||
request_logprobs: bool,
|
||||
response_logprobs: Optional[LogprobsLists],
|
||||
@@ -485,12 +501,10 @@ class OpenAIServingChat:
|
||||
token_str = self.engine_client.data_processor.process_logprob_response(
|
||||
[tid], clean_up_tokenization_spaces=False
|
||||
)
|
||||
# token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
entry = LogProbEntry(
|
||||
token=token_str,
|
||||
logprob=lp,
|
||||
# bytes=list(token_bytes)
|
||||
)
|
||||
token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
if "\ufffd" in token_str:
|
||||
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
|
||||
entry = LogProbEntry(token=token_str, logprob=lp, bytes=list(token_bytes))
|
||||
top_logprob_entries.append(entry)
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
sampled_entry = LogProbEntry(
|
||||
@@ -503,6 +517,6 @@ class OpenAIServingChat:
|
||||
return LogProbs(content=[sampled_entry])
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error("Error in build_logprobs_response: %s", e)
|
||||
api_server_logger.error("Error in _build_logprobs_response: %s", e)
|
||||
api_server_logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
Reference in New Issue
Block a user