修复一些bug

This commit is contained in:
wisdgod
2024-12-25 07:02:50 +08:00
parent 3ab975a5b3
commit ea1acb555f
8 changed files with 384 additions and 223 deletions

View File

@@ -2,9 +2,9 @@ name: Docker Build and Push
on: on:
workflow_dispatch: workflow_dispatch:
# push: push:
# tags: tags:
# - 'v*' - 'v*'
env: env:
IMAGE_NAME: ${{ github.repository_owner }}/cursor-api IMAGE_NAME: ${{ github.repository_owner }}/cursor-api

2
Cargo.lock generated
View File

@@ -268,7 +268,7 @@ dependencies = [
[[package]] [[package]]
name = "cursor-api" name = "cursor-api"
version = "0.1.0" version = "0.1.1"
dependencies = [ dependencies = [
"axum", "axum",
"base64", "base64",

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "cursor-api" name = "cursor-api"
version = "0.1.0" version = "0.1.1"
edition = "2021" edition = "2021"
authors = ["wisdgod <nav@wisdgod.com>"] authors = ["wisdgod <nav@wisdgod.com>"]
@@ -28,13 +28,12 @@ tokio-stream = { version = "0.1.17", features = ["time"] }
tower-http = { version = "0.6.2", features = ["cors"] } tower-http = { version = "0.6.2", features = ["cors"] }
uuid = { version = "1.11.0", features = ["v4"] } uuid = { version = "1.11.0", features = ["v4"] }
# 优化设置
[profile.release] [profile.release]
lto = true # 启用链接时优化 lto = true
codegen-units = 1 # 减少并行编译单元以提高优化 codegen-units = 1
panic = 'abort' # 在 panic 时直接终止,减小二进制大小 panic = 'abort'
strip = true # 移除调试符号 strip = true
opt-level = 3 # 最高优化级别 opt-level = 3
# 构建脚本设置 # 构建脚本设置
[package.metadata.cross.target.x86_64-unknown-linux-gnu] [package.metadata.cross.target.x86_64-unknown-linux-gnu]

14
Cross.toml Normal file
View File

@@ -0,0 +1,14 @@
[target.x86_64-unknown-linux-gnu]
pre-build = [
"set -e",
"apt-get update",
"apt-get install -y --no-install-recommends build-essential protobuf-compiler pkg-config libssl-dev nodejs npm",
"rm -rf /var/lib/apt/lists/*"
]
[target.x86_64-unknown-freebsd]
pre-build = [
"pkg update",
"pkg install -y node20 www/npm protobuf ca_root_nss bash gmake pkgconf openssl",
"export SSL_CERT_FILE=/etc/ssl/cert.pem"
]

150
README.md
View File

@@ -4,7 +4,7 @@
1. 访问 [www.cursor.com](https://www.cursor.com) 并完成注册登录(赠送 250 次快速响应,可通过删除账号再注册重置) 1. 访问 [www.cursor.com](https://www.cursor.com) 并完成注册登录(赠送 250 次快速响应,可通过删除账号再注册重置)
2. 在浏览器中打开开发者工具F12 2. 在浏览器中打开开发者工具F12
3. 找到 应用-Cookies 中名为 `WorkosCursorSessionToken` 的值并保存(相当于 openai 的密钥) 3. 找到 Application-Cookies 中名为 `WorkosCursorSessionToken` 的值并保存(相当于 openai 的密钥)
## 接口说明 ## 接口说明
@@ -12,41 +12,53 @@
- 接口地址:`/v1/chat/completions` - 接口地址:`/v1/chat/completions`
- 请求方法POST - 请求方法POST
- 认证方式Bearer Token(支持两种认证方式) - 认证方式Bearer Token
1. 使用环境变量 `AUTH_TOKEN` 进行认证 1. 使用环境变量 `AUTH_TOKEN` 进行认证
2. 使用 `.token` 文件中的令牌列表进行轮询认证 2. 使用 `.token` 文件中的令牌列表进行轮询认证
### 获取模型列表 ### Token管理接口
#### 简易Token信息管理页面
- 接口地址:`/tokeninfo`
- 请求方法GET
- 响应格式HTML页面
- 功能:获取 .token 和 .token-list 文件内容,并允许用户方便地使用 API 修改文件内容
#### 更新Token信息
- 接口地址:`/update-tokeninfo`
- 请求方法GET
- 认证方式:不需要
- 功能:请求内容不包括文件内容,直接修改文件,调用重载函数
#### 更新Token信息
- 接口地址:`/update-tokeninfo`
- 请求方法POST
- 认证方式Bearer Token
- 功能:请求内容包括文件内容,间接修改文件,调用重载函数
#### 获取Token信息
- 接口地址:`/get-tokeninfo`
- 请求方法POST
- 认证方式Bearer Token
### 其他接口
#### 获取模型列表
- 接口地址:`/v1/models` - 接口地址:`/v1/models`
- 请求方法GET - 请求方法GET
### 获取环境变量中的x-cursor-checksum #### 获取随机x-cursor-checksum
- 接口地址:`/env-checksum`
- 请求方法GET
### 获取随机x-cursor-checksum
- 接口地址:`/checksum` - 接口地址:`/checksum`
- 请求方法GET - 请求方法GET
### 健康检查接口 #### 健康检查接口
- 接口地址:`/` - 接口地址:`/`
- 请求方法GET - 请求方法GET
### 获取日志接口 #### 获取日志接口
- 接口地址:`/logs` - 接口地址:`/logs`
- 请求方法GET - 请求方法GET
### Token管理接口
- 获取Token信息页面`/tokeninfo`
- 更新Token信息`/update-tokeninfo`
- 获取Token信息`/get-tokeninfo`
## 配置说明 ## 配置说明
### 环境变量 ### 环境变量
@@ -60,32 +72,73 @@
### Token文件格式 ### Token文件格式
1. `.token` 文件每行一个token支持以下格式 1. `.token` 文件每行一个token支持以下格式
``` ```
token1 token1
alias::token2 alias::token2
``` ```
alias 可以是任意值,用于区分不同的 token更方便管理WorkosCursorSessionToken 是相同格式
该文件将自动向.token-list文件中追加token同时自动生成checksum
2. `.token-list` 文件每行为token和checksum的对应关系 2. `.token-list` 文件每行为token和checksum的对应关系
``` ```
token1,checksum1 token1,checksum1
token2,checksum2 token2,checksum2
``` ```
该文件可以被自动管理,但用户仅可在确认自己拥有修改能力时修改,一般仅有以下情况需要手动修改:
- 需要删除某个 token
- 需要使用已有 checksum 来对应某一个 token
### 模型列表
写死了,后续也不会会支持自定义模型列表
```
cursor-small
claude-3-opus
cursor-fast
gpt-3.5-turbo
gpt-4-turbo-2024-04-09
gpt-4
gpt-4o-128k
gemini-1.5-flash-500k
claude-3-haiku-200k
claude-3-5-sonnet-200k
claude-3-5-sonnet-20240620
claude-3-5-sonnet-20241022
gpt-4o-mini
o1-mini
o1-preview
o1
claude-3.5-haiku
gemini-exp-1206
gemini-2.0-flash-thinking-exp
gemini-2.0-flash-exp
```
## 部署 ## 部署
### 本地部署 ### 本地部署
#### 从源码编译 #### 从源码编译
需要安装 Rust 工具链和 protobuf 编译器 需要安装 Rust 工具链和依赖
```bash ```bash
# 安装依赖Debian/Ubuntu # 安装rust
apt-get install -y build-essential protobuf-compiler curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
# 编译并运行 # 安装依赖Debian/Ubuntu
apt-get install -y build-essential protobuf-compiler pkg-config libssl-dev nodejs npm
# 原生编译
cargo build --release cargo build --release
./target/release/cursor-api
# 交叉编译以x86_64-unknown-linux-gnu为例老实说这也算原生编译因为使用了docker
cross build --target x86_64-unknown-linux-gnu --release
``` ```
#### 使用预编译二进制 #### 使用预编译二进制
@@ -109,28 +162,48 @@ docker run -p 3000:3000 cursor-api
### huggingface部署 ### huggingface部署
1. duplicate项目: 前提一个huggingface账号
[huggingface链接](https://huggingface.co/login?next=%2Fspaces%2Fstevenrk%2Fcursor%3Fduplicate%3Dtrue)
1. 创建一个Space并创建一个Dockerfile文件内容如下
```Dockerfile
FROM wisdgod/cursor-api:latest
# 可能你要覆盖原镜像的环境变量但都可以在下面的第2步中配置
ENV PORT=7860
```
2. 配置环境变量 2. 配置环境变量
在你的space中点击settings找到`Variables and secrets`添加Variables 在你的 Space 中,点击 Settings找到 `Variables and secrets`,添加 Variables
- name: `AUTH_TOKEN` (注意大写)
- value: 你随意 ```env
# 可选,用于配置服务器端口
PORT=3000
# 必选,用于配置路由前缀,比如/api,/hf,/proxy等等
ROUTE_PREFIX=
# 必选用于API认证
AUTH_TOKEN=
# 可选用于配置token文件路径
TOKEN_FILE=.token
# 可选用于配置token列表文件路径
TOKEN_LIST_FILE=.token-list
```
3. 重新部署 3. 重新部署
点击`Factory rebuild`,等待部署完成 点击`Factory rebuild`,等待部署完成
4. 接口地址(`Embed this Space`中查看): 4. 接口地址(`Embed this Space`中查看):
``` ```
https://{username}-{space-name}.hf.space/v1/models https://{username}-{space-name}.hf.space/v1/models
``` ```
## 注意事项 ## 注意事项
1. 请妥善保管您的 AuthToken不要泄露给他人 1. 请妥善保管您的任何 Token不要泄露给他人。若发现泄露,请及时更改
2. 配置 AUTH_TOKEN 环境变量以增加安全性 2. 请遵守本项目许可证,你仅拥有使用本项目的权利,不得用于商业用途
3. 本项目仅供学习研究使用,请遵守 Cursor 的使用条款 3. 本项目仅供学习研究使用,请遵守 Cursor 的使用条款
## 开发 ## 开发
@@ -147,12 +220,13 @@ docker run -p 3000:3000 cursor-api
./scripts/build.sh --cross ./scripts/build.sh --cross
``` ```
支持的目标平台: 支持的平台:
- x86_64-unknown-linux-gnu
- x86_64-pc-windows-msvc - linux x86_64
- aarch64-unknown-linux-gnu - windows x86_64
- x86_64-apple-darwin - macos x86_64
- aarch64-apple-darwin - freebsd x86_64
- docker (only for linux x86_64)
### 获取token ### 获取token

View File

@@ -22,8 +22,8 @@ check_requirements() {
fi fi
done done
# Linux 特定检查 # cross 工具检查(仅在 Linux 上需要)
if [[ $USE_CROSS == true ]] && ! command -v cross &>/dev/null; then if [[ "$OS" == "Linux" ]] && ! command -v cross &>/dev/null; then
missing_tools+=("cross") missing_tools+=("cross")
fi fi
@@ -46,6 +46,22 @@ show_help() {
EOF EOF
} }
# 判断是否使用 cross
should_use_cross() {
local target=$1
# 如果不是 Linux 环境,直接返回 false
if [[ "$OS" != "Linux" ]]; then
return 1
fi
# 在 Linux 环境下,以下目标不使用 cross
# 1. Linux 上的 x86_64-unknown-linux-gnu
if [[ "$target" == "x86_64-unknown-linux-gnu" ]]; then
return 1
fi
return 0
}
# 并行构建函数 # 并行构建函数
build_target() { build_target() {
local target=$1 local target=$1
@@ -57,42 +73,41 @@ build_target() {
# 确定文件后缀 # 确定文件后缀
[[ $target == *"windows"* ]] && extension=".exe" [[ $target == *"windows"* ]] && extension=".exe"
# 设置目标特定的环境变量 # 判断是否使用 cross
local build_env=() if should_use_cross "$target"; then
if [[ $target == "aarch64-unknown-linux-gnu" ]]; then env RUSTFLAGS="$rustflags" cross build --target "$target" --release
build_env+=(
"CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc"
"CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++"
"CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc"
"PKG_CONFIG_PATH=/usr/lib/aarch64-linux-gnu/pkgconfig"
"PKG_CONFIG_ALLOW_CROSS=1"
"OPENSSL_DIR=/usr"
"OPENSSL_INCLUDE_DIR=/usr/include"
"OPENSSL_LIB_DIR=/usr/lib/aarch64-linux-gnu"
)
fi
# 判断是否使用 cross仅在 Linux 上)
if [[ $target != "$CURRENT_TARGET" ]]; then
env ${build_env[@]+"${build_env[@]}"} RUSTFLAGS="$rustflags" cargo build --target "$target" --release
else else
env ${build_env[@]+"${build_env[@]}"} RUSTFLAGS="$rustflags" cargo build --release if [[ $target != "$CURRENT_TARGET" ]]; then
env RUSTFLAGS="$rustflags" cargo build --target "$target" --release
else
env RUSTFLAGS="$rustflags" cargo build --release
fi
fi fi
# 移动编译产物到 release 目录 # 移动编译产物到 release 目录
local binary_name="cursor-api" local binary_name="cursor-api"
[[ $USE_STATIC == true ]] && binary_name+="-static" [[ $USE_STATIC == true ]] && binary_name+="-static"
if [[ -f "target/$target/release/cursor-api$extension" ]]; then local binary_path
cp "target/$target/release/cursor-api$extension" "release/${binary_name}-$target$extension" if [[ $target == "$CURRENT_TARGET" ]]; then
binary_path="target/release/cursor-api$extension"
else
binary_path="target/$target/release/cursor-api$extension"
fi
if [[ -f "$binary_path" ]]; then
cp "$binary_path" "release/${binary_name}-$target$extension"
info "完成构建 $target" info "完成构建 $target"
else else
warn "构建产物未找到: $target" warn "构建产物未找到: $target"
warn "查找路径: $binary_path"
warn "当前目录内容:"
ls -R target/
return 1 return 1
fi fi
} }
# 获取 CPU 架构 # 获取 CPU 架构和操作系统
ARCH=$(uname -m | sed 's/^aarch64\|arm64$/aarch64/;s/^x86_64\|x86-64\|x64\|amd64$/x86_64/') ARCH=$(uname -m | sed 's/^aarch64\|arm64$/aarch64/;s/^x86_64\|x86-64\|x64\|amd64$/x86_64/')
OS=$(uname -s) OS=$(uname -s)
@@ -104,7 +119,7 @@ get_target() {
"Darwin") echo "${arch}-apple-darwin" ;; "Darwin") echo "${arch}-apple-darwin" ;;
"Linux") echo "${arch}-unknown-linux-gnu" ;; "Linux") echo "${arch}-unknown-linux-gnu" ;;
"MINGW"*|"MSYS"*|"CYGWIN"*|"Windows_NT") echo "${arch}-pc-windows-msvc" ;; "MINGW"*|"MSYS"*|"CYGWIN"*|"Windows_NT") echo "${arch}-pc-windows-msvc" ;;
"FreeBSD") echo "x86_64-unknown-freebsd" ;; "FreeBSD") echo "${arch}-unknown-freebsd" ;;
*) error "不支持的系统: $os" ;; *) error "不支持的系统: $os" ;;
esac esac
} }
@@ -118,21 +133,31 @@ CURRENT_TARGET=$(get_target "$ARCH" "$OS")
# 获取系统对应的所有目标 # 获取系统对应的所有目标
get_targets() { get_targets() {
case "$1" in case "$1" in
"linux") echo "x86_64-unknown-linux-gnu aarch64-unknown-linux-gnu" ;; "linux")
"windows") echo "x86_64-pc-windows-msvc aarch64-pc-windows-msvc" ;; # Linux 构建所有 Linux 目标和 FreeBSD 目标
"macos") echo "x86_64-apple-darwin aarch64-apple-darwin" ;; echo "x86_64-unknown-linux-gnu x86_64-unknown-freebsd"
"freebsd") echo "x86_64-unknown-freebsd" ;; ;;
"freebsd")
# FreeBSD 只构建当前架构的 FreeBSD 目标
echo "${ARCH}-unknown-freebsd"
;;
"windows")
# Windows 构建所有 Windows 目标
echo "x86_64-pc-windows-msvc"
;;
"macos")
# macOS 构建所有 macOS 目标
echo "x86_64-apple-darwin aarch64-apple-darwin"
;;
*) error "不支持的系统组: $1" ;; *) error "不支持的系统组: $1" ;;
esac esac
} }
# 解析参数 # 解析参数
USE_CROSS=false
USE_STATIC=false USE_STATIC=false
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case $1 in case $1 in
--cross) USE_CROSS=true ;;
--static) USE_STATIC=true ;; --static) USE_STATIC=true ;;
--help) show_help; exit 0 ;; --help) show_help; exit 0 ;;
*) error "未知参数: $1" ;; *) error "未知参数: $1" ;;
@@ -144,19 +169,21 @@ done
check_requirements check_requirements
# 确定要构建的目标 # 确定要构建的目标
if [[ $USE_CROSS == true ]] && is_linux; then case "$OS" in
# 只在 Linux 上使用 cross 进行多架构构建 "Darwin")
TARGETS=($(get_targets "linux")) TARGETS=($(get_targets "macos"))
else ;;
# 其他系统或不使用 cross 时只构建当前系统的所有架构 "Linux")
case "$OS" in TARGETS=($(get_targets "linux"))
"Darwin") TARGETS=($(get_targets "macos")) ;; ;;
"Linux") TARGETS=("$CURRENT_TARGET") ;; "FreeBSD")
"MINGW"*|"MSYS"*|"CYGWIN"*|"Windows_NT") TARGETS=($(get_targets "windows")) ;; TARGETS=($(get_targets "freebsd"))
"FreeBSD") TARGETS=("$CURRENT_TARGET") ;; ;;
*) error "不支持的系统: $OS" ;; "MINGW"*|"MSYS"*|"CYGWIN"*|"Windows_NT")
esac TARGETS=($(get_targets "windows"))
fi ;;
*) error "不支持的系统: $OS" ;;
esac
# 创建 release 目录 # 创建 release 目录
mkdir -p release mkdir -p release

