Files
cursor-api-go/main.py
2024-11-26 11:39:29 +08:00

412 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uuid
import httpx
import os
from dotenv import load_dotenv
import time
import re
# 加载环境变量
load_dotenv()
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 定义请求模型
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: bool = False
def string_to_hex(text: str, model_name: str) -> bytes:
"""将文本转换为特定格式的十六进制数据"""
# 将输入文本转换为UTF-8字节
text_bytes = text.encode('utf-8')
text_length = len(text_bytes)
# 固定常量
FIXED_HEADER = 2
SEPARATOR = 1
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
# 计算第一个长度字段
if text_length < 128:
text_length_field1 = format(text_length, '02x')
text_length_field_size1 = 1
else:
low_byte1 = (text_length & 0x7F) | 0x80
high_byte1 = (text_length >> 7) & 0xFF
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
text_length_field_size1 = 2
# 计算基础长度字段
base_length = text_length + 0x2A
if base_length < 128:
text_length_field = format(base_length, '02x')
text_length_field_size = 1
else:
low_byte = (base_length & 0x7F) | 0x80
high_byte = (base_length >> 7) & 0xFF
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
text_length_field_size = 2
# 计算总消息长度
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
# 构造十六进制字符串
model_name_bytes = model_name.encode('utf-8')
model_name_length_hex = format(len(model_name_bytes), '02X')
model_name_hex = model_name_bytes.hex().upper()
hex_string = (
f"{message_total_length:010x}"
"12"
f"{text_length_field}"
"0A"
f"{text_length_field1}"
f"{text_bytes.hex()}"
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
f"{model_name_length_hex}"
f"{model_name_hex}"
"22004A"
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
"800101B00100C00100E00100E80100"
).upper()
return bytes.fromhex(hex_string)
def chunk_to_utf8_string(chunk: bytes) -> str:
"""将二进制chunk转换为UTF-8字符串"""
if not chunk or len(chunk) < 2:
return ''
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
return ''
# 记录原始chunk的十六进制调试用
print(f"chunk length: {len(chunk)}")
# print(f"chunk hex: {chunk.hex()}")
try:
# 去掉0x0A之前的所有字节
try:
chunk = chunk[chunk.index(0x0A) + 1:]
except ValueError:
pass
filtered_chunk = bytearray()
i = 0
while i < len(chunk):
# 检查是否有连续的0x00
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
i += 4
while i < len(chunk) and chunk[i] <= 0x0F:
i += 1
continue
if chunk[i] == 0x0C:
i += 1
while i < len(chunk) and chunk[i] == 0x0A:
i += 1
else:
filtered_chunk.append(chunk[i])
i += 1
# 过滤掉特定字节
filtered_chunk = bytes(b for b in filtered_chunk
if b != 0x00 and b != 0x0C)
if not filtered_chunk:
return ''
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
# print(f"decoded result: {result}") # 调试输出
return result
except Exception as e:
print(f"Error in chunk_to_utf8_string: {str(e)}")
return ''
async def process_stream(chunks, ):
"""处理流式响应"""
response_id = f"chatcmpl-{str(uuid.uuid4())}"
# 先将所有chunks读取到列表中
# chunks = []
# async for chunk in response.aiter_raw():
# chunks.append(chunk)
# 然后处理保存的chunks
for chunk in chunks:
text = chunk_to_utf8_string(chunk)
if text:
# 清理文本
text = text.strip()
if "<|END_USER|>" in text:
text = text.split("<|END_USER|>")[-1].strip()
if text and text[0].isalpha():
text = text[1:].strip()
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
if text: # 确保清理后的文本不为空
data_body = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"choices": [{
"index": 0,
"delta": {
"content": text
}
}]
}
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
# yield "data: {\n"
# yield f' "id": "{response_id}",\n'
# yield ' "object": "chat.completion.chunk",\n'
# yield f' "created": {int(time.time())},\n'
# yield ' "choices": [{\n'
# yield ' "index": 0,\n'
# yield ' "delta": {\n'
# yield f' "content": "{text}"\n'
# yield " }\n"
# yield " }]\n"
# yield "}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, chat_request: ChatRequest):
# 验证o1模型不支持流式输出
if chat_request.model.startswith('o1-') and chat_request.stream:
raise HTTPException(
status_code=400,
detail="Model not supported stream"
)
# 获取并处理认证令牌
auth_header = request.headers.get('authorization', '')
if not auth_header.startswith('Bearer '):
raise HTTPException(
status_code=401,
detail="Invalid authorization header"
)
auth_token = auth_header.replace('Bearer ', '')
if not auth_token:
raise HTTPException(
status_code=401,
detail="Missing authorization token"
)
# 处理多个密钥
keys = [key.strip() for key in auth_token.split(',')]
if keys:
auth_token = keys[0] # 使用第一个密钥
if '%3A%3A' in auth_token:
auth_token = auth_token.split('%3A%3A')[1]
# 格式化消息
formatted_messages = "\n".join(
f"{msg.role}:{msg.content}" for msg in chat_request.messages
)
# 生成请求数据
hex_data = string_to_hex(formatted_messages, chat_request.model)
# 准备请求头
headers = {
'Content-Type': 'application/connect+proto',
'Authorization': f'Bearer {auth_token}',
'Connect-Accept-Encoding': 'gzip,br',
'Connect-Protocol-Version': '1',
'User-Agent': 'connect-es/1.4.0',
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
'X-Cursor-Client-Version': '0.42.3',
'X-Cursor-Timezone': 'Asia/Shanghai',
'X-Ghost-Mode': 'false',
'X-Request-Id': str(uuid.uuid4()),
'Host': 'api2.cursor.sh'
}
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
try:
# 使用 stream=True 参数
# 打印 headers 和 二进制 data
print(f"headers: {headers}")
print(hex_data)
async with client.stream(
'POST',
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
headers=headers,
content=hex_data,
timeout=None
) as response:
if chat_request.stream:
chunks = []
async for chunk in response.aiter_raw():
chunks.append(chunk)
return StreamingResponse(
process_stream(chunks),
media_type="text/event-stream"
)
else:
# 非流式响应处理
text = ''
async for chunk in response.aiter_raw():
# print('chunk:', chunk.hex())
print('chunk length:', len(chunk))
res = chunk_to_utf8_string(chunk)
# print('res:', res)
if res:
text += res
# 清理响应文本
import re
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
return {
"id": f"chatcmpl-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": chat_request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
except Exception as e:
print(f"Error: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal server error"
)
@app.post("/models")
async def models():
return {
"object": "list",
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"object": "model",
"created": 1713744000,
"owned_by": "anthropic"
},
{
"id": "claude-3-opus",
"object": "model",
"created": 1709251200,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-haiku",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-sonnet",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "cursor-small",
"object": "model",
"created": 1712534400,
"owned_by": "cursor"
},
{
"id": "gpt-3.5-turbo",
"object": "model",
"created": 1677649200,
"owned_by": "openai"
},
{
"id": "gpt-4",
"object": "model",
"created": 1687392000,
"owned_by": "openai"
},
{
"id": "gpt-4-turbo-2024-04-09",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-preview",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
}
]
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "3001"))
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)