mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-28 18:51:58 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -13,74 +13,98 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
import shutil
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from multiprocessing import current_process
|
||||
|
||||
import uvicorn
|
||||
import zmq
|
||||
import os
|
||||
import sys
|
||||
import ctypes
|
||||
import signal
|
||||
from fastapi import FastAPI, APIRouter, Request
|
||||
import threading
|
||||
from fastapi import FastAPI, Request
|
||||
from multiprocessing import current_process
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from contextlib import asynccontextmanager
|
||||
from prometheus_client import CONTENT_TYPE_LATEST
|
||||
from fastdeploy.metrics.metrics import cleanup_prometheus_files, main_process_metrics, EXCLUDE_LABELS, \
|
||||
get_filtered_metrics
|
||||
from fastdeploy.utils import FlexibleArgumentParser, api_server_logger, is_port_available
|
||||
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
ChatCompletionRequest,
|
||||
ErrorResponse,
|
||||
ChatCompletionResponse,
|
||||
CompletionResponse
|
||||
)
|
||||
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||
from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.entrypoints.openai.serving_completion import \
|
||||
OpenAIServingCompletion
|
||||
from fastdeploy.metrics.metrics import (EXCLUDE_LABELS,
|
||||
cleanup_prometheus_files,
|
||||
get_filtered_metrics,
|
||||
main_process_metrics)
|
||||
from fastdeploy.utils import (FlexibleArgumentParser, api_server_logger,
|
||||
console_logger, is_port_available,
|
||||
retrive_model_from_server)
|
||||
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--port", default=9904, type=int, help="port to the http server")
|
||||
parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
|
||||
parser.add_argument("--port",
|
||||
default=8000,
|
||||
type=int,
|
||||
help="port to the http server")
|
||||
parser.add_argument("--host",
|
||||
default="0.0.0.0",
|
||||
type=str,
|
||||
help="host to the http server")
|
||||
parser.add_argument("--workers", default=1, type=int, help="number of workers")
|
||||
parser.add_argument("--metrics-port", default=8000, type=int, help="port for metrics server")
|
||||
parser.add_argument("--metrics-port",
|
||||
default=8001,
|
||||
type=int,
|
||||
help="port for metrics server")
|
||||
parser.add_argument("--controller-port",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="port for controller server")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args.model = retrive_model_from_server(args.model)
|
||||
|
||||
llm_engine = None
|
||||
|
||||
|
||||
def load_engine():
|
||||
"""
|
||||
Initialize and load the LLM engine.
|
||||
|
||||
Raises:
|
||||
SystemExit: If engine initialization fails
|
||||
load engine
|
||||
"""
|
||||
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}")
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
global llm_engine
|
||||
if llm_engine is not None:
|
||||
return llm_engine
|
||||
|
||||
if not llm_engine.start(api_server_pid=os.getpid()):
|
||||
api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
|
||||
exit(-1)
|
||||
else:
|
||||
api_server_logger.info(f"FastDeploy LLM engine initialized!\n")
|
||||
api_server_logger.info(
|
||||
f"FastDeploy LLM API server starting... {os.getpid()}")
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
if not engine.start(api_server_pid=os.getpid()):
|
||||
api_server_logger.error(
|
||||
"Failed to initialize FastDeploy LLM engine, service exit now!")
|
||||
return None
|
||||
|
||||
api_server_logger.info("FastDeploy LLM engine initialized!\n")
|
||||
console_logger.info(
|
||||
f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics"
|
||||
)
|
||||
console_logger.info(
|
||||
f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions"
|
||||
)
|
||||
console_logger.info(
|
||||
f"Launching completion service at http://{args.host}:{args.port}/v1/completions"
|
||||
)
|
||||
llm_engine = engine
|
||||
return engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Async context manager for FastAPI application lifespan events.
|
||||
|
||||
Args:
|
||||
app (FastAPI): The FastAPI application instance
|
||||
|
||||
Yields:
|
||||
None: After setting up engine client and handlers
|
||||
async context manager for FastAPI lifespan
|
||||
"""
|
||||
|
||||
if args.tokenizer is None:
|
||||
@@ -90,7 +114,11 @@ async def lifespan(app: FastAPI):
|
||||
else:
|
||||
pid = os.getpid()
|
||||
api_server_logger.info(f"{pid}")
|
||||
engine_client = EngineClient(args.tokenizer, args.max_model_len, args.tensor_parallel_size, pid, args.enable_mm)
|
||||
engine_client = EngineClient(args.tokenizer, args.max_model_len,
|
||||
args.tensor_parallel_size, pid,
|
||||
args.limit_mm_per_prompt,
|
||||
args.mm_processor_kwargs, args.enable_mm,
|
||||
args.reasoning_parser)
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
chat_handler = OpenAIServingChat(engine_client, pid)
|
||||
completion_handler = OpenAIServingCompletion(engine_client, pid)
|
||||
@@ -116,15 +144,7 @@ app = FastAPI(lifespan=lifespan)
|
||||
# TODO 传递真实引擎值 通过pid 获取状态
|
||||
@app.get("/health")
|
||||
def health(request: Request) -> Response:
|
||||
"""
|
||||
Perform health check of the engine service.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object
|
||||
|
||||
Returns:
|
||||
Response: HTTP 200 if healthy, 404/304 if errors occur
|
||||
"""
|
||||
"""Health check."""
|
||||
|
||||
status, msg = app.state.engine_client.check_health()
|
||||
if not status:
|
||||
@@ -174,36 +194,21 @@ async def list_all_routes():
|
||||
|
||||
@app.api_route("/ping", methods=["GET", "POST"])
|
||||
def ping(raw_request: Request) -> Response:
|
||||
"""
|
||||
Ping endpoint for service availability check.
|
||||
|
||||
Args:
|
||||
raw_request (Request): FastAPI request object
|
||||
|
||||
Returns:
|
||||
Response: Same as health check response
|
||||
"""
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return health(raw_request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
"""
|
||||
Create chat completion based on the given request.
|
||||
|
||||
Args:
|
||||
request (ChatCompletionRequest): Chat completion request parameters
|
||||
|
||||
Returns:
|
||||
Union[JSONResponse, StreamingResponse]: Response containing either:
|
||||
- Error details if failed
|
||||
- Chat completion results
|
||||
- Stream of completion events
|
||||
Create a chat completion for the provided prompt and parameters.
|
||||
"""
|
||||
if app.state.dynamic_load_weight:
|
||||
status, msg = app.state.engine_client.is_workers_alive()
|
||||
if not status:
|
||||
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
||||
return JSONResponse(
|
||||
content={"error": "Worker Service Not Healthy"},
|
||||
status_code=304)
|
||||
generator = await app.state.chat_handler.create_chat_completion(request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
@@ -219,21 +224,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest):
|
||||
"""
|
||||
Create text completion based on the given request.
|
||||
|
||||
Args:
|
||||
request (CompletionRequest): Completion request parameters
|
||||
|
||||
Returns:
|
||||
Union[JSONResponse, StreamingResponse]: Response containing either:
|
||||
- Error details if failed
|
||||
- Completion results
|
||||
- Stream of completion events
|
||||
Create a completion for the provided prompt and parameters.
|
||||
"""
|
||||
if app.state.dynamic_load_weight:
|
||||
status, msg = app.state.engine_client.is_workers_alive()
|
||||
if not status:
|
||||
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
||||
return JSONResponse(
|
||||
content={"error": "Worker Service Not Healthy"},
|
||||
status_code=304)
|
||||
|
||||
generator = await app.state.completion_handler.create_completion(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
@@ -248,13 +246,7 @@ async def create_completion(request: CompletionRequest):
|
||||
@app.get("/update_model_weight")
|
||||
def update_model_weight(request: Request) -> Response:
|
||||
"""
|
||||
Update model weights dynamically if enabled.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object
|
||||
|
||||
Returns:
|
||||
Response: HTTP 200 if successful, 404 if failed or disabled
|
||||
update model weight
|
||||
"""
|
||||
if app.state.dynamic_load_weight:
|
||||
status, msg = app.state.engine_client.update_model_weight()
|
||||
@@ -262,45 +254,34 @@ def update_model_weight(request: Request) -> Response:
|
||||
return Response(content=msg, status_code=404)
|
||||
return Response(status_code=200)
|
||||
else:
|
||||
return Response(content="Dynamic Load Weight Disabled.", status_code=404)
|
||||
return Response(content="Dynamic Load Weight Disabled.",
|
||||
status_code=404)
|
||||
|
||||
|
||||
@app.get("/clear_load_weight")
|
||||
def clear_load_weight(request: Request) -> Response:
|
||||
"""
|
||||
Clear dynamically loaded model weights if enabled.
|
||||
|
||||
Args:
|
||||
request (Request): FastAPI request object
|
||||
|
||||
Returns:
|
||||
Response: HTTP 200 if successful, 404 if failed or disabled
|
||||
clear model weight
|
||||
"""
|
||||
if app.state.dynamic_load_weight:
|
||||
status, msg = app.state.engine_client.clear_load_weight()
|
||||
status, msg = app.state.engine_client.clear_load_weight()
|
||||
if not status:
|
||||
return Response(content=msg, status_code=404)
|
||||
return Response(status_code=200)
|
||||
else:
|
||||
return Response(content="Dynamic Load Weight Disabled.", status_code=404)
|
||||
return Response(content="Dynamic Load Weight Disabled.",
|
||||
status_code=404)
|
||||
|
||||
|
||||
def launch_api_server(args) -> None:
|
||||
"""
|
||||
Launch the API server with given configuration.
|
||||
|
||||
Args:
|
||||
args: Command line arguments containing server configuration
|
||||
|
||||
Raises:
|
||||
Exception: If server launch fails
|
||||
启动http服务
|
||||
"""
|
||||
api_server_logger.info(f"launch Fastdeploy api server... port: {args.port}")
|
||||
api_server_logger.info(
|
||||
f"launch Fastdeploy api server... port: {args.port}")
|
||||
api_server_logger.info(f"args: {args.__dict__}")
|
||||
|
||||
try:
|
||||
prom_dir = cleanup_prometheus_files(True)
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir
|
||||
metrics_server_thread = threading.Thread(target=run_main_metrics_server, daemon=True)
|
||||
metrics_server_thread.start()
|
||||
uvicorn.run(app="fastdeploy.entrypoints.openai.api_server:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
@@ -310,53 +291,93 @@ def launch_api_server(args) -> None:
|
||||
api_server_logger.error(f"launch sync http server error, {e}")
|
||||
|
||||
|
||||
main_app = FastAPI()
|
||||
metrics_app = FastAPI()
|
||||
|
||||
|
||||
@main_app.get("/metrics")
|
||||
@metrics_app.get("/metrics")
|
||||
async def metrics():
|
||||
"""
|
||||
metrics
|
||||
"""
|
||||
metrics_text = get_filtered_metrics(
|
||||
EXCLUDE_LABELS,
|
||||
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=args.workers)
|
||||
)
|
||||
extra_register_func=lambda reg: main_process_metrics.register_all(
|
||||
reg, workers=args.workers))
|
||||
return Response(metrics_text, media_type=CONTENT_TYPE_LATEST)
|
||||
|
||||
|
||||
def run_main_metrics_server():
|
||||
def run_metrics_server():
|
||||
"""
|
||||
Run metrics server in main process.
|
||||
|
||||
Starts a Uvicorn server for Prometheus metrics endpoint.
|
||||
run metrics server
|
||||
"""
|
||||
|
||||
uvicorn.run(
|
||||
main_app,
|
||||
host="0.0.0.0",
|
||||
port=args.metrics_port,
|
||||
log_level="error"
|
||||
)
|
||||
uvicorn.run(metrics_app,
|
||||
host="0.0.0.0",
|
||||
port=args.metrics_port,
|
||||
log_level="error")
|
||||
|
||||
|
||||
def launch_metrics_server():
|
||||
"""Metrics server running the sub thread"""
|
||||
prom_dir = cleanup_prometheus_files(True)
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir
|
||||
metrics_server_thread = threading.Thread(target=run_metrics_server,
|
||||
daemon=True)
|
||||
metrics_server_thread.start()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
controller_app = FastAPI()
|
||||
|
||||
|
||||
@controller_app.post("/controller/reset_scheduler")
|
||||
def reset_scheduler():
|
||||
"""
|
||||
reset scheduler
|
||||
"""
|
||||
global llm_engine
|
||||
|
||||
if llm_engine is None:
|
||||
return Response("Engine not loaded", status_code=500)
|
||||
llm_engine.reset_scheduler()
|
||||
return Response("Scheduler Reset Successfully", status_code=200)
|
||||
|
||||
|
||||
def run_controller_server():
|
||||
"""
|
||||
run controller server
|
||||
"""
|
||||
uvicorn.run(controller_app,
|
||||
host="0.0.0.0",
|
||||
port=args.controller_port,
|
||||
log_level="error")
|
||||
|
||||
|
||||
def launch_controller_server():
|
||||
"""Controller server running the sub thread"""
|
||||
if args.controller_port < 0:
|
||||
return
|
||||
|
||||
controller_server_thread = threading.Thread(target=run_controller_server,
|
||||
daemon=True)
|
||||
controller_server_thread.start()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the API server.
|
||||
|
||||
Steps:
|
||||
1. Check port availability
|
||||
2. Load LLM engine
|
||||
3. Launch API server
|
||||
|
||||
Raises:
|
||||
Exception: If ports are unavailable
|
||||
"""
|
||||
"""main函数"""
|
||||
if not is_port_available(args.host, args.port):
|
||||
raise Exception(f"The parameter `port`:{args.port} is already in use.")
|
||||
if not is_port_available(args.host, args.metrics_port):
|
||||
raise Exception(f"The parameter `metrics_port`:{args.metrics_port} is already in use.")
|
||||
load_engine()
|
||||
raise Exception(
|
||||
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
|
||||
)
|
||||
|
||||
if load_engine() is None:
|
||||
return
|
||||
|
||||
launch_controller_server()
|
||||
launch_metrics_server()
|
||||
launch_api_server(args)
|
||||
|
||||
|
||||
|
||||
@@ -15,27 +15,20 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import time
|
||||
from typing import Any, ClassVar, Literal, Optional, Union, List, Dict
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
|
||||
ValidationInfo, field_validator, model_validator)
|
||||
from typing_extensions import TypeAlias
|
||||
import json
|
||||
import time
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
#from openai.types.chat import ChatCompletionMessageParam
|
||||
from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam, parse_chat_messages
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""
|
||||
Standard error response format following OpenAI API specification.
|
||||
|
||||
Attributes:
|
||||
object (str): Always "error"
|
||||
message (str): Human-readable error message
|
||||
code (int): HTTP status code
|
||||
Error response from OpenAI API.
|
||||
"""
|
||||
object: str = "error"
|
||||
message: str
|
||||
@@ -44,23 +37,14 @@ class ErrorResponse(BaseModel):
|
||||
|
||||
class PromptTokenUsageInfo(BaseModel):
|
||||
"""
|
||||
Token usage information specific to prompt processing.
|
||||
|
||||
Attributes:
|
||||
cached_tokens (Optional[int]): Number of tokens served from cache
|
||||
Prompt-related token usage info.
|
||||
"""
|
||||
cached_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
"""
|
||||
Token usage statistics for API requests.
|
||||
|
||||
Attributes:
|
||||
prompt_tokens (int): Number of tokens in the prompt
|
||||
total_tokens (int): Total tokens used (prompt + completion)
|
||||
completion_tokens (Optional[int]): Tokens generated in completion
|
||||
prompt_tokens_details (Optional[PromptTokenUsageInfo]): Detailed prompt token info
|
||||
Usage info for a single request.
|
||||
"""
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
@@ -68,45 +52,82 @@ class UsageInfo(BaseModel):
|
||||
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
"""
|
||||
Function call.
|
||||
"""
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Tool call.
|
||||
"""
|
||||
id: str = None
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
index: int
|
||||
|
||||
|
||||
class DeltaFunctionCall(BaseModel):
|
||||
"""
|
||||
Delta function call.
|
||||
"""
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
# a tool call delta where everything is optional
|
||||
class DeltaToolCall(BaseModel):
|
||||
"""
|
||||
Delta tool call.
|
||||
"""
|
||||
id: Optional[str] = None
|
||||
type: Optional[Literal["function"]] = None
|
||||
index: int
|
||||
function: Optional[DeltaFunctionCall] = None
|
||||
|
||||
|
||||
class FunctionDefinition(BaseModel):
|
||||
"""
|
||||
Function definition.
|
||||
"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatCompletionToolsParam(BaseModel):
|
||||
"""
|
||||
Chat completion tools parameter.
|
||||
"""
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionDefinition
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""
|
||||
Single message in a chat conversation.
|
||||
|
||||
Attributes:
|
||||
role (str): Role of the message sender (system/user/assistant)
|
||||
content (str): Text content of the message
|
||||
reasoning_content (Optional[str]): Additional reasoning/explanation
|
||||
Chat message.
|
||||
"""
|
||||
role: str
|
||||
content: str
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
"""
|
||||
Single choice in a chat completion response.
|
||||
|
||||
Attributes:
|
||||
index (int): Choice index
|
||||
message (ChatMessage): Generated chat message
|
||||
finish_reason (Optional[Literal["stop", "length"]]): Reason for stopping generation
|
||||
Chat completion response choice.
|
||||
"""
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]]
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
"""
|
||||
Standard chat completion response format.
|
||||
|
||||
Attributes:
|
||||
id (str): Unique request identifier
|
||||
object (str): Always "chat.completion"
|
||||
created (int): Unix timestamp of creation
|
||||
model (str): Model name used
|
||||
choices (List[ChatCompletionResponseChoice]): Generated response choices
|
||||
usage (UsageInfo): Token usage statistics
|
||||
Chat completion response.
|
||||
"""
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
@@ -118,47 +139,28 @@ class ChatCompletionResponse(BaseModel):
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
"""
|
||||
Incremental message update for streaming responses.
|
||||
|
||||
Attributes:
|
||||
role (Optional[str]): Role of the message sender
|
||||
content (Optional[str]): Partial message content
|
||||
token_ids (Optional[List[int]]): Token IDs for the delta content
|
||||
reasoning_content (Optional[str]): Partial reasoning content
|
||||
Delta message for chat completion stream response.
|
||||
"""
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
token_ids: Optional[List[int]] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
"""
|
||||
Streaming choice in a chat completion response.
|
||||
|
||||
Attributes:
|
||||
index (int): Choice index
|
||||
delta (DeltaMessage): Incremental message update
|
||||
finish_reason (Optional[Literal["stop", "length"]]): Reason for stopping
|
||||
arrival_time (Optional[float]): Timestamp when chunk was generated
|
||||
Chat completion response choice for stream response.
|
||||
"""
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
arrival_time: Optional[float] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
"""
|
||||
Streaming chat completion response format.
|
||||
|
||||
Attributes:
|
||||
id (str): Unique request identifier
|
||||
object (str): Always "chat.completion.chunk"
|
||||
created (int): Unix timestamp of creation
|
||||
model (str): Model name used
|
||||
choices (List[ChatCompletionResponseStreamChoice]): Streaming choices
|
||||
usage (Optional[UsageInfo]): Token usage (if enabled in stream options)
|
||||
Chat completion response for stream response.
|
||||
"""
|
||||
id: str
|
||||
object: str = "chat.completion.chunk"
|
||||
@@ -170,16 +172,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
"""
|
||||
Single choice in a text completion response.
|
||||
|
||||
Attributes:
|
||||
index (int): Choice index
|
||||
text (str): Generated text
|
||||
token_ids (Optional[List[int]]): Token IDs for generated text
|
||||
arrival_time (Optional[float]): Timestamp when generated
|
||||
logprobs (Optional[int]): Log probabilities
|
||||
reasoning_content (Optional[str]): Additional reasoning
|
||||
finish_reason (Optional[Literal["stop", "length"]]): Reason for stopping
|
||||
Completion response choice.
|
||||
"""
|
||||
index: int
|
||||
text: str
|
||||
@@ -187,20 +180,13 @@ class CompletionResponseChoice(BaseModel):
|
||||
arrival_time: Optional[float] = None
|
||||
logprobs: Optional[int] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]]
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""
|
||||
Standard text completion response format.
|
||||
|
||||
Attributes:
|
||||
id (str): Unique request identifier
|
||||
object (str): Always "text_completion"
|
||||
created (int): Unix timestamp of creation
|
||||
model (str): Model name used
|
||||
choices (List[CompletionResponseChoice]): Generated response choices
|
||||
usage (UsageInfo): Token usage statistics
|
||||
Completion response.
|
||||
"""
|
||||
id: str
|
||||
object: str = "text_completion"
|
||||
@@ -212,16 +198,7 @@ class CompletionResponse(BaseModel):
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
"""
|
||||
Streaming choice in a text completion response.
|
||||
|
||||
Attributes:
|
||||
index (int): Choice index
|
||||
text (str): Partial generated text
|
||||
arrival_time (float): Timestamp when chunk was generated
|
||||
token_ids (Optional[List[int]]): Token IDs for partial text
|
||||
logprobs (Optional[float]): Log probabilities
|
||||
reasoning_content (Optional[str]): Partial reasoning
|
||||
finish_reason (Optional[Literal["stop", "length"]]): Reason for stopping
|
||||
Completion response choice for stream response.
|
||||
"""
|
||||
index: int
|
||||
text: str
|
||||
@@ -229,20 +206,13 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
token_ids: Optional[List[int]] = None
|
||||
logprobs: Optional[float] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
"""
|
||||
Streaming text completion response format.
|
||||
|
||||
Attributes:
|
||||
id (str): Unique request identifier
|
||||
object (str): Always "text_completion"
|
||||
created (int): Unix timestamp of creation
|
||||
model (str): Model name used
|
||||
choices (List[CompletionResponseStreamChoice]): Streaming choices
|
||||
usage (Optional[UsageInfo]): Token usage (if enabled in stream options)
|
||||
Completion response for stream response.
|
||||
"""
|
||||
id: str
|
||||
object: str = "text_completion"
|
||||
@@ -254,41 +224,55 @@ class CompletionStreamResponse(BaseModel):
|
||||
|
||||
class StreamOptions(BaseModel):
|
||||
"""
|
||||
Configuration options for streaming responses.
|
||||
|
||||
Attributes:
|
||||
include_usage (Optional[bool]): Whether to include usage stats
|
||||
continuous_usage_stats (Optional[bool]): Whether to send incremental usage
|
||||
Stream options.
|
||||
"""
|
||||
include_usage: Optional[bool] = True
|
||||
continuous_usage_stats: Optional[bool] = False
|
||||
|
||||
|
||||
class StructuralTag(BaseModel):
|
||||
"""
|
||||
Structural tag.
|
||||
"""
|
||||
begin: str
|
||||
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
|
||||
alias="schema")
|
||||
end: str
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(BaseModel):
|
||||
"""
|
||||
Json schema for ResponseFormat.
|
||||
"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema')
|
||||
strict: Optional[bool] = None
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
"""
|
||||
Structural tag for ResponseFormat.
|
||||
"""
|
||||
type: Literal["structural_tag"]
|
||||
structures: list[StructuralTag]
|
||||
triggers: list[str]
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
"""
|
||||
response_format type.
|
||||
"""
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""
|
||||
Text completion request parameters following OpenAI API specification.
|
||||
|
||||
Attributes:
|
||||
model (Optional[str]): Model name (default: "default")
|
||||
prompt (Union[List[int], List[List[int]], str, List[str]]): Input prompt(s)
|
||||
best_of (Optional[int]): Number of samples to generate
|
||||
echo (Optional[bool]): Whether to echo the prompt
|
||||
frequency_penalty (Optional[float]): Penalize repeated tokens
|
||||
logprobs (Optional[int]): Number of logprobs to return
|
||||
max_tokens (Optional[int]): Maximum tokens to generate (default: 16)
|
||||
n (int): Number of completions (default: 1)
|
||||
presence_penalty (Optional[float]): Penalize new tokens
|
||||
seed (Optional[int]): Random seed
|
||||
stop (Optional[Union[str, List[str]]]): Stop sequences
|
||||
stream (Optional[bool]): Whether to stream response
|
||||
stream_options (Optional[StreamOptions]): Streaming configuration
|
||||
suffix (Optional[dict]): Suffix to append
|
||||
temperature (Optional[float]): Sampling temperature
|
||||
top_p (Optional[float]): Nucleus sampling probability
|
||||
user (Optional[str]): User identifier
|
||||
repetition_penalty (Optional[float]): Repetition penalty factor
|
||||
stop_token_ids (Optional[List[int]]): Token IDs to stop generation
|
||||
Completion request to the engine.
|
||||
"""
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
@@ -296,11 +280,11 @@ class CompletionRequest(BaseModel):
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
max_tokens: Optional[int] = None
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
@@ -310,12 +294,17 @@ class CompletionRequest(BaseModel):
|
||||
top_p: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[list[str]] = None
|
||||
guided_grammar: Optional[str] = None
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
def to_dict_for_infer(self, request_id=None, prompt=None):
|
||||
"""
|
||||
@@ -340,8 +329,31 @@ class CompletionRequest(BaseModel):
|
||||
req_dict["prompt_token_ids"] = prompt
|
||||
del req_dict["prompt"]
|
||||
|
||||
return req_dict
|
||||
guided_json_object = None
|
||||
if self.response_format is not None:
|
||||
if self.response_format.type == "json_object":
|
||||
guided_json_object = True
|
||||
elif self.response_format.type == "json_schema":
|
||||
json_schema = self.response_format.json_schema.json_schema
|
||||
assert json_schema is not None, "response_format.json_schema can not be None"
|
||||
if isinstance(json_schema, (BaseModel, type(BaseModel))):
|
||||
self.guided_json = json_schema.model_json_schema()
|
||||
else:
|
||||
self.guided_json = json_schema
|
||||
|
||||
if guided_json_object:
|
||||
req_dict["guided_json_object"] = guided_json_object
|
||||
|
||||
guided_schema = [
|
||||
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
|
||||
"structural_tag"
|
||||
]
|
||||
for key in guided_schema:
|
||||
item = getattr(self, key, None)
|
||||
if item is not None:
|
||||
req_dict[key] = item
|
||||
|
||||
return req_dict
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -353,44 +365,40 @@ class CompletionRequest(BaseModel):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
guided_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None,
|
||||
"guided_grammar" in data and data["guided_grammar"] is not None
|
||||
])
|
||||
|
||||
if guided_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex', 'guided_choice', 'guided_grammar')."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""
|
||||
Chat completion request parameters following OpenAI API specification.
|
||||
|
||||
Attributes:
|
||||
messages (Union[List[ChatCompletionMessageParam], List[int]]): Conversation history
|
||||
model (Optional[str]): Model name (default: "default")
|
||||
frequency_penalty (Optional[float]): Penalize repeated tokens
|
||||
max_tokens (Optional[int]): Deprecated - max tokens to generate
|
||||
max_completion_tokens (Optional[int]): Max tokens in completion
|
||||
n (Optional[int]): Number of completions (default: 1)
|
||||
presence_penalty (Optional[float]): Penalize new tokens
|
||||
seed (Optional[int]): Random seed
|
||||
stop (Optional[Union[str, List[str]]]): Stop sequences
|
||||
stream (Optional[bool]): Whether to stream response
|
||||
stream_options (Optional[StreamOptions]): Streaming configuration
|
||||
temperature (Optional[float]): Sampling temperature
|
||||
top_p (Optional[float]): Nucleus sampling probability
|
||||
user (Optional[str]): User identifier
|
||||
metadata (Optional[dict]): Additional metadata
|
||||
repetition_penalty (Optional[float]): Repetition penalty factor
|
||||
stop_token_ids (Optional[List[int]]): Token IDs to stop generation
|
||||
Chat completion request to the engine.
|
||||
"""
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: Union[List[ChatCompletionMessageParam], List[int]]
|
||||
messages: Union[List[Any], List[int]]
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
model: Optional[str] = "default"
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = None
|
||||
# remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
deprecated='max_tokens is deprecated in favor of the max_completion_tokens field')
|
||||
deprecated=
|
||||
'max_tokens is deprecated in favor of the max_completion_tokens field')
|
||||
max_completion_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
@@ -400,9 +408,17 @@ class ChatCompletionRequest(BaseModel):
|
||||
user: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[list[str]] = None
|
||||
guided_grammar: Optional[str] = None
|
||||
structural_tag: Optional[str] = None
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
def to_dict_for_infer(self, request_id=None):
|
||||
@@ -430,6 +446,36 @@ class ChatCompletionRequest(BaseModel):
|
||||
req_dict["prompt"] = req_dict["messages"][0]["content"]
|
||||
del req_dict["messages"]
|
||||
|
||||
guided_json_object = None
|
||||
if self.response_format is not None:
|
||||
if self.response_format.type == "json_object":
|
||||
guided_json_object = True
|
||||
elif self.response_format.type == "json_schema":
|
||||
json_schema = self.response_format.json_schema.json_schema
|
||||
assert json_schema is not None, "response_format.json_schema can not be None"
|
||||
if isinstance(json_schema, (BaseModel, type(BaseModel))):
|
||||
self.guided_json = json_schema.model_json_schema()
|
||||
else:
|
||||
self.guided_json = json_schema
|
||||
elif self.response_format.type == "structural_tag":
|
||||
structural_tag = self.response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
structural_tag, StructuralTagResponseFormat)
|
||||
self.structural_tag = json.dumps(
|
||||
structural_tag.model_dump(by_alias=True))
|
||||
|
||||
if guided_json_object:
|
||||
req_dict["guided_json_object"] = guided_json_object
|
||||
|
||||
guided_schema = [
|
||||
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
|
||||
"structural_tag"
|
||||
]
|
||||
for key in guided_schema:
|
||||
item = getattr(self, key, None)
|
||||
if item is not None:
|
||||
req_dict[key] = item
|
||||
|
||||
return req_dict
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -442,4 +488,18 @@ class ChatCompletionRequest(BaseModel):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
guided_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None,
|
||||
"guided_grammar" in data and data["guided_grammar"] is not None,
|
||||
"structural_tag" in data and data["structural_tag"] is not None
|
||||
])
|
||||
|
||||
if guided_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex', 'guided_choice', 'guided_grammar', 'structural_tag')."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -33,6 +33,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
PromptTokenUsageInfo,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
@@ -46,13 +47,7 @@ from fastdeploy.engine.request import RequestOutput
|
||||
|
||||
class OpenAIServingChat:
|
||||
"""
|
||||
Implementation of OpenAI-compatible chat completion API endpoints.
|
||||
|
||||
Handles both streaming and non-streaming chat completion requests.
|
||||
|
||||
Attributes:
|
||||
engine_client: Client for communicating with the LLM engine
|
||||
pid: Process ID for ZMQ communication
|
||||
OpenAI-style chat completions serving
|
||||
"""
|
||||
|
||||
def __init__(self, engine_client, pid):
|
||||
@@ -64,16 +59,7 @@ class OpenAIServingChat:
|
||||
request: ChatCompletionRequest
|
||||
):
|
||||
"""
|
||||
Create chat completion based on the given request.
|
||||
|
||||
Args:
|
||||
request (ChatCompletionRequest): Chat completion request parameters
|
||||
|
||||
Returns:
|
||||
Union[AsyncGenerator, ChatCompletionResponse, ErrorResponse]:
|
||||
- Streaming generator if request.stream=True
|
||||
- Full completion response if request.stream=False
|
||||
- ErrorResponse if validation fails
|
||||
Create a new chat completion using the specified parameters.
|
||||
"""
|
||||
if request.user is not None:
|
||||
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
|
||||
@@ -84,33 +70,27 @@ class OpenAIServingChat:
|
||||
try:
|
||||
current_req_dict = request.to_dict_for_infer(request_id)
|
||||
current_req_dict["arrival_time"] = time.time()
|
||||
self.engine_client.format_and_add_data(current_req_dict)
|
||||
|
||||
except ValueError as e:
|
||||
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
|
||||
except Exception as e:
|
||||
return ErrorResponse(code=400, message=str(e))
|
||||
|
||||
del current_req_dict
|
||||
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, request_id, request.model)
|
||||
request, request_id,
|
||||
request.model,
|
||||
prompt_token_ids)
|
||||
else:
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, request_id, request.model)
|
||||
except ValueError as e:
|
||||
request, request_id,
|
||||
request.model,
|
||||
prompt_token_ids)
|
||||
except Exception as e:
|
||||
return ErrorResponse(code=400, message=str(e))
|
||||
|
||||
def _create_streaming_error_response(self, message: str) -> str:
|
||||
"""
|
||||
Create an error response in streaming format.
|
||||
|
||||
Args:
|
||||
message (str): Error message to include
|
||||
|
||||
Returns:
|
||||
str: JSON-formatted error response
|
||||
"""
|
||||
error_response = ErrorResponse(
|
||||
code=400,
|
||||
message=message,
|
||||
@@ -121,25 +101,11 @@ class OpenAIServingChat:
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
request_id: str,
|
||||
model_name: str
|
||||
model_name: str,
|
||||
prompt_token_ids: list()
|
||||
):
|
||||
"""
|
||||
Generator for streaming chat completion responses.
|
||||
|
||||
Args:
|
||||
request (ChatCompletionRequest): Original request parameters
|
||||
request_id (str): Unique request identifier
|
||||
model_name (str): Name of the model being used
|
||||
|
||||
Yields:
|
||||
str: Server-Sent Events (SSE) formatted chunks containing:
|
||||
- Partial completion results
|
||||
- Usage statistics (if enabled)
|
||||
- Error messages (if any)
|
||||
|
||||
Note:
|
||||
Uses ZMQ for inter-process communication with the engine.
|
||||
Maintains streaming protocol compatibility with OpenAI API.
|
||||
Streaming chat completion generator.
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: str = "chat.completion.chunk"
|
||||
@@ -148,6 +114,7 @@ class OpenAIServingChat:
|
||||
num_prompt_tokens = 0
|
||||
num_choices = 1
|
||||
max_streaming_response_tokens = 1
|
||||
enable_thinking = None
|
||||
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
|
||||
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
|
||||
|
||||
@@ -172,23 +139,32 @@ class OpenAIServingChat:
|
||||
)
|
||||
dealer.write([b"", request_id.encode('utf-8')])
|
||||
choices = []
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=300)
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
except asyncio.TimeoutError:
|
||||
status, msg = self.engine_client.check_health()
|
||||
if not status:
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
continue
|
||||
|
||||
current_waiting_time += 10
|
||||
if current_waiting_time == 300:
|
||||
status, msg = self.engine_client.check_health()
|
||||
if not status:
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
self.engine_client.data_processor.process_response_dict(res, stream=True)
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True, enable_thinking=enable_thinking)
|
||||
|
||||
if res['metrics']['first_token_time'] is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
@@ -196,15 +172,15 @@ class OpenAIServingChat:
|
||||
else:
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time
|
||||
if first_iteration:
|
||||
num_prompt_tokens = len(res["prompt_token_ids"])
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
num_cached_tokens = res.get("num_cached_tokens", 0)
|
||||
for i in range(num_choices):
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant", content="", reasoning_content="")
|
||||
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None)
|
||||
)
|
||||
if request.metadata is not None and request.metadata.get("training", False):
|
||||
choice.delta.token_ids = list(res["prompt_token_ids"])
|
||||
choice.delta.token_ids = prompt_token_ids
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
@@ -216,7 +192,8 @@ class OpenAIServingChat:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens
|
||||
total_tokens=num_prompt_tokens,
|
||||
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
|
||||
first_iteration = False
|
||||
@@ -226,18 +203,21 @@ class OpenAIServingChat:
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
|
||||
token_ids=output.get("token_ids"))
|
||||
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", []))
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=output["index"],
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
arrival_time=arrival_time
|
||||
)
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
|
||||
if request.max_tokens is None or output["index"] + 1 != request.max_tokens:
|
||||
if request.max_tokens is None or previous_num_tokens != request.max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
@@ -285,27 +265,15 @@ class OpenAIServingChat:
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
request_id: str,
|
||||
model_name: str
|
||||
model_name: str,
|
||||
prompt_token_ids: list()
|
||||
):
|
||||
"""
|
||||
Generate complete chat response in one-shot mode.
|
||||
|
||||
Args:
|
||||
request (ChatCompletionRequest): Original request parameters
|
||||
request_id (str): Unique request identifier
|
||||
model_name (str): Name of the model being used
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: Complete chat response with:
|
||||
- Generated message
|
||||
- Usage statistics
|
||||
- Finish reason
|
||||
|
||||
Raises:
|
||||
ValueError: If engine communication fails or times out
|
||||
Full chat completion generator.
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
final_res = None
|
||||
enable_thinking = None
|
||||
try:
|
||||
dealer = await aiozmq.create_zmq_stream(
|
||||
zmq.DEALER,
|
||||
@@ -314,20 +282,29 @@ class OpenAIServingChat:
|
||||
dealer.write([b"", request_id.encode('utf-8')])
|
||||
final_res = None
|
||||
previous_num_tokens = 0
|
||||
current_waiting_time = 0
|
||||
while True:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=300)
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
except asyncio.TimeoutError:
|
||||
status, msg = self.engine_client.check_health()
|
||||
if not status:
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
continue
|
||||
current_waiting_time += 10
|
||||
if current_waiting_time == 300:
|
||||
status, msg = self.engine_client.check_health()
|
||||
if not status:
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
data = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
data = self.engine_client.data_processor.process_response_dict(data, stream=False)
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
data = self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, enable_thinking=enable_thinking)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
if data["finished"]:
|
||||
@@ -342,27 +319,31 @@ class OpenAIServingChat:
|
||||
role="assistant",
|
||||
content=output["text"],
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
token_ids=output.get("token_ids")
|
||||
)
|
||||
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=output["index"],
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=None
|
||||
)
|
||||
if request.max_tokens is None or output["index"] + 1 != request.max_tokens:
|
||||
|
||||
if request.max_tokens is None or previous_num_tokens != request.max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
choices.append(choice)
|
||||
|
||||
num_prompt_tokens = len(final_res["prompt_token_ids"])
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
num_generated_tokens = previous_num_tokens
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=final_res.get("num_cached_tokens", 0))
|
||||
)
|
||||
work_process_metrics.e2e_request_latency.observe(time.time() - final_res["metrics"]["request_start_time"])
|
||||
return ChatCompletionResponse(
|
||||
|
||||
@@ -26,44 +26,31 @@ from typing import Optional, Union, cast, TypeVar, List
|
||||
import uuid
|
||||
from fastapi import Request
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ErrorResponse, CompletionRequest, CompletionResponse, CompletionStreamResponse, CompletionResponseStreamChoice, CompletionResponseChoice,UsageInfo
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionStreamResponse,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionResponseChoice,
|
||||
UsageInfo,
|
||||
DeltaToolCall,
|
||||
DeltaFunctionCall,
|
||||
ToolCall,
|
||||
FunctionCall
|
||||
)
|
||||
from fastdeploy.utils import api_server_logger
|
||||
from fastdeploy.engine.request import RequestOutput
|
||||
|
||||
|
||||
class OpenAIServingCompletion:
|
||||
"""
|
||||
Implementation of OpenAI-compatible text completion API endpoints.
|
||||
|
||||
Handles both streaming and non-streaming text completion requests.
|
||||
|
||||
Attributes:
|
||||
engine_client: Client for communicating with the LLM engine
|
||||
pid: Process ID for ZMQ communication
|
||||
"""
|
||||
def __init__(self, engine_client, pid):
|
||||
"""
|
||||
Initialize the completion service.
|
||||
|
||||
Args:
|
||||
engine_client: Client for engine communication
|
||||
pid: Process ID for ZMQ routing
|
||||
"""
|
||||
self.engine_client = engine_client
|
||||
self.pid = pid
|
||||
|
||||
async def create_completion(self, request: CompletionRequest):
|
||||
"""
|
||||
Create text completion based on the given request.
|
||||
|
||||
Args:
|
||||
request (CompletionRequest): Completion request parameters
|
||||
|
||||
Returns:
|
||||
Union[AsyncGenerator, CompletionResponse, ErrorResponse]:
|
||||
- Streaming generator if request.stream=True
|
||||
- Full completion response if request.stream=False
|
||||
- ErrorResponse if validation fails
|
||||
Create a completion for the given prompt.
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
if request.user is not None:
|
||||
@@ -97,14 +84,16 @@ class OpenAIServingCompletion:
|
||||
num_choices = len(request_prompts)
|
||||
|
||||
api_server_logger.info(f"start inference for request {num_choices}")
|
||||
|
||||
prompt_batched_token_ids = []
|
||||
try:
|
||||
for idx, prompt in enumerate(request_prompts):
|
||||
request_id_idx = f"{request_id}-{idx}"
|
||||
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
|
||||
try:
|
||||
current_req_dict["arrival_time"] = time.time()
|
||||
self.engine_client.format_and_add_data(current_req_dict)
|
||||
prompt_batched_token_ids.append(
|
||||
self.engine_client.format_and_add_data(current_req_dict)
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(message=str(e), code=400)
|
||||
|
||||
@@ -116,7 +105,8 @@ class OpenAIServingCompletion:
|
||||
num_choices = num_choices,
|
||||
request_id=request_id,
|
||||
created_time=created_time,
|
||||
model_name=request.model
|
||||
model_name=request.model,
|
||||
prompt_batched_token_ids=prompt_batched_token_ids
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -125,12 +115,13 @@ class OpenAIServingCompletion:
|
||||
num_choices=num_choices,
|
||||
request_id=request_id,
|
||||
created_time=created_time,
|
||||
model_name=request.model
|
||||
model_name=request.model,
|
||||
prompt_batched_token_ids=prompt_batched_token_ids
|
||||
)
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
return ErrorResponse(code=400, message=str(e))
|
||||
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
return ErrorResponse(message=str(e), code=400)
|
||||
|
||||
|
||||
@@ -141,25 +132,10 @@ class OpenAIServingCompletion:
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
prompt_batched_token_ids: list()
|
||||
):
|
||||
"""
|
||||
Generate complete text response in one-shot mode.
|
||||
|
||||
Args:
|
||||
request (CompletionRequest): Original request parameters
|
||||
num_choices (int): Number of prompt variations
|
||||
request_id (str): Unique request identifier
|
||||
created_time (int): Unix timestamp of creation
|
||||
model_name (str): Name of the model being used
|
||||
|
||||
Returns:
|
||||
CompletionResponse: Complete text response with:
|
||||
- Generated text
|
||||
- Usage statistics
|
||||
- Finish reason
|
||||
|
||||
Raises:
|
||||
ValueError: If engine communication fails or times out
|
||||
Process the full completion request with multiple choices.
|
||||
"""
|
||||
dealer = None
|
||||
try:
|
||||
@@ -175,22 +151,28 @@ class OpenAIServingCompletion:
|
||||
|
||||
valid_results = [dict()] * num_choices
|
||||
output_tokens = [0] * num_choices
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=300)
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
except asyncio.TimeoutError:
|
||||
status, msg = self.engine_client.check_health()
|
||||
if not status:
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
continue
|
||||
current_waiting_time += 10
|
||||
if current_waiting_time == 300:
|
||||
status, msg = self.engine_client.check_health()
|
||||
if not status:
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
data = json.loads(raw_data[-1].decode("utf-8"))
|
||||
rid = int(data["request_id"].split("-")[-1])
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False
|
||||
)
|
||||
data, stream=False)
|
||||
output_tokens[rid] += len(data["outputs"]["token_ids"])
|
||||
if data.get("finished", False):
|
||||
data["output_token_ids"] = output_tokens[rid]
|
||||
@@ -202,7 +184,8 @@ class OpenAIServingCompletion:
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
created_time=created_time,
|
||||
model_name=model_name
|
||||
model_name=model_name,
|
||||
prompt_batched_token_ids=prompt_batched_token_ids
|
||||
)
|
||||
except Exception as e:
|
||||
api_server_logger.error(
|
||||
@@ -220,27 +203,11 @@ class OpenAIServingCompletion:
|
||||
num_choices: int,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str
|
||||
model_name: str,
|
||||
prompt_batched_token_ids: list()
|
||||
):
|
||||
"""
|
||||
Generator for streaming text completion responses.
|
||||
|
||||
Args:
|
||||
request (CompletionRequest): Original request parameters
|
||||
num_choices (int): Number of prompt variations
|
||||
request_id (str): Unique request identifier
|
||||
created_time (int): Unix timestamp of creation
|
||||
model_name (str): Name of the model being used
|
||||
|
||||
Yields:
|
||||
str: Server-Sent Events (SSE) formatted chunks containing:
|
||||
- Partial completion results
|
||||
- Usage statistics (if enabled)
|
||||
- Error messages (if any)
|
||||
|
||||
Note:
|
||||
Uses ZMQ for inter-process communication with the engine.
|
||||
Maintains streaming protocol compatibility with OpenAI API.
|
||||
Process the stream completion request.
|
||||
"""
|
||||
try:
|
||||
dealer = await aiozmq.create_zmq_stream(
|
||||
@@ -259,16 +226,21 @@ class OpenAIServingCompletion:
|
||||
max_streaming_response_tokens = request.suffix["max_streaming_response_tokens"]
|
||||
choices = []
|
||||
|
||||
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=300)
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
except asyncio.TimeoutError:
|
||||
status, msg = self.engine_client.check_health()
|
||||
if not status:
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
continue
|
||||
current_waiting_time += 10
|
||||
if current_waiting_time == 300:
|
||||
status, msg = self.engine_client.check_health()
|
||||
if not status:
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
@@ -285,14 +257,15 @@ class OpenAIServingCompletion:
|
||||
choices=[CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text="",
|
||||
token_ids=list(res["prompt_token_ids"])
|
||||
token_ids=list(prompt_batched_token_ids[idx])
|
||||
)]
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
first_iteration[idx] = False
|
||||
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(res, stream=True)
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True)
|
||||
if res['metrics'].get('first_token_time') is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time[idx] = res['metrics']['inference_start_time']
|
||||
@@ -306,12 +279,16 @@ class OpenAIServingCompletion:
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
token_ids=output.get("token_ids"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time
|
||||
))
|
||||
if res["finished"]:
|
||||
if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens:
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
chunk.choices[0].finish_reason = "tool_calls"
|
||||
else:
|
||||
chunk.choices[0].finish_reason = "length"
|
||||
|
||||
@@ -337,7 +314,7 @@ class OpenAIServingCompletion:
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=len(res.get("prompt_token_ids", [])),
|
||||
prompt_tokens=len(prompt_batched_token_ids[idx]),
|
||||
completion_tokens=output_tokens[idx]
|
||||
)
|
||||
)
|
||||
@@ -360,28 +337,15 @@ class OpenAIServingCompletion:
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
prompt_batched_token_ids: list()
|
||||
) -> CompletionResponse:
|
||||
"""
|
||||
Convert raw engine outputs to OpenAI-compatible completion response.
|
||||
|
||||
Args:
|
||||
final_res_batch (List[RequestOutput]): Batch of engine responses
|
||||
request (CompletionRequest): Original request parameters
|
||||
request_id (str): Unique request identifier
|
||||
created_time (int): Unix timestamp of creation
|
||||
model_name (str): Name of the model being used
|
||||
|
||||
Returns:
|
||||
CompletionResponse: Formatted completion response with:
|
||||
- Generated text choices
|
||||
- Token usage statistics
|
||||
"""
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
|
||||
for final_res in final_res_batch:
|
||||
prompt_token_ids = final_res["prompt_token_ids"]
|
||||
for idx in range(len(final_res_batch)):
|
||||
final_res = final_res_batch[idx]
|
||||
prompt_token_ids = prompt_batched_token_ids[idx]
|
||||
assert prompt_token_ids is not None
|
||||
prompt_text = final_res["prompt"]
|
||||
|
||||
@@ -402,6 +366,7 @@ class OpenAIServingCompletion:
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
reasoning_content=output.get('reasoning_content'),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
logprobs=None,
|
||||
finish_reason=None
|
||||
)
|
||||
|
||||
82
fastdeploy/entrypoints/openai/test_openai.py
Normal file
82
fastdeploy/entrypoints/openai/test_openai.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import openai
|
||||
|
||||
ip = "0.0.0.0"
|
||||
service_http_port = "9908" # 服务配置的
|
||||
|
||||
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
|
||||
|
||||
# 非流式返回
|
||||
response = client.completions.create(
|
||||
model="default",
|
||||
prompt="There are 50 kinds of fruits, include apple, banana, pineapple",
|
||||
max_tokens=100,
|
||||
seed=13,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
print(response)
|
||||
print("\n")
|
||||
|
||||
# 流式返回
|
||||
response = client.completions.create(
|
||||
model="default",
|
||||
prompt="Hello, how are you?",
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk.choices[0].text, end='')
|
||||
print("\n")
|
||||
|
||||
# Chat completion
|
||||
# 非流式返回
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello, who are you"},
|
||||
{"role": "system", "content": "I'm a helpful AI assistant."},
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=64,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
print(response)
|
||||
print("\n")
|
||||
|
||||
|
||||
# # 流式返回
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello, who are you"},
|
||||
{"role": "system", "content": "I'm a helpful AI assistant."},
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=64,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta is not None:
|
||||
print(chunk.choices[0].delta, end='')
|
||||
print("\n")
|
||||
Reference in New Issue
Block a user