fix list_models bug

This commit is contained in:
A23187
2025-04-08 20:45:24 +08:00
parent b4b8fb970a
commit d25d760e17

View File

@@ -5,7 +5,7 @@ from datetime import datetime
import httpx_sse import httpx_sse
from fastapi import FastAPI, Header from fastapi import FastAPI, Header
from fastapi.responses import StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from httpx import AsyncClient from httpx import AsyncClient
from .env import ( from .env import (
@@ -36,7 +36,7 @@ app = FastAPI(
@app.get("/v1/models") @app.get("/v1/models")
async def list_models(ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorization")) -> list[Model]: async def list_models(ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorization")) -> JSONResponse:
ide_token = ide_token.removeprefix("Bearer ") ide_token = ide_token.removeprefix("Bearer ")
async with AsyncClient() as client: async with AsyncClient() as client:
response = await client.get( response = await client.get(
@@ -56,7 +56,12 @@ async def list_models(ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorizati
"x-os-version": TRAE_OS_VERSION, "x-os-version": TRAE_OS_VERSION,
}, },
) )
return [Model(created=0, id=model["name"]) for model in response.json()["model_configs"]] return JSONResponse(
{
"object": "list",
"data": [Model(created=0, id=model["name"]).model_dump() for model in response.json()["model_configs"]],
}
)
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
@@ -64,7 +69,7 @@ async def create_chat_completions(
request: ChatCompletionRequest, ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorization") request: ChatCompletionRequest, ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorization")
) -> StreamingResponse: ) -> StreamingResponse:
ide_token = ide_token.removeprefix("Bearer ") ide_token = ide_token.removeprefix("Bearer ")
current_turn = sum(1 for msg in request.messages if msg.role == "user") current_turn = sum(1 for msg in request.messages[:-1] if msg.role == "user")
last_assistant_message = next(filter(lambda msg: msg.role == "assistant", reversed(request.messages)), None) last_assistant_message = next(filter(lambda msg: msg.role == "assistant", reversed(request.messages)), None)
async def stream_response(): async def stream_response():
@@ -87,7 +92,14 @@ async def create_chat_completions(
"x-os-version": TRAE_OS_VERSION, "x-os-version": TRAE_OS_VERSION,
}, },
json={ json={
"chat_history": [msg.model_dump() for msg in request.messages[:-1]], "chat_history": [
{
**msg.model_dump(),
"status": "success",
"locale": "zh-cn",
}
for msg in request.messages[:-1]
],
"context_resolvers": [], "context_resolvers": [],
"conversation_id": str(uuid.uuid4()), "conversation_id": str(uuid.uuid4()),
"current_turn": current_turn, "current_turn": current_turn,