View File

@@ -140,16 +140,14 @@ pub async fn encode_chat_message(
Ok(hex::decode(len_prefix + &content)?) Ok(hex::decode(len_prefix + &content)?)
} }
pub async fn decode_response(data: &[u8]) -> String { pub async fn decode_response(data: &[u8]) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if let Ok(decoded) = decode_proto_messages(data) { match decode_proto_messages(data) {
if !decoded.is_empty() { Ok(decoded) if !decoded.is_empty() => Ok(decoded),
return decoded; _ => decompress_response(data).await
}
} }
decompress_response(data).await
} }
fn decode_proto_messages(data: &[u8]) -> Result<String, Box<dyn std::error::Error>> { fn decode_proto_messages(data: &[u8]) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let hex_str = hex::encode(data); let hex_str = hex::encode(data);
let mut pos = 0; let mut pos = 0;
let mut messages = Vec::new(); let mut messages = Vec::new();
@@ -173,9 +171,9 @@ fn decode_proto_messages(data: &[u8]) -> Result<String, Box<dyn std::error::Erro
Ok(messages.join("")) Ok(messages.join(""))
} }
async fn decompress_response(data: &[u8]) -> String { async fn decompress_response(data: &[u8]) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if data.len() <= 5 { if data.len() <= 5 {
return String::new(); return Ok(String::new());
} }
let mut decoder = GzDecoder::new(&data[5..]); let mut decoder = GzDecoder::new(&data[5..]);
@@ -184,12 +182,12 @@ async fn decompress_response(data: &[u8]) -> String {
match decoder.read_to_string(&mut text) { match decoder.read_to_string(&mut text) {
Ok(_) => { Ok(_) => {
if !text.contains("<|BEGIN_SYSTEM|>") { if !text.contains("<|BEGIN_SYSTEM|>") {
text Ok(text)
} else { } else {
String::new() Ok(String::new())
} }
} },
Err(_) => String::new(), Err(e) => Err(Box::new(e))
} }
} }

