mirror of
https://github.com/zeke-chin/cursor-api.git
synced 2025-09-27 03:55:58 +08:00
update: github workflow
This commit is contained in:
1
.github/workflows/build-rs-capi.yml
vendored
1
.github/workflows/build-rs-capi.yml
vendored
@@ -40,7 +40,6 @@ jobs:
|
||||
override: true
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.target == 'aarch64-unknown-linux-gnu'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y gcc-aarch64-linux-gnu g++-aarch64-linux-gnu pkg-config libssl-dev gcc-multilib crossbuild-essential-arm64 musl-tools
|
||||
|
@@ -20,7 +20,7 @@
|
||||
|
||||
## 快速开始
|
||||
```
|
||||
docker run xxxx -p 3000:3000 ghcr.io/xxxx/rs-capi:latest
|
||||
docker run --rm -p 7070:3000 ghcr.io/xxxx/rs-capi:latest
|
||||
```
|
||||
|
||||
docker-compose
|
||||
|
@@ -16,9 +16,9 @@ RUN cargo build --bin rs-capi --release
|
||||
|
||||
|
||||
FROM ubuntu:22.04
|
||||
# RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# ca-certificates \
|
||||
# && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /workspace/target/release/rs-capi /workspace/rs-capi
|
||||
|
||||
|
822
rs-capi/main.py
822
rs-capi/main.py
@@ -1,822 +0,0 @@
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import httpx
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
import re
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# 定义请求模型
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
stream: bool = False
|
||||
|
||||
|
||||
def string_to_hex(text: str, model_name: str) -> bytes:
|
||||
"""将文本转换为特定格式的十六进制数据"""
|
||||
# 将输入文本转换为UTF-8字节
|
||||
text_bytes = text.encode('utf-8')
|
||||
text_length = len(text_bytes)
|
||||
|
||||
# 固定常量
|
||||
FIXED_HEADER = 2
|
||||
SEPARATOR = 1
|
||||
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
||||
|
||||
# 计算第一个长度字段
|
||||
if text_length < 128:
|
||||
text_length_field1 = format(text_length, '02x')
|
||||
text_length_field_size1 = 1
|
||||
else:
|
||||
low_byte1 = (text_length & 0x7F) | 0x80
|
||||
high_byte1 = (text_length >> 7) & 0xFF
|
||||
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
||||
text_length_field_size1 = 2
|
||||
|
||||
# 计算基础长度字段
|
||||
base_length = text_length + 0x2A
|
||||
if base_length < 128:
|
||||
text_length_field = format(base_length, '02x')
|
||||
text_length_field_size = 1
|
||||
else:
|
||||
low_byte = (base_length & 0x7F) | 0x80
|
||||
high_byte = (base_length >> 7) & 0xFF
|
||||
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
||||
text_length_field_size = 2
|
||||
|
||||
# 计算总消息长度
|
||||
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
||||
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
||||
|
||||
# 构造十六进制字符串
|
||||
model_name_bytes = model_name.encode('utf-8')
|
||||
model_name_length_hex = format(len(model_name_bytes), '02X')
|
||||
model_name_hex = model_name_bytes.hex().upper()
|
||||
|
||||
hex_string = (
|
||||
f"{message_total_length:010x}"
|
||||
"12"
|
||||
f"{text_length_field}"
|
||||
"0A"
|
||||
f"{text_length_field1}"
|
||||
f"{text_bytes.hex()}"
|
||||
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
||||
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
||||
f"{model_name_length_hex}"
|
||||
f"{model_name_hex}"
|
||||
"22004A"
|
||||
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
||||
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
||||
"800101B00100C00100E00100E80100"
|
||||
).upper()
|
||||
|
||||
return bytes.fromhex(hex_string)
|
||||
|
||||
|
||||
def chunk_to_utf8_string(chunk: bytes) -> str:
|
||||
"""将二进制chunk转换为UTF-8字符串"""
|
||||
if not chunk or len(chunk) < 2:
|
||||
return ''
|
||||
|
||||
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
||||
return ''
|
||||
|
||||
# 记录原始chunk的十六进制(调试用)
|
||||
print(f"chunk length: {len(chunk)}")
|
||||
# print(f"chunk hex: {chunk.hex()}")
|
||||
|
||||
try:
|
||||
# 去掉0x0A之前的所有字节
|
||||
try:
|
||||
chunk = chunk[chunk.index(0x0A) + 1:]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
filtered_chunk = bytearray()
|
||||
i = 0
|
||||
while i < len(chunk):
|
||||
# 检查是否有连续的0x00
|
||||
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
||||
i += 4
|
||||
while i < len(chunk) and chunk[i] <= 0x0F:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if chunk[i] == 0x0C:
|
||||
i += 1
|
||||
while i < len(chunk) and chunk[i] == 0x0A:
|
||||
i += 1
|
||||
else:
|
||||
filtered_chunk.append(chunk[i])
|
||||
i += 1
|
||||
|
||||
# 过滤掉特定字节
|
||||
filtered_chunk = bytes(b for b in filtered_chunk
|
||||
if b != 0x00 and b != 0x0C)
|
||||
|
||||
if not filtered_chunk:
|
||||
return ''
|
||||
|
||||
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
||||
# print(f"decoded result: {result}") # 调试输出
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
||||
return ''
|
||||
|
||||
|
||||
async def process_stream(chunks, ):
|
||||
"""处理流式响应"""
|
||||
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
||||
|
||||
# 先将所有chunks读取到列表中
|
||||
# chunks = []
|
||||
# async for chunk in response.aiter_raw():
|
||||
# chunks.append(chunk)
|
||||
|
||||
# 然后处理保存的chunks
|
||||
for chunk in chunks:
|
||||
text = chunk_to_utf8_string(chunk)
|
||||
if text:
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if "<|END_USER|>" in text:
|
||||
text = text.split("<|END_USER|>")[-1].strip()
|
||||
if text and text[0].isalpha():
|
||||
text = text[1:].strip()
|
||||
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
||||
|
||||
if text: # 确保清理后的文本不为空
|
||||
data_body = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": text
|
||||
}
|
||||
}]
|
||||
}
|
||||
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
||||
# yield "data: {\n"
|
||||
# yield f' "id": "{response_id}",\n'
|
||||
# yield ' "object": "chat.completion.chunk",\n'
|
||||
# yield f' "created": {int(time.time())},\n'
|
||||
# yield ' "choices": [{\n'
|
||||
# yield ' "index": 0,\n'
|
||||
# yield ' "delta": {\n'
|
||||
# yield f' "content": "{text}"\n'
|
||||
# yield " }\n"
|
||||
# yield " }]\n"
|
||||
# yield "}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, chat_request: ChatRequest):
|
||||
# 验证o1模型不支持流式输出
|
||||
if chat_request.model.startswith('o1-') and chat_request.stream:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model not supported stream"
|
||||
)
|
||||
|
||||
# 获取并处理认证令牌
|
||||
auth_header = request.headers.get('authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authorization header"
|
||||
)
|
||||
|
||||
auth_token = auth_header.replace('Bearer ', '')
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization token"
|
||||
)
|
||||
|
||||
# 处理多个密钥
|
||||
keys = [key.strip() for key in auth_token.split(',')]
|
||||
if keys:
|
||||
auth_token = keys[0] # 使用第一个密钥
|
||||
|
||||
if '%3A%3A' in auth_token:
|
||||
auth_token = auth_token.split('%3A%3A')[1]
|
||||
|
||||
# 格式化消息
|
||||
formatted_messages = "\n".join(
|
||||
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
||||
)
|
||||
|
||||
# 生成请求数据
|
||||
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
'Content-Type': 'application/connect+proto',
|
||||
'Authorization': f'Bearer {auth_token}',
|
||||
'Connect-Accept-Encoding': 'gzip,br',
|
||||
'Connect-Protocol-Version': '1',
|
||||
'User-Agent': 'connect-es/1.4.0',
|
||||
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
||||
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
||||
'X-Cursor-Client-Version': '0.42.3',
|
||||
'X-Cursor-Timezone': 'Asia/Shanghai',
|
||||
'X-Ghost-Mode': 'false',
|
||||
'X-Request-Id': str(uuid.uuid4()),
|
||||
'Host': 'api2.cursor.sh'
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||||
try:
|
||||
# 使用 stream=True 参数
|
||||
# 打印 headers 和 二进制 data
|
||||
print(f"headers: {headers}")
|
||||
print(hex_data)
|
||||
async with client.stream(
|
||||
'POST',
|
||||
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
||||
headers=headers,
|
||||
content=hex_data,
|
||||
timeout=None
|
||||
) as response:
|
||||
if chat_request.stream:
|
||||
chunks = []
|
||||
async for chunk in response.aiter_raw():
|
||||
chunks.append(chunk)
|
||||
return StreamingResponse(
|
||||
process_stream(chunks),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# 非流式响应处理
|
||||
text = ''
|
||||
async for chunk in response.aiter_raw():
|
||||
# print('chunk:', chunk.hex())
|
||||
print('chunk length:', len(chunk))
|
||||
|
||||
res = chunk_to_utf8_string(chunk)
|
||||
# print('res:', res)
|
||||
if res:
|
||||
text += res
|
||||
|
||||
# 清理响应文本
|
||||
import re
|
||||
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
||||
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": chat_request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/models")
|
||||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"object": "model",
|
||||
"created": 1713744000,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3-opus",
|
||||
"object": "model",
|
||||
"created": 1709251200,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-haiku",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "cursor-small",
|
||||
"object": "model",
|
||||
"created": 1712534400,
|
||||
"owned_by": "cursor"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"created": 1677649200,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"created": 1687392000,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4-turbo-2024-04-09",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-preview",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("PORT", "3001"))
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import httpx
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
import re
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# 定义请求模型
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
stream: bool = False
|
||||
|
||||
|
||||
def string_to_hex(text: str, model_name: str) -> bytes:
|
||||
"""将文本转换为特定格式的十六进制数据"""
|
||||
# 将输入文本转换为UTF-8字节
|
||||
text_bytes = text.encode('utf-8')
|
||||
text_length = len(text_bytes)
|
||||
|
||||
# 固定常量
|
||||
FIXED_HEADER = 2
|
||||
SEPARATOR = 1
|
||||
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
||||
|
||||
# 计算第一个长度字段
|
||||
if text_length < 128:
|
||||
text_length_field1 = format(text_length, '02x')
|
||||
text_length_field_size1 = 1
|
||||
else:
|
||||
low_byte1 = (text_length & 0x7F) | 0x80
|
||||
high_byte1 = (text_length >> 7) & 0xFF
|
||||
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
||||
text_length_field_size1 = 2
|
||||
|
||||
# 计算基础长度字段
|
||||
base_length = text_length + 0x2A
|
||||
if base_length < 128:
|
||||
text_length_field = format(base_length, '02x')
|
||||
text_length_field_size = 1
|
||||
else:
|
||||
low_byte = (base_length & 0x7F) | 0x80
|
||||
high_byte = (base_length >> 7) & 0xFF
|
||||
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
||||
text_length_field_size = 2
|
||||
|
||||
# 计算总消息长度
|
||||
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
||||
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
||||
|
||||
# 构造十六进制字符串
|
||||
model_name_bytes = model_name.encode('utf-8')
|
||||
model_name_length_hex = format(len(model_name_bytes), '02X')
|
||||
model_name_hex = model_name_bytes.hex().upper()
|
||||
|
||||
hex_string = (
|
||||
f"{message_total_length:010x}"
|
||||
"12"
|
||||
f"{text_length_field}"
|
||||
"0A"
|
||||
f"{text_length_field1}"
|
||||
f"{text_bytes.hex()}"
|
||||
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
||||
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
||||
f"{model_name_length_hex}"
|
||||
f"{model_name_hex}"
|
||||
"22004A"
|
||||
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
||||
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
||||
"800101B00100C00100E00100E80100"
|
||||
).upper()
|
||||
|
||||
return bytes.fromhex(hex_string)
|
||||
|
||||
|
||||
def chunk_to_utf8_string(chunk: bytes) -> str:
|
||||
"""将二进制chunk转换为UTF-8字符串"""
|
||||
if not chunk or len(chunk) < 2:
|
||||
return ''
|
||||
|
||||
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
||||
return ''
|
||||
|
||||
# 记录原始chunk的十六进制(调试用)
|
||||
print(f"chunk length: {len(chunk)}")
|
||||
# print(f"chunk hex: {chunk.hex()}")
|
||||
|
||||
try:
|
||||
# 去掉0x0A之前的所有字节
|
||||
try:
|
||||
chunk = chunk[chunk.index(0x0A) + 1:]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
filtered_chunk = bytearray()
|
||||
i = 0
|
||||
while i < len(chunk):
|
||||
# 检查是否有连续的0x00
|
||||
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
||||
i += 4
|
||||
while i < len(chunk) and chunk[i] <= 0x0F:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if chunk[i] == 0x0C:
|
||||
i += 1
|
||||
while i < len(chunk) and chunk[i] == 0x0A:
|
||||
i += 1
|
||||
else:
|
||||
filtered_chunk.append(chunk[i])
|
||||
i += 1
|
||||
|
||||
# 过滤掉特定字节
|
||||
filtered_chunk = bytes(b for b in filtered_chunk
|
||||
if b != 0x00 and b != 0x0C)
|
||||
|
||||
if not filtered_chunk:
|
||||
return ''
|
||||
|
||||
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
||||
# print(f"decoded result: {result}") # 调试输出
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
||||
return ''
|
||||
|
||||
|
||||
async def process_stream(chunks, ):
|
||||
"""处理流式响应"""
|
||||
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
||||
|
||||
# 先将所有chunks读取到列表中
|
||||
# chunks = []
|
||||
# async for chunk in response.aiter_raw():
|
||||
# chunks.append(chunk)
|
||||
|
||||
# 然后处理保存的chunks
|
||||
for chunk in chunks:
|
||||
text = chunk_to_utf8_string(chunk)
|
||||
if text:
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if "<|END_USER|>" in text:
|
||||
text = text.split("<|END_USER|>")[-1].strip()
|
||||
if text and text[0].isalpha():
|
||||
text = text[1:].strip()
|
||||
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
||||
|
||||
if text: # 确保清理后的文本不为空
|
||||
data_body = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": text
|
||||
}
|
||||
}]
|
||||
}
|
||||
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
||||
# yield "data: {\n"
|
||||
# yield f' "id": "{response_id}",\n'
|
||||
# yield ' "object": "chat.completion.chunk",\n'
|
||||
# yield f' "created": {int(time.time())},\n'
|
||||
# yield ' "choices": [{\n'
|
||||
# yield ' "index": 0,\n'
|
||||
# yield ' "delta": {\n'
|
||||
# yield f' "content": "{text}"\n'
|
||||
# yield " }\n"
|
||||
# yield " }]\n"
|
||||
# yield "}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, chat_request: ChatRequest):
|
||||
# 验证o1模型不支持流式输出
|
||||
if chat_request.model.startswith('o1-') and chat_request.stream:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model not supported stream"
|
||||
)
|
||||
|
||||
# 获取并处理认证令牌
|
||||
auth_header = request.headers.get('authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authorization header"
|
||||
)
|
||||
|
||||
auth_token = auth_header.replace('Bearer ', '')
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization token"
|
||||
)
|
||||
|
||||
# 处理多个密钥
|
||||
keys = [key.strip() for key in auth_token.split(',')]
|
||||
if keys:
|
||||
auth_token = keys[0] # 使用第一个密钥
|
||||
|
||||
if '%3A%3A' in auth_token:
|
||||
auth_token = auth_token.split('%3A%3A')[1]
|
||||
|
||||
# 格式化消息
|
||||
formatted_messages = "\n".join(
|
||||
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
||||
)
|
||||
|
||||
# 生成请求数据
|
||||
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
'Content-Type': 'application/connect+proto',
|
||||
'Authorization': f'Bearer {auth_token}',
|
||||
'Connect-Accept-Encoding': 'gzip,br',
|
||||
'Connect-Protocol-Version': '1',
|
||||
'User-Agent': 'connect-es/1.4.0',
|
||||
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
||||
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
||||
'X-Cursor-Client-Version': '0.42.3',
|
||||
'X-Cursor-Timezone': 'Asia/Shanghai',
|
||||
'X-Ghost-Mode': 'false',
|
||||
'X-Request-Id': str(uuid.uuid4()),
|
||||
'Host': 'api2.cursor.sh'
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||||
try:
|
||||
# 使用 stream=True 参数
|
||||
# 打印 headers 和 二进制 data
|
||||
print(f"headers: {headers}")
|
||||
print(hex_data)
|
||||
async with client.stream(
|
||||
'POST',
|
||||
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
||||
headers=headers,
|
||||
content=hex_data,
|
||||
timeout=None
|
||||
) as response:
|
||||
if chat_request.stream:
|
||||
chunks = []
|
||||
async for chunk in response.aiter_raw():
|
||||
chunks.append(chunk)
|
||||
return StreamingResponse(
|
||||
process_stream(chunks),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# 非流式响应处理
|
||||
text = ''
|
||||
async for chunk in response.aiter_raw():
|
||||
# print('chunk:', chunk.hex())
|
||||
print('chunk length:', len(chunk))
|
||||
|
||||
res = chunk_to_utf8_string(chunk)
|
||||
# print('res:', res)
|
||||
if res:
|
||||
text += res
|
||||
|
||||
# 清理响应文本
|
||||
import re
|
||||
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
||||
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": chat_request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/models")
|
||||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"object": "model",
|
||||
"created": 1713744000,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3-opus",
|
||||
"object": "model",
|
||||
"created": 1709251200,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-haiku",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "cursor-small",
|
||||
"object": "model",
|
||||
"created": 1712534400,
|
||||
"owned_by": "cursor"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"created": 1677649200,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"created": 1687392000,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4-turbo-2024-04-09",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-preview",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("PORT", "3001"))
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|
@@ -7,6 +7,7 @@ use axum::{
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use std::error::Error;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -262,7 +263,16 @@ async fn chat_completions(
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(300))
|
||||
.build()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
.map_err(|e| {
|
||||
tracing::error!("创建HTTP客户端失败: {:?}", e);
|
||||
tracing::error!(error = %e, "错误详情");
|
||||
|
||||
if let Some(source) = e.source() {
|
||||
tracing::error!(source = %source, "错误源");
|
||||
}
|
||||
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let response = client
|
||||
.post("https://api2.cursor.sh/aiserver.v1.AiService/StreamChat")
|
||||
@@ -270,7 +280,32 @@ async fn chat_completions(
|
||||
.body(hex_data)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
.map_err(|e| {
|
||||
tracing::error!("请求失败: {:?}", e);
|
||||
tracing::error!(error = %e, "错误详情");
|
||||
|
||||
// 如果是超时错误
|
||||
if e.is_timeout() {
|
||||
tracing::error!("请求超时");
|
||||
}
|
||||
|
||||
// 如果是连接错误
|
||||
if e.is_connect() {
|
||||
tracing::error!("连接失败");
|
||||
}
|
||||
|
||||
// 如果有请求信息
|
||||
if let Some(url) = e.url() {
|
||||
tracing::error!(url = %url, "请求URL");
|
||||
}
|
||||
|
||||
// 如果有状态码
|
||||
if let Some(status) = e.status() {
|
||||
tracing::error!(status = %status, "HTTP状态码");
|
||||
}
|
||||
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
if chat_request.stream {
|
||||
let mut chunks = Vec::new();
|
||||
@@ -347,7 +382,72 @@ async fn models() -> Json<serde_json::Value> {
|
||||
"created": 1713744000,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
// ... 其他模型
|
||||
{
|
||||
"id": "claude-3-opus",
|
||||
"object": "model",
|
||||
"created": 1709251200,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-haiku",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "cursor-small",
|
||||
"object": "model",
|
||||
"created": 1712534400,
|
||||
"owned_by": "cursor"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"created": 1677649200,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"created": 1687392000,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4-turbo-2024-04-09",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-preview",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
}))
|
||||
}
|
||||
|
Reference in New Issue
Block a user