mirror of
https://github.com/wisdgod/cursor-api.git
synced 2025-10-06 15:16:51 +08:00
修复一些bug
This commit is contained in:
281
src/main.rs
281
src/main.rs
@@ -11,20 +11,51 @@ use chrono::{DateTime, Local, Utc};
|
||||
use futures::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
LazyLock,
|
||||
};
|
||||
use std::{convert::Infallible, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use uuid::Uuid;
|
||||
|
||||
// 应用状态
|
||||
struct AppState {
|
||||
start_time: DateTime<Local>,
|
||||
struct AppConfig {
|
||||
auth_token: String,
|
||||
token_file: String,
|
||||
token_list_file: String,
|
||||
route_prefix: String,
|
||||
version: String,
|
||||
start_time: DateTime<Local>,
|
||||
}
|
||||
|
||||
static APP_CONFIG: LazyLock<AppConfig> = LazyLock::new(|| {
|
||||
// 加载环境变量
|
||||
if let Err(e) = dotenvy::dotenv() {
|
||||
eprintln!("警告: 无法加载 .env 文件: {}", e);
|
||||
}
|
||||
|
||||
let auth_token = std::env::var("AUTH_TOKEN").unwrap_or_else(|_| "".to_string());
|
||||
if auth_token.is_empty() {
|
||||
eprintln!("错误: AUTH_TOKEN 未设置");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
AppConfig {
|
||||
auth_token,
|
||||
token_file: std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string()),
|
||||
token_list_file: std::env::var("TOKEN_LIST_FILE")
|
||||
.unwrap_or_else(|_| ".token-list".to_string()),
|
||||
route_prefix: std::env::var("ROUTE_PREFIX").unwrap_or_default(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
start_time: Local::now(),
|
||||
}
|
||||
});
|
||||
|
||||
struct AppState {
|
||||
total_requests: u64,
|
||||
active_requests: u64,
|
||||
request_logs: Vec<RequestLog>,
|
||||
route_prefix: String,
|
||||
token_infos: Vec<TokenInfo>,
|
||||
}
|
||||
|
||||
@@ -45,6 +76,8 @@ struct RequestLog {
|
||||
checksum: String,
|
||||
auth_token: String,
|
||||
stream: bool,
|
||||
status: String,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
// 聊天请求
|
||||
@@ -68,7 +101,6 @@ mod models;
|
||||
use models::AVAILABLE_MODELS;
|
||||
|
||||
// 用于存储 token 信息
|
||||
#[derive(Debug)]
|
||||
struct TokenInfo {
|
||||
token: String,
|
||||
checksum: String,
|
||||
@@ -83,7 +115,6 @@ struct TokenUpdateRequest {
|
||||
}
|
||||
|
||||
// 自定义错误类型
|
||||
#[derive(Debug)]
|
||||
enum ChatError {
|
||||
ModelNotSupported(String),
|
||||
EmptyMessages,
|
||||
@@ -124,26 +155,14 @@ impl ChatError {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// 加载环境变量
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
// 处理 token 文件路径
|
||||
let token_file = std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string());
|
||||
|
||||
// 加载 tokens
|
||||
let token_infos = load_tokens(&token_file);
|
||||
let token_infos = load_tokens();
|
||||
|
||||
// 获取路由前缀配置
|
||||
let route_prefix = std::env::var("ROUTE_PREFIX").unwrap_or_default();
|
||||
|
||||
// 初始化应用状态
|
||||
// 初始化需要互斥访问的状态
|
||||
let state = Arc::new(Mutex::new(AppState {
|
||||
start_time: Local::now(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
total_requests: 0,
|
||||
active_requests: 0,
|
||||
request_logs: Vec::new(),
|
||||
route_prefix: route_prefix.clone(),
|
||||
token_infos,
|
||||
}));
|
||||
|
||||
@@ -151,13 +170,16 @@ async fn main() {
|
||||
let app = Router::new()
|
||||
.route("/", get(handle_root))
|
||||
.route("/tokeninfo", get(handle_tokeninfo_page))
|
||||
.route(&format!("{}/v1/models", route_prefix), get(handle_models))
|
||||
.route(
|
||||
&format!("{}/v1/models", APP_CONFIG.route_prefix),
|
||||
get(handle_models),
|
||||
)
|
||||
.route("/checksum", get(handle_checksum))
|
||||
.route("/update-tokeninfo", get(handle_update_tokeninfo))
|
||||
.route("/get-tokeninfo", post(handle_get_tokeninfo))
|
||||
.route("/update-tokeninfo", post(handle_update_tokeninfo_post))
|
||||
.route(
|
||||
&format!("{}/v1/chat/completions", route_prefix),
|
||||
&format!("{}/v1/chat/completions", APP_CONFIG.route_prefix),
|
||||
post(handle_chat),
|
||||
)
|
||||
.route("/logs", get(handle_logs))
|
||||
@@ -174,54 +196,69 @@ async fn main() {
|
||||
}
|
||||
|
||||
// Token 加载函数
|
||||
fn load_tokens(token_file: &str) -> Vec<TokenInfo> {
|
||||
let token_list_file =
|
||||
std::env::var("TOKEN_LIST_FILE").unwrap_or_else(|_| ".token-list".to_string());
|
||||
|
||||
// 读取并规范化 .token 文件
|
||||
let tokens = if let Ok(content) = std::fs::read_to_string(token_file) {
|
||||
let normalized = content.replace("\r\n", "\n");
|
||||
if normalized != content {
|
||||
std::fs::write(token_file, &normalized).unwrap();
|
||||
}
|
||||
normalized
|
||||
.lines()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, line)| {
|
||||
let parts: Vec<&str> = line.split("::").collect();
|
||||
match parts.len() {
|
||||
1 => Some(line.to_string()),
|
||||
2 => Some(parts[1].to_string()),
|
||||
_ => {
|
||||
println!("警告: 第{}行包含多个'::'分隔符,已忽略此行", idx + 1);
|
||||
None
|
||||
}
|
||||
fn load_tokens() -> Vec<TokenInfo> {
|
||||
// 读取 .token 文件并解析
|
||||
let tokens = match std::fs::read_to_string(&APP_CONFIG.token_file) {
|
||||
Ok(content) => {
|
||||
let normalized = content.replace("\r\n", "\n");
|
||||
// 如果内容被规范化,则更新文件
|
||||
if normalized != content {
|
||||
if let Err(e) = std::fs::write(&APP_CONFIG.token_file, &normalized) {
|
||||
eprintln!("警告: 无法更新规范化的token文件: {}", e);
|
||||
}
|
||||
})
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
eprintln!("警告: 无法读取token文件 '{}'", token_file);
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
normalized
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// 处理 alias::token 格式
|
||||
match line.split("::").collect::<Vec<_>>() {
|
||||
parts if parts.len() == 1 => Some(line.to_string()),
|
||||
parts if parts.len() == 2 => Some(parts[1].to_string()),
|
||||
_ => {
|
||||
eprintln!("警告: 忽略无效的token行: {}", line);
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("警告: 无法读取token文件 '{}': {}", APP_CONFIG.token_file, e);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// 读取现有的 token-list
|
||||
let mut token_map: std::collections::HashMap<String, String> =
|
||||
if let Ok(content) = std::fs::read_to_string(&token_list_file) {
|
||||
content
|
||||
.split('\n')
|
||||
.filter(|s| !s.is_empty())
|
||||
match std::fs::read_to_string(&APP_CONFIG.token_list_file) {
|
||||
Ok(content) => content
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = line.split(',').collect();
|
||||
if parts.len() == 2 {
|
||||
Some((parts[0].to_string(), parts[1].to_string()))
|
||||
} else {
|
||||
None
|
||||
match parts[..] {
|
||||
[token, checksum] => Some((token.to_string(), checksum.to_string())),
|
||||
_ => {
|
||||
eprintln!("警告: 忽略无效的token-list行: {}", line);
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
std::collections::HashMap::new()
|
||||
.collect(),
|
||||
Err(e) => {
|
||||
eprintln!("警告: 无法读取token-list文件: {}", e);
|
||||
std::collections::HashMap::new()
|
||||
}
|
||||
};
|
||||
|
||||
// 为新 token 生成 checksum
|
||||
@@ -241,7 +278,10 @@ fn load_tokens(token_file: &str) -> Vec<TokenInfo> {
|
||||
.map(|(token, checksum)| format!("{},{}", token, checksum))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
std::fs::write(token_list_file, token_list_content).unwrap();
|
||||
|
||||
if let Err(e) = std::fs::write(&APP_CONFIG.token_list_file, token_list_content) {
|
||||
eprintln!("警告: 无法更新token-list文件: {}", e);
|
||||
}
|
||||
|
||||
// 转换为 TokenInfo vector
|
||||
token_map
|
||||
@@ -253,14 +293,14 @@ fn load_tokens(token_file: &str) -> Vec<TokenInfo> {
|
||||
// 根路由处理
|
||||
async fn handle_root(State(state): State<Arc<Mutex<AppState>>>) -> Json<serde_json::Value> {
|
||||
let state = state.lock().await;
|
||||
let uptime = (Local::now() - state.start_time).num_seconds();
|
||||
let uptime = (Local::now() - APP_CONFIG.start_time).num_seconds();
|
||||
|
||||
Json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"version": state.version,
|
||||
"version": APP_CONFIG.version,
|
||||
"uptime": uptime,
|
||||
"stats": {
|
||||
"started": state.start_time,
|
||||
"started": APP_CONFIG.start_time,
|
||||
"totalRequests": state.total_requests,
|
||||
"activeRequests": state.active_requests,
|
||||
"memory": {
|
||||
@@ -271,8 +311,8 @@ async fn handle_root(State(state): State<Arc<Mutex<AppState>>>) -> Json<serde_js
|
||||
},
|
||||
"models": AVAILABLE_MODELS.iter().map(|m| &m.id).collect::<Vec<_>>(),
|
||||
"endpoints": [
|
||||
&format!("{}/v1/chat/completions", state.route_prefix),
|
||||
&format!("{}/v1/models", state.route_prefix),
|
||||
&format!("{}/v1/chat/completions", APP_CONFIG.route_prefix),
|
||||
&format!("{}/v1/models", APP_CONFIG.route_prefix),
|
||||
"/checksum",
|
||||
"/tokeninfo",
|
||||
"/update-tokeninfo",
|
||||
@@ -311,11 +351,8 @@ async fn handle_checksum() -> Json<serde_json::Value> {
|
||||
async fn handle_update_tokeninfo(
|
||||
State(state): State<Arc<Mutex<AppState>>>,
|
||||
) -> Json<serde_json::Value> {
|
||||
// 获取当前的 token 文件路径
|
||||
let token_file = std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string());
|
||||
|
||||
// 重新加载 tokens
|
||||
let token_infos = load_tokens(&token_file);
|
||||
let token_infos = load_tokens();
|
||||
|
||||
// 更新应用状态
|
||||
{
|
||||
@@ -341,25 +378,19 @@ async fn handle_get_tokeninfo(
|
||||
.and_then(|h| h.strip_prefix("Bearer "))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let env_token = std::env::var("AUTH_TOKEN").map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if auth_header != env_token {
|
||||
if auth_header != APP_CONFIG.auth_token {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// 获取文件路径
|
||||
let token_file = std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string());
|
||||
let token_list_file =
|
||||
std::env::var("TOKEN_LIST_FILE").unwrap_or_else(|_| ".token-list".to_string());
|
||||
|
||||
// 读取文件内容
|
||||
let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new());
|
||||
let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new());
|
||||
let tokens = std::fs::read_to_string(&APP_CONFIG.token_file).unwrap_or_else(|_| String::new());
|
||||
let token_list =
|
||||
std::fs::read_to_string(&APP_CONFIG.token_list_file).unwrap_or_else(|_| String::new());
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"status": "success",
|
||||
"token_file": token_file,
|
||||
"token_list_file": token_list_file,
|
||||
"token_file": APP_CONFIG.token_file,
|
||||
"token_list_file": APP_CONFIG.token_list_file,
|
||||
"tokens": tokens,
|
||||
"token_list": token_list
|
||||
})))
|
||||
@@ -377,28 +408,22 @@ async fn handle_update_tokeninfo_post(
|
||||
.and_then(|h| h.strip_prefix("Bearer "))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let env_token = std::env::var("AUTH_TOKEN").map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if auth_header != env_token {
|
||||
if auth_header != APP_CONFIG.auth_token {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// 获取文件路径
|
||||
let token_file = std::env::var("TOKEN_FILE").unwrap_or_else(|_| ".token".to_string());
|
||||
let token_list_file =
|
||||
std::env::var("TOKEN_LIST_FILE").unwrap_or_else(|_| ".token-list".to_string());
|
||||
|
||||
// 写入 .token 文件
|
||||
std::fs::write(&token_file, &request.tokens).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
std::fs::write(&APP_CONFIG.token_file, &request.tokens)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// 如果提供了 token_list,则写入
|
||||
if let Some(token_list) = request.token_list {
|
||||
std::fs::write(&token_list_file, token_list)
|
||||
std::fs::write(&APP_CONFIG.token_list_file, token_list)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
// 重新加载 tokens
|
||||
let token_infos = load_tokens(&token_file);
|
||||
let token_infos = load_tokens();
|
||||
let token_infos_len = token_infos.len();
|
||||
|
||||
// 更新应用状态
|
||||
@@ -410,8 +435,8 @@ async fn handle_update_tokeninfo_post(
|
||||
Ok(Json(serde_json::json!({
|
||||
"status": "success",
|
||||
"message": "Token files have been updated and reloaded",
|
||||
"token_file": token_file,
|
||||
"token_list_file": token_list_file,
|
||||
"token_file": APP_CONFIG.token_file,
|
||||
"token_list_file": APP_CONFIG.token_list_file,
|
||||
"token_count": token_infos_len
|
||||
})))
|
||||
}
|
||||
@@ -469,14 +494,12 @@ async fn handle_chat(
|
||||
Json(ChatError::Unauthorized.to_json()),
|
||||
))?;
|
||||
|
||||
// 验证环境变量中的 AUTH_TOKEN
|
||||
if let Ok(env_token) = std::env::var("AUTH_TOKEN") {
|
||||
if auth_token != env_token {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ChatError::Unauthorized.to_json()),
|
||||
));
|
||||
}
|
||||
// 验证 AuthToken
|
||||
if auth_token != APP_CONFIG.auth_token {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ChatError::Unauthorized.to_json()),
|
||||
));
|
||||
}
|
||||
|
||||
// 完整的令牌处理逻辑和对应的 checksum
|
||||
@@ -508,6 +531,8 @@ async fn handle_chat(
|
||||
checksum: checksum.clone(),
|
||||
auth_token: auth_token.clone(),
|
||||
stream: request.stream,
|
||||
status: "pending".to_string(),
|
||||
error: None,
|
||||
});
|
||||
|
||||
if state.request_logs.len() > 100 {
|
||||
@@ -556,13 +581,33 @@ async fn handle_chat(
|
||||
.header("Host", "api2.cursor.sh")
|
||||
.body(hex_data)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
.await;
|
||||
|
||||
// 处理请求结果
|
||||
let response = match response {
|
||||
Ok(resp) => {
|
||||
// 更新请求日志为成功
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.request_logs.last_mut().unwrap().status = "success".to_string();
|
||||
}
|
||||
resp
|
||||
}
|
||||
Err(e) => {
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.status = "failed".to_string();
|
||||
last_log.error = Some(e.to_string());
|
||||
}
|
||||
}
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::RequestFailed(format!("Request failed: {}", e)).to_json()),
|
||||
)
|
||||
})?;
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// 释放活动请求计数
|
||||
{
|
||||
@@ -579,11 +624,11 @@ async fn handle_chat(
|
||||
|
||||
async move {
|
||||
let chunk = chunk.unwrap_or_default();
|
||||
let text = cursor_api::decode_response(&chunk).await;
|
||||
|
||||
if text.is_empty() {
|
||||
return Ok::<_, Infallible>(Bytes::from("[DONE]"));
|
||||
}
|
||||
let text = match cursor_api::decode_response(&chunk).await {
|
||||
Ok(text) if text.is_empty() => return Ok(Bytes::from("data: [DONE]\n\n")),
|
||||
Ok(text) => text,
|
||||
Err(_) => return Ok(Bytes::new()),
|
||||
};
|
||||
|
||||
let data = serde_json::json!({
|
||||
"id": &response_id,
|
||||
@@ -623,7 +668,11 @@ async fn handle_chat(
|
||||
),
|
||||
)
|
||||
})?;
|
||||
full_text.push_str(&cursor_api::decode_response(&chunk).await);
|
||||
full_text.push_str(
|
||||
&cursor_api::decode_response(&chunk)
|
||||
.await
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
}
|
||||
|
||||
// 处理文本
|
||||
|
Reference in New Issue
Block a user