View File

@@ -11,20 +11,51 @@ use chrono::{DateTime, Local, Utc};
use futures::StreamExt; use futures::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{
atomic::{AtomicUsize, Ordering},
LazyLock,
};
use std::{convert::Infallible, sync::Arc}; use std::{convert::Infallible, sync::Arc};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
use uuid::Uuid; use uuid::Uuid;
// 应用状态 struct AppConfig {
struct AppState { auth_token: String,
start_time: DateTime<Local>, token_file: String,
token_list_file: String,
route_prefix: String,
version: String, version: String,
start_time: DateTime<Local>,
}
static APP_CONFIG: LazyLock<AppConfig> = LazyLock::new(|| {
// 加载环境变量
if let Err(e) = dotenvy::dotenv() {
eprintln!("警告: 无法加载 .env 文件: {}", e);
}
let auth_token = std::env::var("AUTH_TOKEN").unwrap_or_else(|_| "".to_string());
if auth_token.is_empty() {
eprintln!("错误: AUTH_TOKEN 未设置");
std::process::exit(1);
}
AppConfig {
auth_token,
token_file: std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string()),
token_list_file: std::env::var("TOKEN_LIST_FILE")
.unwrap_or_else(|_| ".token-list".to_string()),
route_prefix: std::env::var("ROUTE_PREFIX").unwrap_or_default(),
version: env!("CARGO_PKG_VERSION").to_string(),
start_time: Local::now(),
}
});
struct AppState {
total_requests: u64, total_requests: u64,
active_requests: u64, active_requests: u64,
request_logs: Vec<RequestLog>, request_logs: Vec<RequestLog>,
route_prefix: String,
token_infos: Vec<TokenInfo>, token_infos: Vec<TokenInfo>,
} }
@@ -45,6 +76,8 @@ struct RequestLog {
checksum: String, checksum: String,
auth_token: String, auth_token: String,
stream: bool, stream: bool,
status: String,
error: Option<String>,
} }
// 聊天请求 // 聊天请求
@@ -68,7 +101,6 @@ mod models;
use models::AVAILABLE_MODELS; use models::AVAILABLE_MODELS;
// 用于存储 token 信息 // 用于存储 token 信息
#[derive(Debug)]
struct TokenInfo { struct TokenInfo {
token: String, token: String,
checksum: String, checksum: String,
@@ -83,7 +115,6 @@ struct TokenUpdateRequest {
} }
// 自定义错误类型 // 自定义错误类型
#[derive(Debug)]
enum ChatError { enum ChatError {
ModelNotSupported(String), ModelNotSupported(String),
EmptyMessages, EmptyMessages,
@@ -124,26 +155,14 @@ impl ChatError {
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// 加载环境变量
dotenvy::dotenv().ok();
// 处理 token 文件路径
let token_file = std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string());
// 加载 tokens // 加载 tokens
let token_infos = load_tokens(&token_file); let token_infos = load_tokens();
// 获取路由前缀配置 // 初始化需要互斥访问的状态
let route_prefix = std::env::var("ROUTE_PREFIX").unwrap_or_default();
// 初始化应用状态
let state = Arc::new(Mutex::new(AppState { let state = Arc::new(Mutex::new(AppState {
start_time: Local::now(),
version: env!("CARGO_PKG_VERSION").to_string(),
total_requests: 0, total_requests: 0,
active_requests: 0, active_requests: 0,
request_logs: Vec::new(), request_logs: Vec::new(),
route_prefix: route_prefix.clone(),
token_infos, token_infos,
})); }));
@@ -151,13 +170,16 @@ async fn main() {
let app = Router::new() let app = Router::new()
.route("/", get(handle_root)) .route("/", get(handle_root))
.route("/tokeninfo", get(handle_tokeninfo_page)) .route("/tokeninfo", get(handle_tokeninfo_page))
.route(&format!("{}/v1/models", route_prefix), get(handle_models)) .route(
&format!("{}/v1/models", APP_CONFIG.route_prefix),
get(handle_models),
)
.route("/checksum", get(handle_checksum)) .route("/checksum", get(handle_checksum))
.route("/update-tokeninfo", get(handle_update_tokeninfo)) .route("/update-tokeninfo", get(handle_update_tokeninfo))
.route("/get-tokeninfo", post(handle_get_tokeninfo)) .route("/get-tokeninfo", post(handle_get_tokeninfo))
.route("/update-tokeninfo", post(handle_update_tokeninfo_post)) .route("/update-tokeninfo", post(handle_update_tokeninfo_post))
.route( .route(
&format!("{}/v1/chat/completions", route_prefix), &format!("{}/v1/chat/completions", APP_CONFIG.route_prefix),
post(handle_chat), post(handle_chat),
) )
.route("/logs", get(handle_logs)) .route("/logs", get(handle_logs))
@@ -174,54 +196,69 @@ async fn main() {
} }
// Token 加载函数 // Token 加载函数
fn load_tokens(token_file: &str) -> Vec<TokenInfo> { fn load_tokens() -> Vec<TokenInfo> {
let token_list_file = // 读取 .token 文件并解析
std::env::var("TOKEN_LIST_FILE").unwrap_or_else(|_| ".token-list".to_string()); let tokens = match std::fs::read_to_string(&APP_CONFIG.token_file) {
Ok(content) => {
// 读取并规范化 .token 文件 let normalized = content.replace("\r\n", "\n");
let tokens = if let Ok(content) = std::fs::read_to_string(token_file) { // 如果内容被规范化,则更新文件
let normalized = content.replace("\r\n", "\n"); if normalized != content {
if normalized != content { if let Err(e) = std::fs::write(&APP_CONFIG.token_file, &normalized) {
std::fs::write(token_file, &normalized).unwrap(); eprintln!("警告: 无法更新规范化的token文件: {}", e);
}
normalized
.lines()
.enumerate()
.filter_map(|(idx, line)| {
let parts: Vec<&str> = line.split("::").collect();
match parts.len() {
1 => Some(line.to_string()),
2 => Some(parts[1].to_string()),
_ => {
println!("警告: 第{}行包含多个'::'分隔符,已忽略此行", idx + 1);
None
}
} }
}) }
.filter(|s| !s.is_empty())
.collect::<Vec<_>>() normalized
} else { .lines()
eprintln!("警告: 无法读取token文件 '{}'", token_file); .filter_map(|line| {
Vec::new() let line = line.trim();
if line.is_empty() {
return None;
}
// 处理 alias::token 格式
match line.split("::").collect::<Vec<_>>() {
parts if parts.len() == 1 => Some(line.to_string()),
parts if parts.len() == 2 => Some(parts[1].to_string()),
_ => {
eprintln!("警告: 忽略无效的token行: {}", line);
None
}
}
})
.collect::<Vec<_>>()
}
Err(e) => {
eprintln!("警告: 无法读取token文件 '{}': {}", APP_CONFIG.token_file, e);
Vec::new()
}
}; };
// 读取现有的 token-list // 读取现有的 token-list
let mut token_map: std::collections::HashMap<String, String> = let mut token_map: std::collections::HashMap<String, String> =
if let Ok(content) = std::fs::read_to_string(&token_list_file) { match std::fs::read_to_string(&APP_CONFIG.token_list_file) {
content Ok(content) => content
.split('\n') .lines()
.filter(|s| !s.is_empty())
.filter_map(|line| { .filter_map(|line| {
let line = line.trim();
if line.is_empty() {
return None;
}
let parts: Vec<&str> = line.split(',').collect(); let parts: Vec<&str> = line.split(',').collect();
if parts.len() == 2 { match parts[..] {
Some((parts[0].to_string(), parts[1].to_string())) [token, checksum] => Some((token.to_string(), checksum.to_string())),
} else { _ => {
None eprintln!("警告: 忽略无效的token-list行: {}", line);
None
}
} }
}) })
.collect() .collect(),
} else { Err(e) => {
std::collections::HashMap::new() eprintln!("警告: 无法读取token-list文件: {}", e);
std::collections::HashMap::new()
}
}; };
// 为新 token 生成 checksum // 为新 token 生成 checksum
@@ -241,7 +278,10 @@ fn load_tokens(token_file: &str) -> Vec<TokenInfo> {
.map(|(token, checksum)| format!("{},{}", token, checksum)) .map(|(token, checksum)| format!("{},{}", token, checksum))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n");
std::fs::write(token_list_file, token_list_content).unwrap();
if let Err(e) = std::fs::write(&APP_CONFIG.token_list_file, token_list_content) {
eprintln!("警告: 无法更新token-list文件: {}", e);
}
// 转换为 TokenInfo vector // 转换为 TokenInfo vector
token_map token_map
@@ -253,14 +293,14 @@ fn load_tokens(token_file: &str) -> Vec<TokenInfo> {
// 根路由处理 // 根路由处理
async fn handle_root(State(state): State<Arc<Mutex<AppState>>>) -> Json<serde_json::Value> { async fn handle_root(State(state): State<Arc<Mutex<AppState>>>) -> Json<serde_json::Value> {
let state = state.lock().await; let state = state.lock().await;
let uptime = (Local::now() - state.start_time).num_seconds(); let uptime = (Local::now() - APP_CONFIG.start_time).num_seconds();
Json(serde_json::json!({ Json(serde_json::json!({
"status": "healthy", "status": "healthy",
"version": state.version, "version": APP_CONFIG.version,
"uptime": uptime, "uptime": uptime,
"stats": { "stats": {
"started": state.start_time, "started": APP_CONFIG.start_time,
"totalRequests": state.total_requests, "totalRequests": state.total_requests,
"activeRequests": state.active_requests, "activeRequests": state.active_requests,
"memory": { "memory": {
@@ -271,8 +311,8 @@ async fn handle_root(State(state): State<Arc<Mutex<AppState>>>) -> Json<serde_js
}, },
"models": AVAILABLE_MODELS.iter().map(|m| &m.id).collect::<Vec<_>>(), "models": AVAILABLE_MODELS.iter().map(|m| &m.id).collect::<Vec<_>>(),
"endpoints": [ "endpoints": [
&format!("{}/v1/chat/completions", state.route_prefix), &format!("{}/v1/chat/completions", APP_CONFIG.route_prefix),
&format!("{}/v1/models", state.route_prefix), &format!("{}/v1/models", APP_CONFIG.route_prefix),
"/checksum", "/checksum",
"/tokeninfo", "/tokeninfo",
"/update-tokeninfo", "/update-tokeninfo",
@@ -311,11 +351,8 @@ async fn handle_checksum() -> Json<serde_json::Value> {
async fn handle_update_tokeninfo( async fn handle_update_tokeninfo(
State(state): State<Arc<Mutex<AppState>>>, State(state): State<Arc<Mutex<AppState>>>,
) -> Json<serde_json::Value> { ) -> Json<serde_json::Value> {
// 获取当前的 token 文件路径
let token_file = std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string());
// 重新加载 tokens // 重新加载 tokens
let token_infos = load_tokens(&token_file); let token_infos = load_tokens();
// 更新应用状态 // 更新应用状态
{ {
@@ -341,25 +378,19 @@ async fn handle_get_tokeninfo(
.and_then(|h| h.strip_prefix("Bearer ")) .and_then(|h| h.strip_prefix("Bearer "))
.ok_or(StatusCode::UNAUTHORIZED)?; .ok_or(StatusCode::UNAUTHORIZED)?;
let env_token = std::env::var("AUTH_TOKEN").map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if auth_header != APP_CONFIG.auth_token {
if auth_header != env_token {
return Err(StatusCode::UNAUTHORIZED); return Err(StatusCode::UNAUTHORIZED);
} }
// 获取文件路径
let token_file = std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string());
let token_list_file =
std::env::var("TOKEN_LIST_FILE").unwrap_or_else(|_| ".token-list".to_string());
// 读取文件内容 // 读取文件内容
let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new()); let tokens = std::fs::read_to_string(&APP_CONFIG.token_file).unwrap_or_else(|_| String::new());
let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new()); let token_list =
std::fs::read_to_string(&APP_CONFIG.token_list_file).unwrap_or_else(|_| String::new());
Ok(Json(serde_json::json!({ Ok(Json(serde_json::json!({
"status": "success", "status": "success",
"token_file": token_file, "token_file": APP_CONFIG.token_file,
"token_list_file": token_list_file, "token_list_file": APP_CONFIG.token_list_file,
"tokens": tokens, "tokens": tokens,
"token_list": token_list "token_list": token_list
}))) })))
@@ -377,28 +408,22 @@ async fn handle_update_tokeninfo_post(
.and_then(|h| h.strip_prefix("Bearer ")) .and_then(|h| h.strip_prefix("Bearer "))
.ok_or(StatusCode::UNAUTHORIZED)?; .ok_or(StatusCode::UNAUTHORIZED)?;
let env_token = std::env::var("AUTH_TOKEN").map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if auth_header != APP_CONFIG.auth_token {
if auth_header != env_token {
return Err(StatusCode::UNAUTHORIZED); return Err(StatusCode::UNAUTHORIZED);
} }
// 获取文件路径
let token_file = std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string());
let token_list_file =
std::env::var("TOKEN_LIST_FILE").unwrap_or_else(|_| ".token-list".to_string());
// 写入 .token 文件 // 写入 .token 文件
std::fs::write(&token_file, &request.tokens).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; std::fs::write(&APP_CONFIG.token_file, &request.tokens)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// 如果提供了 token_list则写入 // 如果提供了 token_list则写入
if let Some(token_list) = request.token_list { if let Some(token_list) = request.token_list {
std::fs::write(&token_list_file, token_list) std::fs::write(&APP_CONFIG.token_list_file, token_list)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
} }
// 重新加载 tokens // 重新加载 tokens
let token_infos = load_tokens(&token_file); let token_infos = load_tokens();
let token_infos_len = token_infos.len(); let token_infos_len = token_infos.len();
// 更新应用状态 // 更新应用状态
@@ -410,8 +435,8 @@ async fn handle_update_tokeninfo_post(
Ok(Json(serde_json::json!({ Ok(Json(serde_json::json!({
"status": "success", "status": "success",
"message": "Token files have been updated and reloaded", "message": "Token files have been updated and reloaded",
"token_file": token_file, "token_file": APP_CONFIG.token_file,
"token_list_file": token_list_file, "token_list_file": APP_CONFIG.token_list_file,
"token_count": token_infos_len "token_count": token_infos_len
}))) })))
} }
@@ -469,14 +494,12 @@ async fn handle_chat(
Json(ChatError::Unauthorized.to_json()), Json(ChatError::Unauthorized.to_json()),
))?; ))?;
// 验证环境变量中的 AUTH_TOKEN // 验证 AuthToken
if let Ok(env_token) = std::env::var("AUTH_TOKEN") { if auth_token != APP_CONFIG.auth_token {
if auth_token != env_token { return Err((
return Err(( StatusCode::UNAUTHORIZED,
StatusCode::UNAUTHORIZED, Json(ChatError::Unauthorized.to_json()),
Json(ChatError::Unauthorized.to_json()), ));
));
}
} }
// 完整的令牌处理逻辑和对应的 checksum // 完整的令牌处理逻辑和对应的 checksum
@@ -508,6 +531,8 @@ async fn handle_chat(
checksum: checksum.clone(), checksum: checksum.clone(),
auth_token: auth_token.clone(), auth_token: auth_token.clone(),
stream: request.stream, stream: request.stream,
status: "pending".to_string(),
error: None,
}); });
if state.request_logs.len() > 100 { if state.request_logs.len() > 100 {
@@ -556,13 +581,33 @@ async fn handle_chat(
.header("Host", "api2.cursor.sh") .header("Host", "api2.cursor.sh")
.body(hex_data) .body(hex_data)
.send() .send()
.await .await;
.map_err(|e| {
( // 处理请求结果
let response = match response {
Ok(resp) => {
// 更新请求日志为成功
{
let mut state = state.lock().await;
state.request_logs.last_mut().unwrap().status = "success".to_string();
}
resp
}
Err(e) => {
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(last_log) = state.request_logs.last_mut() {
last_log.status = "failed".to_string();
last_log.error = Some(e.to_string());
}
}
return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(ChatError::RequestFailed(format!("Request failed: {}", e)).to_json()), Json(ChatError::RequestFailed(format!("Request failed: {}", e)).to_json()),
) ));
})?; }
};
// 释放活动请求计数 // 释放活动请求计数
{ {
@@ -579,11 +624,11 @@ async fn handle_chat(
async move { async move {
let chunk = chunk.unwrap_or_default(); let chunk = chunk.unwrap_or_default();
let text = cursor_api::decode_response(&chunk).await; let text = match cursor_api::decode_response(&chunk).await {
Ok(text) if text.is_empty() => return Ok(Bytes::from("data: [DONE]\n\n")),
if text.is_empty() { Ok(text) => text,
return Ok::<_, Infallible>(Bytes::from("[DONE]")); Err(_) => return Ok(Bytes::new()),
} };
let data = serde_json::json!({ let data = serde_json::json!({
"id": &response_id, "id": &response_id,
@@ -623,7 +668,11 @@ async fn handle_chat(
), ),
) )
})?; })?;
full_text.push_str(&cursor_api::decode_response(&chunk).await); full_text.push_str(
&cursor_api::decode_response(&chunk)
.await
.unwrap_or_default(),
);
} }
// 处理文本 // 处理文本