diff --git a/augment2api_server.py b/augment2api_server.py index 22a05dd..20ce99f 100644 --- a/augment2api_server.py +++ b/augment2api_server.py @@ -34,11 +34,19 @@ logger = logging.getLogger(__name__) # 模型定义 ################################################# +# 新增:支持OpenAI新格式的内容项定义 +class ContentItem(BaseModel): + """表示OpenAI聊天API中的内容项""" + type: str # 例如 "text", "image_url" 等 + text: Optional[str] = None + # 可以在这里添加其他内容类型的字段,如image_url等 + # OpenAI API 请求模型 class ChatMessage(BaseModel): """表示OpenAI聊天API中的单条消息""" role: Literal["system", "user", "assistant", "function"] - content: Optional[str] = None + # 修改:content字段现在可以是字符串或内容项数组 + content: Optional[Union[str, List[ContentItem]]] = None name: Optional[str] = None class ChatCompletionRequest(BaseModel): @@ -192,6 +200,48 @@ def estimate_tokens(text): chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff') if text else 0 return int(words * 1.3 + chinese_chars) +def map_model_name(openai_model: str) -> Optional[str]: + """ + 将OpenAI模型名称映射到Augment模型名称 + + Args: + openai_model: OpenAI格式的模型名称 + + Returns: + Augment格式的模型名称,或None表示使用自动选择 + """ + # 模型名称映射规则 + if openai_model == "augment-auto": + # 使用null表示自动选择模型 + return None + elif openai_model.startswith("claude-"): + # Claude模型名称,添加augment-前缀 + return f"augment-{openai_model}" + elif openai_model.startswith("augment-"): + # 已经是Augment格式的名称,直接使用 + return openai_model + else: + # 其他名称默认使用自动选择 + logger.info(f"未知模型名称 '{openai_model}',使用自动选择") + return None + +# 新增:处理内容数组的函数 +def process_content_array(content_array: List[ContentItem]) -> str: + """ + 将内容数组转换为单个字符串 + + Args: + content_array: 内容项数组 + + Returns: + 合并后的文本内容 + """ + result = "" + for item in content_array: + if item.type == "text" and item.text: + result += item.text + return result + def convert_to_augment_request(openai_request: ChatCompletionRequest) -> AugmentChatRequest: """ 将OpenAI API请求转换为Augment API请求 @@ -208,6 +258,13 @@ def convert_to_augment_request(openai_request: ChatCompletionRequest) -> Augment chat_history = [] system_message = None + # 预处理所有消息,处理内容数组 + for i in range(len(openai_request.messages)): + msg = openai_request.messages[i] + if isinstance(msg.content, list): + # 将内容数组转换为单个字符串 + openai_request.messages[i].content = process_content_array(msg.content) + # 处理消息历史记录 for i in range(len(openai_request.messages) - 1): msg = openai_request.messages[i] @@ -247,8 +304,12 @@ def convert_to_augment_request(openai_request: ChatCompletionRequest) -> Augment detail="At least one user message is required" ) + # 映射模型名称 + augment_model = map_model_name(openai_request.model) + # 准备Augment请求体 augment_request = AugmentChatRequest( + model=augment_model, message=current_message, chat_history=chat_history, mode="CHAT" @@ -264,7 +325,7 @@ def convert_to_augment_request(openai_request: ChatCompletionRequest) -> Augment # FastAPI应用 ################################################# -def create_app(augment_base_url, chat_endpoint, timeout): +def create_app(augment_base_url, chat_endpoint, timeout, max_connections, max_keepalive, keepalive_expiry): """ 创建并配置FastAPI应用 @@ -272,6 +333,9 @@ def create_app(augment_base_url, chat_endpoint, timeout): augment_base_url: Augment API基础URL chat_endpoint: 聊天端点路径 timeout: 请求超时时间 + max_connections: 连接池最大连接数 + max_keepalive: 保持活动的连接数 + keepalive_expiry: 连接保持活动的时间(秒) Returns: 配置好的FastAPI应用 @@ -290,6 +354,31 @@ def create_app(augment_base_url, chat_endpoint, timeout): allow_methods=["*"], allow_headers=["*"], ) + + # HTTP客户端连接池 + http_client = None + + @app.on_event("startup") + async def startup_event(): + """应用启动时初始化HTTP客户端连接池""" + nonlocal http_client + http_client = httpx.AsyncClient( + timeout=timeout, + limits=httpx.Limits( + max_connections=max_connections, + max_keepalive_connections=max_keepalive, + keepalive_expiry=keepalive_expiry + ) + ) + logger.info(f"已初始化HTTP客户端连接池: 最大连接数={max_connections}, 保持活动连接数={max_keepalive}, 连接过期时间={keepalive_expiry}秒") + + @app.on_event("shutdown") + async def shutdown_event(): + """应用关闭时关闭HTTP客户端连接池""" + nonlocal http_client + if http_client: + await http_client.aclose() + logger.info("已关闭HTTP客户端连接池") ################################################# # 中间件和依赖项 @@ -366,11 +455,11 @@ def create_app(augment_base_url, chat_endpoint, timeout): @app.get("/v1/models") async def list_models(): """列出支持的模型""" - # 返回一个虚拟的模型列表 + # 返回支持的模型列表,包含Augment支持的模型 models = [ - ModelInfo(id="gpt-3.5-turbo", created=int(time.time())), - ModelInfo(id="gpt-4", created=int(time.time())), - ModelInfo(id="augment-default", created=int(time.time())), + ModelInfo(id="augment-auto", created=int(time.time())), + ModelInfo(id="claude-3.7-sonnet", created=int(time.time())), + ModelInfo(id="augment-claude-3.7-sonnet", created=int(time.time())), ] return ModelListResponse(data=models) @@ -397,17 +486,17 @@ def create_app(augment_base_url, chat_endpoint, timeout): try: # 转换为Augment请求格式 augment_request = convert_to_augment_request(request) - logger.debug(f"Converted request: {augment_request.dict()}") + logger.debug(f"Converted request: {augment_request.model_dump(exclude_none=True)}") # 决定是否使用流式响应 if request.stream: return StreamingResponse( - stream_augment_response(augment_base_url, api_key, augment_request, request.model, chat_endpoint, timeout), + stream_augment_response(http_client, augment_base_url, api_key, augment_request, request.model, chat_endpoint), media_type="text/event-stream" ) else: # 同步请求处理 - return await handle_sync_request(augment_base_url, api_key, augment_request, request.model, chat_endpoint, timeout) + return await handle_sync_request(http_client, augment_base_url, api_key, augment_request, request.model, chat_endpoint) except httpx.TimeoutException: logger.error("Request to Augment API timed out") @@ -454,198 +543,202 @@ def create_app(augment_base_url, chat_endpoint, timeout): return app -async def handle_sync_request(base_url, api_key, augment_request, model_name, chat_endpoint, timeout): +async def handle_sync_request(client, base_url, api_key, augment_request, model_name, chat_endpoint): """ 处理同步请求 Args: + client: HTTP客户端连接池 base_url: Augment API基础URL api_key: API密钥 augment_request: Augment API请求对象 model_name: 模型名称 chat_endpoint: 聊天端点 - timeout: 请求超时时间 Returns: OpenAI格式的聊天完成响应 """ - async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.post( + # 排除None值,确保正确的JSON格式 + request_json = augment_request.model_dump(exclude_none=True) + + response = await client.post( + f"{base_url.rstrip('/')}/{chat_endpoint}", + json=request_json, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + "User-Agent": "Augment.openai-adapter/1.0.0", + "Accept": "*/*" + } + ) + + if response.status_code != 200: + logger.error(f"Augment API error: {response.status_code} - {response.text}") + raise HTTPException( + status_code=response.status_code, + detail={ + "error": { + "message": f"Augment API error: {response.text}", + "type": "api_error", + "param": None, + "code": "api_error" + } + } + ) + + # 处理流式响应,合并为完整响应 + full_response = "" + for line in response.text.split("\n"): + if line.strip(): + try: + data = json.loads(line) + if "text" in data and data["text"]: + full_response += data["text"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON: {line}") + + # 估算token使用情况 + prompt_tokens = estimate_tokens(augment_request.message) + completion_tokens = estimate_tokens(full_response) + + # 构建OpenAI格式响应 + return ChatCompletionResponse( + id=f"chatcmpl-{generate_id()}", + created=int(time.time()), + model=model_name, + choices=[ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content=full_response + ), + finish_reason="stop" + ) + ], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens + ) + ) + +async def stream_augment_response(client, base_url, api_key, augment_request, model_name, chat_endpoint): + """ + 处理流式响应 + + Args: + client: HTTP客户端连接池 + base_url: Augment API基础URL + api_key: API密钥 + augment_request: Augment API请求对象 + model_name: 模型名称 + chat_endpoint: 聊天端点 + + Yields: + 流式响应的数据块 + """ + try: + # 排除None值,确保正确的JSON格式 + request_json = augment_request.model_dump(exclude_none=True) + + async with client.stream( + "POST", f"{base_url.rstrip('/')}/{chat_endpoint}", - json=augment_request.dict(), + json=request_json, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "User-Agent": "Augment.openai-adapter/1.0.0", "Accept": "*/*" } - ) - - if response.status_code != 200: - logger.error(f"Augment API error: {response.status_code} - {response.text}") - raise HTTPException( - status_code=response.status_code, - detail={ - "error": { - "message": f"Augment API error: {response.text}", - "type": "api_error", - "param": None, - "code": "api_error" - } - } + ) as response: + + if response.status_code != 200: + error_detail = await response.aread() + logger.error(f"Augment API error: {response.status_code} - {error_detail}") + error_message = f"Error from Augment API: {error_detail.decode('utf-8', errors='replace')}" + yield f"data: {json.dumps({'error': error_message})}\n\n" + return + + # 生成唯一ID + chat_id = f"chatcmpl-{generate_id()}" + created_time = int(time.time()) + + # 初始化响应 + init_response = ChatCompletionStreamResponse( + id=chat_id, + created=created_time, + model=model_name, + choices=[ + ChatCompletionStreamResponseChoice( + index=0, + delta={"role": "assistant"}, + finish_reason=None + ) + ] ) - - # 处理流式响应,合并为完整响应 - full_response = "" - for line in response.text.split("\n"): - if line.strip(): + init_data = json.dumps(init_response.model_dump()) + yield f"data: {init_data}\n\n" + + # 处理流式响应 + buffer = "" + async for line in response.aiter_lines(): + if not line.strip(): + continue + try: - data = json.loads(line) - if "text" in data and data["text"]: - full_response += data["text"] + # 解析Augment响应格式 + chunk = json.loads(line) + if "text" in chunk and chunk["text"]: + content = chunk["text"] + + # 发送增量更新 + stream_response = ChatCompletionStreamResponse( + id=chat_id, + created=created_time, + model=model_name, + choices=[ + ChatCompletionStreamResponseChoice( + index=0, + delta={"content": content}, + finish_reason=None + ) + ] + ) + response_data = json.dumps(stream_response.model_dump()) + yield f"data: {response_data}\n\n" except json.JSONDecodeError: logger.warning(f"Failed to parse JSON: {line}") - - # 估算token使用情况 - prompt_tokens = estimate_tokens(augment_request.message) - completion_tokens = estimate_tokens(full_response) - - # 构建OpenAI格式响应 - return ChatCompletionResponse( - id=f"chatcmpl-{generate_id()}", - created=int(time.time()), - model=model_name, - choices=[ - ChatCompletionResponseChoice( - index=0, - message=ChatMessage( - role="assistant", - content=full_response - ), - finish_reason="stop" - ) - ], - usage=Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + + # 发送完成信号 + final_response = ChatCompletionStreamResponse( + id=chat_id, + created=created_time, + model=model_name, + choices=[ + ChatCompletionStreamResponseChoice( + index=0, + delta={}, + finish_reason="stop" + ) + ] ) - ) - -async def stream_augment_response(base_url, api_key, augment_request, model_name, chat_endpoint, timeout): - """ - 处理流式响应 - - Args: - base_url: Augment API基础URL - api_key: API密钥 - augment_request: Augment API请求对象 - model_name: 模型名称 - chat_endpoint: 聊天端点 - timeout: 请求超时时间 - - Yields: - 流式响应的数据块 - """ - async with httpx.AsyncClient(timeout=timeout) as client: - try: - async with client.stream( - "POST", - f"{base_url.rstrip('/')}/{chat_endpoint}", - json=augment_request.dict(), - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - "User-Agent": "Augment.openai-adapter/1.0.0", - "Accept": "*/*" - } - ) as response: - - if response.status_code != 200: - error_detail = await response.aread() - logger.error(f"Augment API error: {response.status_code} - {error_detail}") - error_message = f"Error from Augment API: {error_detail.decode('utf-8', errors='replace')}" - yield f"data: {json.dumps({'error': error_message})}\n\n" - return - - # 生成唯一ID - chat_id = f"chatcmpl-{generate_id()}" - created_time = int(time.time()) - - # 初始化响应 - init_response = ChatCompletionStreamResponse( - id=chat_id, - created=created_time, - model=model_name, - choices=[ - ChatCompletionStreamResponseChoice( - index=0, - delta={"role": "assistant"}, - finish_reason=None - ) - ] - ) - init_data = json.dumps(init_response.dict()) - yield f"data: {init_data}\n\n" - - # 处理流式响应 - buffer = "" - async for line in response.aiter_lines(): - if not line.strip(): - continue - - try: - # 解析Augment响应格式 - chunk = json.loads(line) - if "text" in chunk and chunk["text"]: - content = chunk["text"] - - # 发送增量更新 - stream_response = ChatCompletionStreamResponse( - id=chat_id, - created=created_time, - model=model_name, - choices=[ - ChatCompletionStreamResponseChoice( - index=0, - delta={"content": content}, - finish_reason=None - ) - ] - ) - response_data = json.dumps(stream_response.dict()) - yield f"data: {response_data}\n\n" - except json.JSONDecodeError: - logger.warning(f"Failed to parse JSON: {line}") - - # 发送完成信号 - final_response = ChatCompletionStreamResponse( - id=chat_id, - created=created_time, - model=model_name, - choices=[ - ChatCompletionStreamResponseChoice( - index=0, - delta={}, - finish_reason="stop" - ) - ] - ) - final_data = json.dumps(final_response.dict()) - yield f"data: {final_data}\n\n" - - # 发送[DONE]标记 - yield "data: [DONE]\n\n" - - except httpx.TimeoutException: - logger.error("Request to Augment API timed out") - yield f"data: {json.dumps({'error': 'Request to Augment API timed out'})}\n\n" - except httpx.HTTPError as e: - logger.error(f"HTTP error: {str(e)}") - yield f"data: {json.dumps({'error': f'Error communicating with Augment API: {str(e)}'})}\n\n" - except Exception as e: - logger.exception("Unexpected error") - yield f"data: {json.dumps({'error': f'Internal server error: {str(e)}'})}\n\n" + final_data = json.dumps(final_response.model_dump()) + yield f"data: {final_data}\n\n" + + # 发送[DONE]标记 + yield "data: [DONE]\n\n" + + except httpx.TimeoutException: + logger.error("Request to Augment API timed out") + yield f"data: {json.dumps({'error': 'Request to Augment API timed out'})}\n\n" + except httpx.HTTPError as e: + logger.error(f"HTTP error: {str(e)}") + yield f"data: {json.dumps({'error': f'Error communicating with Augment API: {str(e)}'})}\n\n" + except Exception as e: + logger.exception("Unexpected error") + yield f"data: {json.dumps({'error': f'Internal server error: {str(e)}'})}\n\n" def parse_args(): """解析命令行参数""" @@ -698,6 +791,28 @@ def parse_args(): help="Augment API租户ID (域名前缀)" ) + # 连接池相关参数 + parser.add_argument( + "--max-connections", + type=int, + default=100, + help="HTTP连接池最大连接数" + ) + + parser.add_argument( + "--max-keepalive", + type=int, + default=20, + help="HTTP连接池保持活动的连接数" + ) + + parser.add_argument( + "--keepalive-expiry", + type=float, + default=60.0, + help="HTTP连接池连接保持活动的时间(秒)" + ) + return parser.parse_args() ################################################# @@ -725,13 +840,17 @@ def main(): app = create_app( augment_base_url=augment_base_url, chat_endpoint=args.chat_endpoint, - timeout=args.timeout + timeout=args.timeout, + max_connections=args.max_connections, + max_keepalive=args.max_keepalive, + keepalive_expiry=args.keepalive_expiry ) # 启动应用 logger.info(f"Starting server on {args.host}:{args.port}") logger.info(f"Using Augment base URL: {augment_base_url}") logger.info(f"Using Augment chat endpoint: {args.chat_endpoint}") + logger.info(f"HTTP连接池配置: 最大连接数={args.max_connections}, 保持活动连接数={args.max_keepalive}, 连接过期时间={args.keepalive_expiry}秒") uvicorn.run( app,