mirror of
https://github.com/zeke-chin/cursor-api.git
synced 2025-09-26 19:51:11 +08:00
412 lines
13 KiB
Python
412 lines
13 KiB
Python
import json
|
||
|
||
from fastapi import FastAPI, Request, Response, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional, Dict, Any
|
||
import uuid
|
||
import httpx
|
||
import os
|
||
from dotenv import load_dotenv
|
||
import time
|
||
import re
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
app = FastAPI()
|
||
|
||
# 添加CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# 定义请求模型
|
||
class Message(BaseModel):
|
||
role: str
|
||
content: str
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
model: str
|
||
messages: List[Message]
|
||
stream: bool = False
|
||
|
||
|
||
def string_to_hex(text: str, model_name: str) -> bytes:
|
||
"""将文本转换为特定格式的十六进制数据"""
|
||
# 将输入文本转换为UTF-8字节
|
||
text_bytes = text.encode('utf-8')
|
||
text_length = len(text_bytes)
|
||
|
||
# 固定常量
|
||
FIXED_HEADER = 2
|
||
SEPARATOR = 1
|
||
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
||
|
||
# 计算第一个长度字段
|
||
if text_length < 128:
|
||
text_length_field1 = format(text_length, '02x')
|
||
text_length_field_size1 = 1
|
||
else:
|
||
low_byte1 = (text_length & 0x7F) | 0x80
|
||
high_byte1 = (text_length >> 7) & 0xFF
|
||
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
||
text_length_field_size1 = 2
|
||
|
||
# 计算基础长度字段
|
||
base_length = text_length + 0x2A
|
||
if base_length < 128:
|
||
text_length_field = format(base_length, '02x')
|
||
text_length_field_size = 1
|
||
else:
|
||
low_byte = (base_length & 0x7F) | 0x80
|
||
high_byte = (base_length >> 7) & 0xFF
|
||
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
||
text_length_field_size = 2
|
||
|
||
# 计算总消息长度
|
||
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
||
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
||
|
||
# 构造十六进制字符串
|
||
model_name_bytes = model_name.encode('utf-8')
|
||
model_name_length_hex = format(len(model_name_bytes), '02X')
|
||
model_name_hex = model_name_bytes.hex().upper()
|
||
|
||
hex_string = (
|
||
f"{message_total_length:010x}"
|
||
"12"
|
||
f"{text_length_field}"
|
||
"0A"
|
||
f"{text_length_field1}"
|
||
f"{text_bytes.hex()}"
|
||
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
||
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
||
f"{model_name_length_hex}"
|
||
f"{model_name_hex}"
|
||
"22004A"
|
||
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
||
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
||
"800101B00100C00100E00100E80100"
|
||
).upper()
|
||
|
||
return bytes.fromhex(hex_string)
|
||
|
||
|
||
def chunk_to_utf8_string(chunk: bytes) -> str:
|
||
"""将二进制chunk转换为UTF-8字符串"""
|
||
if not chunk or len(chunk) < 2:
|
||
return ''
|
||
|
||
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
||
return ''
|
||
|
||
# 记录原始chunk的十六进制(调试用)
|
||
print(f"chunk length: {len(chunk)}")
|
||
# print(f"chunk hex: {chunk.hex()}")
|
||
|
||
try:
|
||
# 去掉0x0A之前的所有字节
|
||
try:
|
||
chunk = chunk[chunk.index(0x0A) + 1:]
|
||
except ValueError:
|
||
pass
|
||
|
||
filtered_chunk = bytearray()
|
||
i = 0
|
||
while i < len(chunk):
|
||
# 检查是否有连续的0x00
|
||
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
||
i += 4
|
||
while i < len(chunk) and chunk[i] <= 0x0F:
|
||
i += 1
|
||
continue
|
||
|
||
if chunk[i] == 0x0C:
|
||
i += 1
|
||
while i < len(chunk) and chunk[i] == 0x0A:
|
||
i += 1
|
||
else:
|
||
filtered_chunk.append(chunk[i])
|
||
i += 1
|
||
|
||
# 过滤掉特定字节
|
||
filtered_chunk = bytes(b for b in filtered_chunk
|
||
if b != 0x00 and b != 0x0C)
|
||
|
||
if not filtered_chunk:
|
||
return ''
|
||
|
||
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
||
# print(f"decoded result: {result}") # 调试输出
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
||
return ''
|
||
|
||
|
||
async def process_stream(chunks, ):
|
||
"""处理流式响应"""
|
||
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
||
|
||
# 先将所有chunks读取到列表中
|
||
# chunks = []
|
||
# async for chunk in response.aiter_raw():
|
||
# chunks.append(chunk)
|
||
|
||
# 然后处理保存的chunks
|
||
for chunk in chunks:
|
||
text = chunk_to_utf8_string(chunk)
|
||
if text:
|
||
# 清理文本
|
||
text = text.strip()
|
||
if "<|END_USER|>" in text:
|
||
text = text.split("<|END_USER|>")[-1].strip()
|
||
if text and text[0].isalpha():
|
||
text = text[1:].strip()
|
||
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
||
|
||
if text: # 确保清理后的文本不为空
|
||
data_body = {
|
||
"id": response_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": int(time.time()),
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": text
|
||
}
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
||
# yield "data: {\n"
|
||
# yield f' "id": "{response_id}",\n'
|
||
# yield ' "object": "chat.completion.chunk",\n'
|
||
# yield f' "created": {int(time.time())},\n'
|
||
# yield ' "choices": [{\n'
|
||
# yield ' "index": 0,\n'
|
||
# yield ' "delta": {\n'
|
||
# yield f' "content": "{text}"\n'
|
||
# yield " }\n"
|
||
# yield " }]\n"
|
||
# yield "}\n\n"
|
||
|
||
yield "data: [DONE]\n\n"
|
||
|
||
|
||
@app.post("/v1/chat/completions")
|
||
async def chat_completions(request: Request, chat_request: ChatRequest):
|
||
# 验证o1模型不支持流式输出
|
||
if chat_request.model.startswith('o1-') and chat_request.stream:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="Model not supported stream"
|
||
)
|
||
|
||
# 获取并处理认证令牌
|
||
auth_header = request.headers.get('authorization', '')
|
||
if not auth_header.startswith('Bearer '):
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Invalid authorization header"
|
||
)
|
||
|
||
auth_token = auth_header.replace('Bearer ', '')
|
||
if not auth_token:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Missing authorization token"
|
||
)
|
||
|
||
# 处理多个密钥
|
||
keys = [key.strip() for key in auth_token.split(',')]
|
||
if keys:
|
||
auth_token = keys[0] # 使用第一个密钥
|
||
|
||
if '%3A%3A' in auth_token:
|
||
auth_token = auth_token.split('%3A%3A')[1]
|
||
|
||
# 格式化消息
|
||
formatted_messages = "\n".join(
|
||
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
||
)
|
||
|
||
# 生成请求数据
|
||
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
||
|
||
# 准备请求头
|
||
headers = {
|
||
'Content-Type': 'application/connect+proto',
|
||
'Authorization': f'Bearer {auth_token}',
|
||
'Connect-Accept-Encoding': 'gzip,br',
|
||
'Connect-Protocol-Version': '1',
|
||
'User-Agent': 'connect-es/1.4.0',
|
||
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
||
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
||
'X-Cursor-Client-Version': '0.42.3',
|
||
'X-Cursor-Timezone': 'Asia/Shanghai',
|
||
'X-Ghost-Mode': 'false',
|
||
'X-Request-Id': str(uuid.uuid4()),
|
||
'Host': 'api2.cursor.sh'
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||
try:
|
||
# 使用 stream=True 参数
|
||
# 打印 headers 和 二进制 data
|
||
print(f"headers: {headers}")
|
||
print(hex_data)
|
||
async with client.stream(
|
||
'POST',
|
||
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
||
headers=headers,
|
||
content=hex_data,
|
||
timeout=None
|
||
) as response:
|
||
if chat_request.stream:
|
||
chunks = []
|
||
async for chunk in response.aiter_raw():
|
||
chunks.append(chunk)
|
||
return StreamingResponse(
|
||
process_stream(chunks),
|
||
media_type="text/event-stream"
|
||
)
|
||
else:
|
||
# 非流式响应处理
|
||
text = ''
|
||
async for chunk in response.aiter_raw():
|
||
# print('chunk:', chunk.hex())
|
||
print('chunk length:', len(chunk))
|
||
|
||
res = chunk_to_utf8_string(chunk)
|
||
# print('res:', res)
|
||
if res:
|
||
text += res
|
||
|
||
# 清理响应文本
|
||
import re
|
||
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
||
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
||
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
||
|
||
return {
|
||
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": chat_request.model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": text
|
||
},
|
||
"finish_reason": "stop"
|
||
}],
|
||
"usage": {
|
||
"prompt_tokens": 0,
|
||
"completion_tokens": 0,
|
||
"total_tokens": 0
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail="Internal server error"
|
||
)
|
||
|
||
|
||
@app.post("/models")
|
||
async def models():
|
||
return {
|
||
"object": "list",
|
||
"data": [
|
||
{
|
||
"id": "claude-3-5-sonnet-20241022",
|
||
"object": "model",
|
||
"created": 1713744000,
|
||
"owned_by": "anthropic"
|
||
},
|
||
{
|
||
"id": "claude-3-opus",
|
||
"object": "model",
|
||
"created": 1709251200,
|
||
"owned_by": "anthropic"
|
||
},
|
||
{
|
||
"id": "claude-3.5-haiku",
|
||
"object": "model",
|
||
"created": 1711929600,
|
||
"owned_by": "anthropic"
|
||
},
|
||
{
|
||
"id": "claude-3.5-sonnet",
|
||
"object": "model",
|
||
"created": 1711929600,
|
||
"owned_by": "anthropic"
|
||
},
|
||
{
|
||
"id": "cursor-small",
|
||
"object": "model",
|
||
"created": 1712534400,
|
||
"owned_by": "cursor"
|
||
},
|
||
{
|
||
"id": "gpt-3.5-turbo",
|
||
"object": "model",
|
||
"created": 1677649200,
|
||
"owned_by": "openai"
|
||
},
|
||
{
|
||
"id": "gpt-4",
|
||
"object": "model",
|
||
"created": 1687392000,
|
||
"owned_by": "openai"
|
||
},
|
||
{
|
||
"id": "gpt-4-turbo-2024-04-09",
|
||
"object": "model",
|
||
"created": 1712620800,
|
||
"owned_by": "openai"
|
||
},
|
||
{
|
||
"id": "gpt-4o",
|
||
"object": "model",
|
||
"created": 1712620800,
|
||
"owned_by": "openai"
|
||
},
|
||
{
|
||
"id": "gpt-4o-mini",
|
||
"object": "model",
|
||
"created": 1712620800,
|
||
"owned_by": "openai"
|
||
},
|
||
{
|
||
"id": "o1-mini",
|
||
"object": "model",
|
||
"created": 1712620800,
|
||
"owned_by": "openai"
|
||
},
|
||
{
|
||
"id": "o1-preview",
|
||
"object": "model",
|
||
"created": 1712620800,
|
||
"owned_by": "openai"
|
||
}
|
||
]
|
||
}
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
port = int(os.getenv("PORT", "3001"))
|
||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|