mirror of
https://github.com/zeke-chin/cursor-api.git
synced 2025-09-29 12:52:48 +08:00
update: github workflow
This commit is contained in:
1
.github/workflows/build-rs-capi.yml
vendored
1
.github/workflows/build-rs-capi.yml
vendored
@@ -40,7 +40,6 @@ jobs:
|
|||||||
override: true
|
override: true
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: matrix.target == 'aarch64-unknown-linux-gnu'
|
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y gcc-aarch64-linux-gnu g++-aarch64-linux-gnu pkg-config libssl-dev gcc-multilib crossbuild-essential-arm64 musl-tools
|
sudo apt-get install -y gcc-aarch64-linux-gnu g++-aarch64-linux-gnu pkg-config libssl-dev gcc-multilib crossbuild-essential-arm64 musl-tools
|
||||||
|
@@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
```
|
```
|
||||||
docker run xxxx -p 3000:3000 ghcr.io/xxxx/rs-capi:latest
|
docker run --rm -p 7070:3000 ghcr.io/xxxx/rs-capi:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
docker-compose
|
docker-compose
|
||||||
|
@@ -16,9 +16,9 @@ RUN cargo build --bin rs-capi --release
|
|||||||
|
|
||||||
|
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04
|
||||||
# RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
# ca-certificates \
|
ca-certificates \
|
||||||
# && rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY --from=builder /workspace/target/release/rs-capi /workspace/rs-capi
|
COPY --from=builder /workspace/target/release/rs-capi /workspace/rs-capi
|
||||||
|
|
||||||
|
822
rs-capi/main.py
822
rs-capi/main.py
@@ -1,822 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Response, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
import uuid
|
|
||||||
import httpx
|
|
||||||
import os
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import time
|
|
||||||
import re
|
|
||||||
|
|
||||||
# 加载环境变量
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
# 添加CORS中间件
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 定义请求模型
|
|
||||||
class Message(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
messages: List[Message]
|
|
||||||
stream: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_hex(text: str, model_name: str) -> bytes:
|
|
||||||
"""将文本转换为特定格式的十六进制数据"""
|
|
||||||
# 将输入文本转换为UTF-8字节
|
|
||||||
text_bytes = text.encode('utf-8')
|
|
||||||
text_length = len(text_bytes)
|
|
||||||
|
|
||||||
# 固定常量
|
|
||||||
FIXED_HEADER = 2
|
|
||||||
SEPARATOR = 1
|
|
||||||
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
|
||||||
|
|
||||||
# 计算第一个长度字段
|
|
||||||
if text_length < 128:
|
|
||||||
text_length_field1 = format(text_length, '02x')
|
|
||||||
text_length_field_size1 = 1
|
|
||||||
else:
|
|
||||||
low_byte1 = (text_length & 0x7F) | 0x80
|
|
||||||
high_byte1 = (text_length >> 7) & 0xFF
|
|
||||||
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
|
||||||
text_length_field_size1 = 2
|
|
||||||
|
|
||||||
# 计算基础长度字段
|
|
||||||
base_length = text_length + 0x2A
|
|
||||||
if base_length < 128:
|
|
||||||
text_length_field = format(base_length, '02x')
|
|
||||||
text_length_field_size = 1
|
|
||||||
else:
|
|
||||||
low_byte = (base_length & 0x7F) | 0x80
|
|
||||||
high_byte = (base_length >> 7) & 0xFF
|
|
||||||
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
|
||||||
text_length_field_size = 2
|
|
||||||
|
|
||||||
# 计算总消息长度
|
|
||||||
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
|
||||||
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
|
||||||
|
|
||||||
# 构造十六进制字符串
|
|
||||||
model_name_bytes = model_name.encode('utf-8')
|
|
||||||
model_name_length_hex = format(len(model_name_bytes), '02X')
|
|
||||||
model_name_hex = model_name_bytes.hex().upper()
|
|
||||||
|
|
||||||
hex_string = (
|
|
||||||
f"{message_total_length:010x}"
|
|
||||||
"12"
|
|
||||||
f"{text_length_field}"
|
|
||||||
"0A"
|
|
||||||
f"{text_length_field1}"
|
|
||||||
f"{text_bytes.hex()}"
|
|
||||||
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
|
||||||
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
|
||||||
f"{model_name_length_hex}"
|
|
||||||
f"{model_name_hex}"
|
|
||||||
"22004A"
|
|
||||||
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
|
||||||
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
|
||||||
"800101B00100C00100E00100E80100"
|
|
||||||
).upper()
|
|
||||||
|
|
||||||
return bytes.fromhex(hex_string)
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_to_utf8_string(chunk: bytes) -> str:
|
|
||||||
"""将二进制chunk转换为UTF-8字符串"""
|
|
||||||
if not chunk or len(chunk) < 2:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
|
||||||
return ''
|
|
||||||
|
|
||||||
# 记录原始chunk的十六进制(调试用)
|
|
||||||
print(f"chunk length: {len(chunk)}")
|
|
||||||
# print(f"chunk hex: {chunk.hex()}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 去掉0x0A之前的所有字节
|
|
||||||
try:
|
|
||||||
chunk = chunk[chunk.index(0x0A) + 1:]
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
filtered_chunk = bytearray()
|
|
||||||
i = 0
|
|
||||||
while i < len(chunk):
|
|
||||||
# 检查是否有连续的0x00
|
|
||||||
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
|
||||||
i += 4
|
|
||||||
while i < len(chunk) and chunk[i] <= 0x0F:
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
if chunk[i] == 0x0C:
|
|
||||||
i += 1
|
|
||||||
while i < len(chunk) and chunk[i] == 0x0A:
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
filtered_chunk.append(chunk[i])
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
# 过滤掉特定字节
|
|
||||||
filtered_chunk = bytes(b for b in filtered_chunk
|
|
||||||
if b != 0x00 and b != 0x0C)
|
|
||||||
|
|
||||||
if not filtered_chunk:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
|
||||||
# print(f"decoded result: {result}") # 调试输出
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|
||||||
async def process_stream(chunks, ):
|
|
||||||
"""处理流式响应"""
|
|
||||||
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
|
||||||
|
|
||||||
# 先将所有chunks读取到列表中
|
|
||||||
# chunks = []
|
|
||||||
# async for chunk in response.aiter_raw():
|
|
||||||
# chunks.append(chunk)
|
|
||||||
|
|
||||||
# 然后处理保存的chunks
|
|
||||||
for chunk in chunks:
|
|
||||||
text = chunk_to_utf8_string(chunk)
|
|
||||||
if text:
|
|
||||||
# 清理文本
|
|
||||||
text = text.strip()
|
|
||||||
if "<|END_USER|>" in text:
|
|
||||||
text = text.split("<|END_USER|>")[-1].strip()
|
|
||||||
if text and text[0].isalpha():
|
|
||||||
text = text[1:].strip()
|
|
||||||
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
|
||||||
|
|
||||||
if text: # 确保清理后的文本不为空
|
|
||||||
data_body = {
|
|
||||||
"id": response_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"content": text
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
|
||||||
# yield "data: {\n"
|
|
||||||
# yield f' "id": "{response_id}",\n'
|
|
||||||
# yield ' "object": "chat.completion.chunk",\n'
|
|
||||||
# yield f' "created": {int(time.time())},\n'
|
|
||||||
# yield ' "choices": [{\n'
|
|
||||||
# yield ' "index": 0,\n'
|
|
||||||
# yield ' "delta": {\n'
|
|
||||||
# yield f' "content": "{text}"\n'
|
|
||||||
# yield " }\n"
|
|
||||||
# yield " }]\n"
|
|
||||||
# yield "}\n\n"
|
|
||||||
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
async def chat_completions(request: Request, chat_request: ChatRequest):
|
|
||||||
# 验证o1模型不支持流式输出
|
|
||||||
if chat_request.model.startswith('o1-') and chat_request.stream:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Model not supported stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取并处理认证令牌
|
|
||||||
auth_header = request.headers.get('authorization', '')
|
|
||||||
if not auth_header.startswith('Bearer '):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Invalid authorization header"
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_token = auth_header.replace('Bearer ', '')
|
|
||||||
if not auth_token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Missing authorization token"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理多个密钥
|
|
||||||
keys = [key.strip() for key in auth_token.split(',')]
|
|
||||||
if keys:
|
|
||||||
auth_token = keys[0] # 使用第一个密钥
|
|
||||||
|
|
||||||
if '%3A%3A' in auth_token:
|
|
||||||
auth_token = auth_token.split('%3A%3A')[1]
|
|
||||||
|
|
||||||
# 格式化消息
|
|
||||||
formatted_messages = "\n".join(
|
|
||||||
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
|
||||||
)
|
|
||||||
|
|
||||||
# 生成请求数据
|
|
||||||
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
|
||||||
|
|
||||||
# 准备请求头
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/connect+proto',
|
|
||||||
'Authorization': f'Bearer {auth_token}',
|
|
||||||
'Connect-Accept-Encoding': 'gzip,br',
|
|
||||||
'Connect-Protocol-Version': '1',
|
|
||||||
'User-Agent': 'connect-es/1.4.0',
|
|
||||||
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
|
||||||
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
|
||||||
'X-Cursor-Client-Version': '0.42.3',
|
|
||||||
'X-Cursor-Timezone': 'Asia/Shanghai',
|
|
||||||
'X-Ghost-Mode': 'false',
|
|
||||||
'X-Request-Id': str(uuid.uuid4()),
|
|
||||||
'Host': 'api2.cursor.sh'
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
||||||
try:
|
|
||||||
# 使用 stream=True 参数
|
|
||||||
# 打印 headers 和 二进制 data
|
|
||||||
print(f"headers: {headers}")
|
|
||||||
print(hex_data)
|
|
||||||
async with client.stream(
|
|
||||||
'POST',
|
|
||||||
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
|
||||||
headers=headers,
|
|
||||||
content=hex_data,
|
|
||||||
timeout=None
|
|
||||||
) as response:
|
|
||||||
if chat_request.stream:
|
|
||||||
chunks = []
|
|
||||||
async for chunk in response.aiter_raw():
|
|
||||||
chunks.append(chunk)
|
|
||||||
return StreamingResponse(
|
|
||||||
process_stream(chunks),
|
|
||||||
media_type="text/event-stream"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 非流式响应处理
|
|
||||||
text = ''
|
|
||||||
async for chunk in response.aiter_raw():
|
|
||||||
# print('chunk:', chunk.hex())
|
|
||||||
print('chunk length:', len(chunk))
|
|
||||||
|
|
||||||
res = chunk_to_utf8_string(chunk)
|
|
||||||
# print('res:', res)
|
|
||||||
if res:
|
|
||||||
text += res
|
|
||||||
|
|
||||||
# 清理响应文本
|
|
||||||
import re
|
|
||||||
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
|
||||||
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
|
||||||
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": chat_request.model,
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": text
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Internal server error"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/models")
|
|
||||||
async def models():
|
|
||||||
return {
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"id": "claude-3-5-sonnet-20241022",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1713744000,
|
|
||||||
"owned_by": "anthropic"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-3-opus",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1709251200,
|
|
||||||
"owned_by": "anthropic"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-3.5-haiku",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1711929600,
|
|
||||||
"owned_by": "anthropic"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-3.5-sonnet",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1711929600,
|
|
||||||
"owned_by": "anthropic"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cursor-small",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712534400,
|
|
||||||
"owned_by": "cursor"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-3.5-turbo",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1677649200,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-4",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1687392000,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-4-turbo-2024-04-09",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-4o",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-4o-mini",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "o1-mini",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "o1-preview",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
port = int(os.getenv("PORT", "3001"))
|
|
||||||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Response, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
import uuid
|
|
||||||
import httpx
|
|
||||||
import os
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import time
|
|
||||||
import re
|
|
||||||
|
|
||||||
# 加载环境变量
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
# 添加CORS中间件
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 定义请求模型
|
|
||||||
class Message(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
messages: List[Message]
|
|
||||||
stream: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_hex(text: str, model_name: str) -> bytes:
|
|
||||||
"""将文本转换为特定格式的十六进制数据"""
|
|
||||||
# 将输入文本转换为UTF-8字节
|
|
||||||
text_bytes = text.encode('utf-8')
|
|
||||||
text_length = len(text_bytes)
|
|
||||||
|
|
||||||
# 固定常量
|
|
||||||
FIXED_HEADER = 2
|
|
||||||
SEPARATOR = 1
|
|
||||||
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
|
||||||
|
|
||||||
# 计算第一个长度字段
|
|
||||||
if text_length < 128:
|
|
||||||
text_length_field1 = format(text_length, '02x')
|
|
||||||
text_length_field_size1 = 1
|
|
||||||
else:
|
|
||||||
low_byte1 = (text_length & 0x7F) | 0x80
|
|
||||||
high_byte1 = (text_length >> 7) & 0xFF
|
|
||||||
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
|
||||||
text_length_field_size1 = 2
|
|
||||||
|
|
||||||
# 计算基础长度字段
|
|
||||||
base_length = text_length + 0x2A
|
|
||||||
if base_length < 128:
|
|
||||||
text_length_field = format(base_length, '02x')
|
|
||||||
text_length_field_size = 1
|
|
||||||
else:
|
|
||||||
low_byte = (base_length & 0x7F) | 0x80
|
|
||||||
high_byte = (base_length >> 7) & 0xFF
|
|
||||||
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
|
||||||
text_length_field_size = 2
|
|
||||||
|
|
||||||
# 计算总消息长度
|
|
||||||
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
|
||||||
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
|
||||||
|
|
||||||
# 构造十六进制字符串
|
|
||||||
model_name_bytes = model_name.encode('utf-8')
|
|
||||||
model_name_length_hex = format(len(model_name_bytes), '02X')
|
|
||||||
model_name_hex = model_name_bytes.hex().upper()
|
|
||||||
|
|
||||||
hex_string = (
|
|
||||||
f"{message_total_length:010x}"
|
|
||||||
"12"
|
|
||||||
f"{text_length_field}"
|
|
||||||
"0A"
|
|
||||||
f"{text_length_field1}"
|
|
||||||
f"{text_bytes.hex()}"
|
|
||||||
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
|
||||||
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
|
||||||
f"{model_name_length_hex}"
|
|
||||||
f"{model_name_hex}"
|
|
||||||
"22004A"
|
|
||||||
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
|
||||||
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
|
||||||
"800101B00100C00100E00100E80100"
|
|
||||||
).upper()
|
|
||||||
|
|
||||||
return bytes.fromhex(hex_string)
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_to_utf8_string(chunk: bytes) -> str:
|
|
||||||
"""将二进制chunk转换为UTF-8字符串"""
|
|
||||||
if not chunk or len(chunk) < 2:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
|
||||||
return ''
|
|
||||||
|
|
||||||
# 记录原始chunk的十六进制(调试用)
|
|
||||||
print(f"chunk length: {len(chunk)}")
|
|
||||||
# print(f"chunk hex: {chunk.hex()}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 去掉0x0A之前的所有字节
|
|
||||||
try:
|
|
||||||
chunk = chunk[chunk.index(0x0A) + 1:]
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
filtered_chunk = bytearray()
|
|
||||||
i = 0
|
|
||||||
while i < len(chunk):
|
|
||||||
# 检查是否有连续的0x00
|
|
||||||
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
|
||||||
i += 4
|
|
||||||
while i < len(chunk) and chunk[i] <= 0x0F:
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
if chunk[i] == 0x0C:
|
|
||||||
i += 1
|
|
||||||
while i < len(chunk) and chunk[i] == 0x0A:
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
filtered_chunk.append(chunk[i])
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
# 过滤掉特定字节
|
|
||||||
filtered_chunk = bytes(b for b in filtered_chunk
|
|
||||||
if b != 0x00 and b != 0x0C)
|
|
||||||
|
|
||||||
if not filtered_chunk:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
|
||||||
# print(f"decoded result: {result}") # 调试输出
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|
||||||
async def process_stream(chunks, ):
|
|
||||||
"""处理流式响应"""
|
|
||||||
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
|
||||||
|
|
||||||
# 先将所有chunks读取到列表中
|
|
||||||
# chunks = []
|
|
||||||
# async for chunk in response.aiter_raw():
|
|
||||||
# chunks.append(chunk)
|
|
||||||
|
|
||||||
# 然后处理保存的chunks
|
|
||||||
for chunk in chunks:
|
|
||||||
text = chunk_to_utf8_string(chunk)
|
|
||||||
if text:
|
|
||||||
# 清理文本
|
|
||||||
text = text.strip()
|
|
||||||
if "<|END_USER|>" in text:
|
|
||||||
text = text.split("<|END_USER|>")[-1].strip()
|
|
||||||
if text and text[0].isalpha():
|
|
||||||
text = text[1:].strip()
|
|
||||||
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
|
||||||
|
|
||||||
if text: # 确保清理后的文本不为空
|
|
||||||
data_body = {
|
|
||||||
"id": response_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"content": text
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
|
||||||
# yield "data: {\n"
|
|
||||||
# yield f' "id": "{response_id}",\n'
|
|
||||||
# yield ' "object": "chat.completion.chunk",\n'
|
|
||||||
# yield f' "created": {int(time.time())},\n'
|
|
||||||
# yield ' "choices": [{\n'
|
|
||||||
# yield ' "index": 0,\n'
|
|
||||||
# yield ' "delta": {\n'
|
|
||||||
# yield f' "content": "{text}"\n'
|
|
||||||
# yield " }\n"
|
|
||||||
# yield " }]\n"
|
|
||||||
# yield "}\n\n"
|
|
||||||
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
async def chat_completions(request: Request, chat_request: ChatRequest):
|
|
||||||
# 验证o1模型不支持流式输出
|
|
||||||
if chat_request.model.startswith('o1-') and chat_request.stream:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Model not supported stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取并处理认证令牌
|
|
||||||
auth_header = request.headers.get('authorization', '')
|
|
||||||
if not auth_header.startswith('Bearer '):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Invalid authorization header"
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_token = auth_header.replace('Bearer ', '')
|
|
||||||
if not auth_token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Missing authorization token"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理多个密钥
|
|
||||||
keys = [key.strip() for key in auth_token.split(',')]
|
|
||||||
if keys:
|
|
||||||
auth_token = keys[0] # 使用第一个密钥
|
|
||||||
|
|
||||||
if '%3A%3A' in auth_token:
|
|
||||||
auth_token = auth_token.split('%3A%3A')[1]
|
|
||||||
|
|
||||||
# 格式化消息
|
|
||||||
formatted_messages = "\n".join(
|
|
||||||
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
|
||||||
)
|
|
||||||
|
|
||||||
# 生成请求数据
|
|
||||||
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
|
||||||
|
|
||||||
# 准备请求头
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/connect+proto',
|
|
||||||
'Authorization': f'Bearer {auth_token}',
|
|
||||||
'Connect-Accept-Encoding': 'gzip,br',
|
|
||||||
'Connect-Protocol-Version': '1',
|
|
||||||
'User-Agent': 'connect-es/1.4.0',
|
|
||||||
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
|
||||||
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
|
||||||
'X-Cursor-Client-Version': '0.42.3',
|
|
||||||
'X-Cursor-Timezone': 'Asia/Shanghai',
|
|
||||||
'X-Ghost-Mode': 'false',
|
|
||||||
'X-Request-Id': str(uuid.uuid4()),
|
|
||||||
'Host': 'api2.cursor.sh'
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
||||||
try:
|
|
||||||
# 使用 stream=True 参数
|
|
||||||
# 打印 headers 和 二进制 data
|
|
||||||
print(f"headers: {headers}")
|
|
||||||
print(hex_data)
|
|
||||||
async with client.stream(
|
|
||||||
'POST',
|
|
||||||
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
|
||||||
headers=headers,
|
|
||||||
content=hex_data,
|
|
||||||
timeout=None
|
|
||||||
) as response:
|
|
||||||
if chat_request.stream:
|
|
||||||
chunks = []
|
|
||||||
async for chunk in response.aiter_raw():
|
|
||||||
chunks.append(chunk)
|
|
||||||
return StreamingResponse(
|
|
||||||
process_stream(chunks),
|
|
||||||
media_type="text/event-stream"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 非流式响应处理
|
|
||||||
text = ''
|
|
||||||
async for chunk in response.aiter_raw():
|
|
||||||
# print('chunk:', chunk.hex())
|
|
||||||
print('chunk length:', len(chunk))
|
|
||||||
|
|
||||||
res = chunk_to_utf8_string(chunk)
|
|
||||||
# print('res:', res)
|
|
||||||
if res:
|
|
||||||
text += res
|
|
||||||
|
|
||||||
# 清理响应文本
|
|
||||||
import re
|
|
||||||
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
|
||||||
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
|
||||||
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": chat_request.model,
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": text
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Internal server error"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/models")
|
|
||||||
async def models():
|
|
||||||
return {
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"id": "claude-3-5-sonnet-20241022",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1713744000,
|
|
||||||
"owned_by": "anthropic"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-3-opus",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1709251200,
|
|
||||||
"owned_by": "anthropic"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-3.5-haiku",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1711929600,
|
|
||||||
"owned_by": "anthropic"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-3.5-sonnet",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1711929600,
|
|
||||||
"owned_by": "anthropic"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cursor-small",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712534400,
|
|
||||||
"owned_by": "cursor"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-3.5-turbo",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1677649200,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-4",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1687392000,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-4-turbo-2024-04-09",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-4o",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-4o-mini",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "o1-mini",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "o1-preview",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1712620800,
|
|
||||||
"owned_by": "openai"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
port = int(os.getenv("PORT", "3001"))
|
|
||||||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|
|
@@ -7,6 +7,7 @@ use axum::{
|
|||||||
routing::post,
|
routing::post,
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
|
use std::error::Error;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
@@ -262,7 +263,16 @@ async fn chat_completions(
|
|||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(Duration::from_secs(300))
|
.timeout(Duration::from_secs(300))
|
||||||
.build()
|
.build()
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
.map_err(|e| {
|
||||||
|
tracing::error!("创建HTTP客户端失败: {:?}", e);
|
||||||
|
tracing::error!(error = %e, "错误详情");
|
||||||
|
|
||||||
|
if let Some(source) = e.source() {
|
||||||
|
tracing::error!(source = %source, "错误源");
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
let response = client
|
let response = client
|
||||||
.post("https://api2.cursor.sh/aiserver.v1.AiService/StreamChat")
|
.post("https://api2.cursor.sh/aiserver.v1.AiService/StreamChat")
|
||||||
@@ -270,7 +280,32 @@ async fn chat_completions(
|
|||||||
.body(hex_data)
|
.body(hex_data)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
.map_err(|e| {
|
||||||
|
tracing::error!("请求失败: {:?}", e);
|
||||||
|
tracing::error!(error = %e, "错误详情");
|
||||||
|
|
||||||
|
// 如果是超时错误
|
||||||
|
if e.is_timeout() {
|
||||||
|
tracing::error!("请求超时");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是连接错误
|
||||||
|
if e.is_connect() {
|
||||||
|
tracing::error!("连接失败");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有请求信息
|
||||||
|
if let Some(url) = e.url() {
|
||||||
|
tracing::error!(url = %url, "请求URL");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有状态码
|
||||||
|
if let Some(status) = e.status() {
|
||||||
|
tracing::error!(status = %status, "HTTP状态码");
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
})?;
|
||||||
|
|
||||||
if chat_request.stream {
|
if chat_request.stream {
|
||||||
let mut chunks = Vec::new();
|
let mut chunks = Vec::new();
|
||||||
@@ -347,7 +382,72 @@ async fn models() -> Json<serde_json::Value> {
|
|||||||
"created": 1713744000,
|
"created": 1713744000,
|
||||||
"owned_by": "anthropic"
|
"owned_by": "anthropic"
|
||||||
},
|
},
|
||||||
// ... 其他模型
|
{
|
||||||
|
"id": "claude-3-opus",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1709251200,
|
||||||
|
"owned_by": "anthropic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "claude-3.5-haiku",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1711929600,
|
||||||
|
"owned_by": "anthropic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "claude-3.5-sonnet",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1711929600,
|
||||||
|
"owned_by": "anthropic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "cursor-small",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1712534400,
|
||||||
|
"owned_by": "cursor"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-3.5-turbo",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1677649200,
|
||||||
|
"owned_by": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-4",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1687392000,
|
||||||
|
"owned_by": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-4-turbo-2024-04-09",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1712620800,
|
||||||
|
"owned_by": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-4o",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1712620800,
|
||||||
|
"owned_by": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-4o-mini",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1712620800,
|
||||||
|
"owned_by": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "o1-mini",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1712620800,
|
||||||
|
"owned_by": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "o1-preview",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1712620800,
|
||||||
|
"owned_by": "openai"
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user