diff --git a/src/app.py b/src/app.py index fbea226..3c81474 100644 --- a/src/app.py +++ b/src/app.py @@ -5,7 +5,7 @@ from datetime import datetime import httpx_sse from fastapi import FastAPI, Header -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from httpx import AsyncClient from .env import ( @@ -36,7 +36,7 @@ app = FastAPI( @app.get("/v1/models") -async def list_models(ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorization")) -> list[Model]: +async def list_models(ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorization")) -> JSONResponse: ide_token = ide_token.removeprefix("Bearer ") async with AsyncClient() as client: response = await client.get( @@ -56,7 +56,12 @@ async def list_models(ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorizati "x-os-version": TRAE_OS_VERSION, }, ) - return [Model(created=0, id=model["name"]) for model in response.json()["model_configs"]] + return JSONResponse( + { + "object": "list", + "data": [Model(created=0, id=model["name"]).model_dump() for model in response.json()["model_configs"]], + } + ) @app.post("/v1/chat/completions") @@ -64,7 +69,7 @@ async def create_chat_completions( request: ChatCompletionRequest, ide_token: str = Header(TRAE_IDE_TOKEN, alias="Authorization") ) -> StreamingResponse: ide_token = ide_token.removeprefix("Bearer ") - current_turn = sum(1 for msg in request.messages if msg.role == "user") + current_turn = sum(1 for msg in request.messages[:-1] if msg.role == "user") last_assistant_message = next(filter(lambda msg: msg.role == "assistant", reversed(request.messages)), None) async def stream_response(): @@ -87,7 +92,14 @@ async def create_chat_completions( "x-os-version": TRAE_OS_VERSION, }, json={ - "chat_history": [msg.model_dump() for msg in request.messages[:-1]], + "chat_history": [ + { + **msg.model_dump(), + "status": "success", + "locale": "zh-cn", + } + for msg in request.messages[:-1] + ], "context_resolvers": [], "conversation_id": str(uuid.uuid4()), "current_turn": current_turn,