mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[LLM] support send batch data and aggregate data (#2860)
* [LLM] support send batch data and aggregate data * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] update
This commit is contained in:
@@ -20,7 +20,7 @@ import time
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
@@ -181,7 +181,7 @@ class Request:
|
||||
f"sampling_params={self.sampling_params})")
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class CompletionOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
@@ -235,7 +235,7 @@ class CompletionOutput:
|
||||
f"reasoning_content={self.reasoning_content!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class RequestMetrics:
|
||||
"""Metrics associated with a request.
|
||||
|
||||
@@ -310,6 +310,10 @@ class RequestOutput:
|
||||
None if decoder-only.
|
||||
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||
"""
|
||||
__slots__ = (
|
||||
'request_id', 'prompt', 'prompt_token_ids', 'outputs',
|
||||
'finished', 'metrics', 'num_cached_tokens', 'error_code', 'error_msg'
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -333,6 +337,12 @@ class RequestOutput:
|
||||
self.error_code = error_code
|
||||
self.error_msg = error_msg
|
||||
|
||||
|
||||
if prompt_token_ids is None:
|
||||
self.prompt_token_ids = []
|
||||
elif isinstance(self.prompt_token_ids, np.ndarray):
|
||||
self.prompt_token_ids = self.prompt_token_ids.tolist()
|
||||
|
||||
def add(self, next_output: "RequestOutput") -> None:
|
||||
"""Merge RequestOutput into this one"""
|
||||
|
||||
@@ -365,11 +375,6 @@ class RequestOutput:
|
||||
|
||||
def to_dict(self):
|
||||
"""convert RequestOutput into a serializable dict """
|
||||
if self.prompt_token_ids is None:
|
||||
self.prompt_token_ids = []
|
||||
|
||||
if type(self.prompt_token_ids) is numpy.ndarray:
|
||||
self.prompt_token_ids = self.prompt_token_ids.tolist()
|
||||
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
|
Reference in New Issue
Block a user