From ea19cbc70a16b8d77a7bec1979af2f31cd805555 Mon Sep 17 00:00:00 2001 From: wisdgod Date: Sat, 4 Jan 2025 02:08:16 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=99=E6=98=AF=E5=8F=AF=E5=9B=9E=E9=80=80?= =?UTF-8?q?=E6=99=AE=E9=80=9A=E7=89=88=E7=9A=84=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 3 + .gitignore | 3 +- Cargo.lock | 88 +++++++++++- Cargo.toml | 24 ++-- src/app.rs | 2 + src/app/constant.rs | 3 - src/app/db.rs | 262 +++++++++++++++++++++++++++++++++++ src/app/lazy.rs | 10 +- src/app/model.rs | 12 ++ src/chat/constant.rs | 1 + src/chat/error.rs | 121 ++++++++-------- src/chat/route/token.rs | 81 ++++++++--- src/chat/route/usage.rs | 5 +- src/chat/service.rs | 9 +- src/common/client.rs | 15 +- src/common/utils.rs | 1 + src/common/utils/checksum.rs | 51 +++++++ src/common/utils/oauth.rs | 80 +++++++++++ src/main.rs | 3 + static/logs.html | 2 + 20 files changed, 655 insertions(+), 121 deletions(-) create mode 100644 src/app/db.rs create mode 100644 src/common/utils/oauth.rs diff --git a/.env.example b/.env.example index 5c606d4..472c252 100644 --- a/.env.example +++ b/.env.example @@ -37,3 +37,6 @@ VISION_ABILITY=base64 # 默认提示词 DEFAULT_INSTRUCTIONS="Respond in Chinese by default" + +# 反向代理服务器主机名 +CURSOR_API2_HOST= diff --git a/.gitignore b/.gitignore index 4114801..6be81ea 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ node_modules /cursor-api.exe /release -/*.py \ No newline at end of file +/*.py +/logs diff --git a/Cargo.lock b/Cargo.lock index 2b01dfc..1334525 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -292,8 +304,9 @@ dependencies = [ [[package]] name = "cursor-api" -version = "0.1.3-rc.3" +version = "0.1.3" dependencies = [ + "anyhow", "axum", "base64", "bytes", @@ -311,6 +324,7 @@ dependencies = [ "rand", "regex", "reqwest", + "rusqlite", "serde", "serde_json", "sha2", @@ -318,6 +332,7 @@ dependencies = [ "tokio", "tokio-stream", "tower-http", + "urlencoding", "uuid", ] @@ -379,6 +394,18 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" @@ -573,12 +600,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.5.0" @@ -906,7 +951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -952,6 +997,17 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1318,9 +1374,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.11" +version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe060fe50f524be480214aba758c71f99f90ee8c83c5a36b5e9e1d568eb4eb3" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "async-compression", "base64", @@ -1378,6 +1434,20 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags 2.6.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1602,9 +1672,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.92" +version = "2.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ae51629bf965c5c098cc9e87908a3df5301051a9e087d6f9bef5c9771ed126" +checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3" dependencies = [ "proc-macro2", "quote", @@ -1856,6 +1926,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf16_iter" version = "1.0.5" diff --git a/Cargo.toml b/Cargo.toml index 714c3d8..2bebb0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,8 @@ [package] name = "cursor-api" -version = "0.1.3-rc.3" +version = "0.1.3" edition = "2021" authors = ["wisdgod "] -# license = "MIT" -# copyright = "Copyright (c) 2024 wisdgod" description = "OpenAI format compatibility layer for the Cursor API" repository = "https://github.com/wisdgod/cursor-api" @@ -14,6 +12,7 @@ sha2 = { version = "0.10.8", default-features = false } serde_json = "1.0.134" [dependencies] +anyhow = "1.0.95" axum = { version = "0.7.9", features = ["json"] } base64 = { version = "0.22.1", default-features = false, features = ["std"] } # brotli = { version = "7.0.0", default-features = false, features = ["std"] } @@ -30,7 +29,8 @@ paste = "1.0.15" prost = "0.13.4" rand = { version = "0.8.5", default-features = false, features = ["std", "std_rng"] } regex = { version = "1.11.1", default-features = false, features = ["std", "perf"] } -reqwest = { version = "0.12.11", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] } +reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] } +rusqlite = { version = "0.32.1", features = ["bundled"], optional = true } serde = { version = "1.0.217", default-features = false, features = ["std", "derive"] } serde_json = "1.0.134" sha2 = { version = "0.10.8", default-features = false } @@ -38,6 +38,7 @@ sysinfo = { version = "0.33.1", default-features = false, features = ["system"] tokio = { version = "1.42.0", features = ["rt-multi-thread", "macros", "net", "sync", "time"] } tokio-stream = { version = "0.1.17", features = ["time"] } tower-http = { version = "0.6.2", features = ["cors"] } +urlencoding = "2.1.3" uuid = { version = "1.11.0", features = ["v4"] } [profile.release] @@ -47,15 +48,6 @@ panic = 'abort' strip = true opt-level = 3 -# 构建脚本设置 -[package.metadata.cross.target.x86_64-unknown-linux-gnu] -image = "ghcr.io/cross-rs/x86_64-unknown-linux-gnu:main" - -[package.metadata.cross.target.aarch64-unknown-linux-gnu] -image = "ghcr.io/cross-rs/aarch64-unknown-linux-gnu:main" - -[package.metadata.cross.target.x86_64-apple-darwin] -image = "ghcr.io/cross-rs/x86_64-apple-darwin:main" - -[package.metadata.cross.target.aarch64-apple-darwin] -image = "ghcr.io/cross-rs/aarch64-apple-darwin:main" +[features] +default = [] +sqlite = ["dep:rusqlite"] diff --git a/src/app.rs b/src/app.rs index a08e33d..9b33bc5 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,4 +1,6 @@ pub mod config; pub mod constant; +#[cfg(feature = "sqlite")] +pub mod db; pub mod model; pub mod lazy; diff --git a/src/app/constant.rs b/src/app/constant.rs index 0658add..a23503f 100644 --- a/src/app/constant.rs +++ b/src/app/constant.rs @@ -50,9 +50,6 @@ def_pub_const!(AUTHORIZATION_BEARER_PREFIX, "Bearer "); def_pub_const!(OBJECT_CHAT_COMPLETION, "chat.completion"); def_pub_const!(OBJECT_CHAT_COMPLETION_CHUNK, "chat.completion.chunk"); -def_pub_const!(CURSOR_API2_HOST, "api2.cursor.sh"); -def_pub_const!(CURSOR_API2_BASE_URL, "https://api2.cursor.sh/aiserver.v1.AiService/"); - def_pub_const!(CURSOR_API2_STREAM_CHAT, "StreamChat"); def_pub_const!(CURSOR_API2_GET_USER_INFO, "GetUserInfo"); diff --git a/src/app/db.rs b/src/app/db.rs new file mode 100644 index 0000000..b29b33b --- /dev/null +++ b/src/app/db.rs @@ -0,0 +1,262 @@ +use crate::app::model::{RequestLog, TokenInfo}; +use crate::common::models::usage::UserUsageInfo; +use chrono::{DateTime, Local}; +use lazy_static::lazy_static; +use rusqlite::params; +use rusqlite::{Connection, Result}; +use std::path::Path; +use std::sync::Mutex; + +const DB_PATH: &str = "logs/sqlite.db"; + +pub struct AppDb { + conn: Connection, +} + +impl AppDb { + pub fn new() -> Result { + // 确保目录存在 + if let Some(parent) = Path::new(DB_PATH).parent() { + std::fs::create_dir_all(parent).map_err(|e| { + rusqlite::Error::SqliteFailure( + rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_IOERR), + Some(e.to_string()), + ) + })?; + } + + let conn = Connection::open(DB_PATH)?; + + // 启用WAL模式以提升性能 + conn.execute_batch("PRAGMA journal_mode = WAL")?; + + // 创建token信息表 + conn.execute( + "CREATE TABLE IF NOT EXISTS token_infos ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token TEXT NOT NULL UNIQUE, + checksum TEXT NOT NULL, + alias TEXT, + fast_requests INTEGER, + max_fast_requests INTEGER + )", + [], + )?; + + // 创建请求日志表 + conn.execute( + "CREATE TABLE IF NOT EXISTS request_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + model TEXT NOT NULL, + token_id INTEGER NOT NULL, + prompt TEXT, + stream BOOLEAN NOT NULL, + status TEXT NOT NULL, + error TEXT, + FOREIGN KEY(token_id) REFERENCES token_infos(id) + )", + [], + )?; + + // 创建索引 + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_token ON token_infos(token)", + [], + )?; + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_timestamp_model ON request_logs(timestamp, model)", + [], + )?; + + Ok(Self { conn }) + } + + fn get_or_create_token_info(&self, token_info: &TokenInfo) -> Result { + let mut stmt = self.conn.prepare_cached( + "INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests) + VALUES (?1, ?2, ?3, ?4, ?5) + RETURNING id" + )?; + + stmt.query_row( + params![ + &token_info.token, + &token_info.checksum, + &token_info.alias, + token_info.usage.as_ref().map(|u| u.fast_requests), + token_info.usage.as_ref().map(|u| u.max_fast_requests), + ], + |row| row.get(0), + ) + } + + pub fn add_log(&self, log: &RequestLog) -> Result<()> { + let token_id = self.get_or_create_token_info(&log.token_info)?; + + self.conn.execute( + "INSERT INTO request_logs (timestamp, model, token_id, prompt, stream, status, error) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", + params![ + log.timestamp.to_rfc3339(), + &log.model, + token_id, + &log.prompt, + log.stream, + &log.status, + &log.error, + ], + )?; + Ok(()) + } + + fn map_row_to_log(&self, row: &rusqlite::Row) -> Result { + let token_id: i64 = row.get(3)?; + let token_info = self.get_token_info_by_id(token_id)?; + + Ok(RequestLog { + id: row.get(0)?, + timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?) + .unwrap() + .with_timezone(&Local), + model: row.get(2)?, + token_info, + prompt: row.get(4)?, + stream: row.get(5)?, + status: row.get(6)?, + error: row.get(7)?, + }) + } + + fn get_token_info_by_id(&self, id: i64) -> Result { + let mut stmt = self.conn.prepare_cached( + "SELECT token, checksum, alias, fast_requests, max_fast_requests + FROM token_infos + WHERE id = ?", + )?; + + stmt.query_row([id], |row| { + Ok(TokenInfo { + token: row.get(0)?, + checksum: row.get(1)?, + alias: row.get(2)?, + usage: Some(UserUsageInfo { + fast_requests: row.get(3)?, + max_fast_requests: row.get(4)?, + }), + }) + }) + } + + pub fn get_token_infos(&self) -> Result> { + let mut stmt = self.conn.prepare_cached( + "SELECT token, checksum, alias, fast_requests, max_fast_requests + FROM token_infos", + )?; + + let tokens = stmt.query_map([], |row| { + Ok(TokenInfo { + token: row.get(0)?, + checksum: row.get(1)?, + alias: row.get(2)?, + usage: Some(UserUsageInfo { + fast_requests: row.get(3)?, + max_fast_requests: row.get(4)?, + }), + }) + })?; + tokens.collect() + } + + pub fn get_recent_logs(&self, limit: i64) -> Result> { + let mut stmt = self.conn.prepare_cached( + "SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests + FROM request_logs r + JOIN token_infos t ON r.token_id = t.id + ORDER BY r.timestamp DESC + LIMIT ?", + )?; + + let logs = stmt.query_map([limit], |row| { + Ok(RequestLog { + id: row.get(0)?, + timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?) + .unwrap() + .with_timezone(&Local), + model: row.get(2)?, + token_info: TokenInfo { + token: row.get(8)?, + checksum: row.get(9)?, + alias: row.get(10)?, + usage: Some(UserUsageInfo { + fast_requests: row.get(11)?, + max_fast_requests: row.get(12)?, + }), + }, + prompt: row.get(4)?, + stream: row.get(5)?, + status: row.get(6)?, + error: row.get(7)?, + }) + })?; + logs.collect() + } + + pub fn get_logs_by_timerange( + &self, + start: DateTime, + end: DateTime, + ) -> Result> { + let mut stmt = self.conn.prepare_cached( + "SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests + FROM request_logs r + JOIN token_infos t ON r.token_id = t.id + WHERE r.timestamp BETWEEN ?1 AND ?2 + ORDER BY r.timestamp DESC", + )?; + + let logs = stmt.query_map([start.to_rfc3339(), end.to_rfc3339()], |row| { + Ok(RequestLog { + id: row.get(0)?, + timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?) + .unwrap() + .with_timezone(&Local), + model: row.get(2)?, + token_info: TokenInfo { + token: row.get(8)?, + checksum: row.get(9)?, + alias: row.get(10)?, + usage: Some(UserUsageInfo { + fast_requests: row.get(11)?, + max_fast_requests: row.get(12)?, + }), + }, + prompt: row.get(4)?, + stream: row.get(5)?, + status: row.get(6)?, + error: row.get(7)?, + }) + })?; + logs.collect() + } + + pub fn update_token_info(&self, token_info: &TokenInfo) -> Result<()> { + self.conn.execute( + "INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests) + VALUES (?1, ?2, ?3, ?4, ?5)", + params![ + &token_info.token, + &token_info.checksum, + &token_info.alias, + token_info.usage.as_ref().map(|u| u.fast_requests), + token_info.usage.as_ref().map(|u| u.max_fast_requests), + ], + )?; + Ok(()) + } +} + +lazy_static! { + pub static ref APP_DB: Mutex = + Mutex::new(AppDb::new().expect("Failed to initialize database")); +} diff --git a/src/app/lazy.rs b/src/app/lazy.rs index 553a87a..74ab67e 100644 --- a/src/app/lazy.rs +++ b/src/app/lazy.rs @@ -33,11 +33,11 @@ def_pub_static!(TOKEN_FILE, env: "TOKEN_FILE", default: DEFAULT_TOKEN_FILE_NAME) def_pub_static!(TOKEN_LIST_FILE, env: "TOKEN_LIST_FILE", default: DEFAULT_TOKEN_LIST_FILE_NAME); def_pub_static!( ROUTE_MODELS_PATH, - format!("{}/v1/models", ROUTE_PREFIX.as_str()) + format!("{}/v1/models", *ROUTE_PREFIX) ); def_pub_static!( ROUTE_CHAT_PATH, - format!("{}/v1/chat/completions", ROUTE_PREFIX.as_str()) + format!("{}/v1/chat/completions", *ROUTE_PREFIX) ); pub static START_TIME: LazyLock> = @@ -49,6 +49,12 @@ pub fn get_start_time() -> chrono::DateTime { def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default"); +def_pub_static!(CURSOR_API2_HOST, env: "REVERSE_PROXY_HOST", default: "api2.cursor.sh"); + +pub static CURSOR_API2_BASE_URL: LazyLock = LazyLock::new(|| { + format!("https://{}/aiserver.v1.AiService/", *CURSOR_API2_HOST) +}); + // pub static DEBUG: LazyLock = LazyLock::new(|| parse_bool_from_env("DEBUG", false)); // #[macro_export] diff --git a/src/app/model.rs b/src/app/model.rs index 0476065..bba146f 100644 --- a/src/app/model.rs +++ b/src/app/model.rs @@ -87,7 +87,9 @@ pub struct Pages { pub struct AppState { pub total_requests: u64, pub active_requests: u64, + #[cfg(not(feature = "sqlite"))] pub request_logs: Vec, + #[cfg(not(feature = "sqlite"))] pub token_infos: Vec, } @@ -273,6 +275,7 @@ impl AppConfig { } impl AppState { + #[cfg(not(feature = "sqlite"))] pub fn new(token_infos: Vec) -> Self { Self { total_requests: 0, @@ -281,11 +284,20 @@ impl AppState { token_infos, } } + + #[cfg(feature = "sqlite")] + pub fn new() -> Self { + Self { + total_requests: 0, + active_requests: 0, + } + } } // 请求日志 #[derive(Serialize, Clone)] pub struct RequestLog { + pub id: u64, pub timestamp: chrono::DateTime, pub model: String, pub token_info: TokenInfo, diff --git a/src/chat/constant.rs b/src/chat/constant.rs index 73eba5f..84a8d6a 100644 --- a/src/chat/constant.rs +++ b/src/chat/constant.rs @@ -7,6 +7,7 @@ macro_rules! def_pub_const { } def_pub_const!(ERR_UNSUPPORTED_GIF, "不支持动态 GIF"); def_pub_const!(ERR_UNSUPPORTED_IMAGE_FORMAT, "不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF"); +def_pub_const!(ERR_NODATA, "No data"); const MODEL_OBJECT: &str = "model"; const CREATED: &i64 = &1706659200; diff --git a/src/chat/error.rs b/src/chat/error.rs index 290c981..75e530d 100644 --- a/src/chat/error.rs +++ b/src/chat/error.rs @@ -2,41 +2,66 @@ use super::aiserver::v1::throw_error_check_request::Error as ErrorType; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ChatError { - pub error: ErrorBody, + error: ErrorBody, } -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ErrorBody { - pub code: String, - pub message: String, - pub details: Vec, + code: String, + // message: String, always: Error + details: Vec, } -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ErrorDetail { - #[serde(rename = "type")] - pub error_type: String, - pub debug: ErrorDebug, - pub value: String, + // #[serde(rename = "type")] + // error_type: String, always: aiserver.v1.ErrorDetails + debug: ErrorDebug, + value: String, } -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ErrorDebug { - pub error: String, - pub details: ErrorDetails, - #[serde(rename = "isExpected")] - pub is_expected: bool, + error: String, + details: ErrorDetails, + // #[serde(rename = "isExpected")] + // is_expected: Option, } -impl ErrorDebug { - // pub fn is_valid(&self) -> bool { - // ErrorType::from_str_name(&self.error).is_some() - // } +#[derive(Deserialize)] +pub struct ErrorDetails { + title: String, + detail: String, + // #[serde(rename = "isRetryable")] + // is_retryable: Option, +} + +use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse}; + +impl ChatError { + pub fn to_error_response(&self) -> ErrorResponse { + if self.error.details.is_empty() { + return ErrorResponse { + status: 500, + code: "unknown".to_string(), + error: None, + }; + } + ErrorResponse { + status: self.status_code(), + code: self.error.code.clone(), + error: Some(Error { + message: self.error.details[0].debug.details.title.clone(), + details: self.error.details[0].debug.details.detail.clone(), + value: self.error.details[0].value.clone(), + }), + } + } pub fn status_code(&self) -> u16 { - match ErrorType::from_str_name(&self.error) { + match ErrorType::from_str_name(&self.error.details[0].debug.error) { Some(error) => match error { ErrorType::Unspecified => 500, ErrorType::BadApiKey @@ -68,46 +93,26 @@ impl ErrorDebug { | ErrorType::SlashEditFileTooLong | ErrorType::FileUnsupported | ErrorType::ClaudeImageTooLarge => 400, - _ => 500, + ErrorType::Deprecated + | ErrorType::FreeUserUsageLimit + | ErrorType::ProUserUsageLimit + | ErrorType::ResourceExhausted + | ErrorType::Openai + | ErrorType::MaxTokens + | ErrorType::ApiKeyNotSupported + | ErrorType::UserAbortedRequest + | ErrorType::CustomMessage + | ErrorType::OutdatedClient + | ErrorType::Debounced + | ErrorType::RepositoryServiceRepositoryIsNotInitialized => 500, }, None => 500, } } -} -#[derive(Serialize, Deserialize)] -pub struct ErrorDetails { - pub title: String, - pub detail: String, - #[serde(rename = "isRetryable")] - pub is_retryable: bool, -} - -use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse}; - -impl ChatError { - pub fn to_json(&self) -> serde_json::Value { - serde_json::to_value(self).unwrap() - } - - pub fn to_error_response(&self) -> ErrorResponse { - if self.error.details.is_empty() { - return ErrorResponse { - status: 500, - code: "ERROR_UNKNOWN".to_string(), - error: None, - }; - } - ErrorResponse { - status: self.error.details[0].debug.status_code(), - code: self.error.details[0].debug.error.clone(), - error: Some(Error { - message: self.error.details[0].debug.details.title.clone(), - details: self.error.details[0].debug.details.detail.clone(), - value: self.error.details[0].value.clone(), - }), - } - } + // pub fn is_expected(&self) -> bool { + // self.error.details[0].debug.is_expected.unwrap_or_default() + // } } #[derive(Serialize)] @@ -135,7 +140,7 @@ impl ErrorResponse { } pub fn native_code(&self) -> String { - self.code.replace("_", " ").to_lowercase() + self.code.replace("_", " ") } pub fn to_common(self) -> CommonErrorResponse { @@ -157,7 +162,7 @@ pub enum StreamError { impl std::fmt::Display for StreamError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - StreamError::ChatError(error) => write!(f, "{}", serde_json::to_string(error).unwrap()), + StreamError::ChatError(error) => write!(f, "{}", error.error.details[0].debug.details.title), StreamError::DataLengthLessThan5 => write!(f, "data length less than 5"), StreamError::EmptyMessage => write!(f, "empty message"), } diff --git a/src/chat/route/token.rs b/src/chat/route/token.rs index f637c13..95a2687 100644 --- a/src/chat/route/token.rs +++ b/src/chat/route/token.rs @@ -5,7 +5,7 @@ use crate::{ CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE, ROUTE_TOKENINFO_PATH, }, - model::{AppConfig, AppState, PageContent, TokenUpdateRequest}, + model::{AppConfig, PageContent, TokenUpdateRequest}, lazy::{AUTH_TOKEN, TOKEN_FILE, TOKEN_LIST_FILE}, }, common::{ @@ -13,15 +13,22 @@ use crate::{ utils::{generate_checksum, generate_hash, tokens::load_tokens}, }, }; +#[cfg(not(feature = "sqlite"))] +use crate::app::model::AppState; +#[cfg(feature = "sqlite")] +use crate::app::db::APP_DB; use axum::{ - extract::State, http::HeaderMap, response::{IntoResponse, Response}, Json, }; +#[cfg(not(feature = "sqlite"))] +use axum::extract::State; use reqwest::StatusCode; use serde::Serialize; +#[cfg(not(feature = "sqlite"))] use std::sync::Arc; +#[cfg(not(feature = "sqlite"))] use tokio::sync::Mutex; #[derive(Serialize)] @@ -36,17 +43,28 @@ pub async fn handle_get_checksum() -> Json { // 更新 TokenInfo 处理 pub async fn handle_update_tokeninfo( - State(state): State>>, + #[cfg(not(feature = "sqlite"))] State(state): State>>, ) -> Json { // 重新加载 tokens let token_infos = load_tokens(); // 更新应用状态 + #[cfg(not(feature = "sqlite"))] { let mut state = state.lock().await; state.token_infos = token_infos; } + #[cfg(feature = "sqlite")] + { + // 使用 APP_DB 更新 token_infos + if let Ok(db) = APP_DB.lock() { + for token_info in token_infos { + let _ = db.update_token_info(&token_info); + } + } + } + Json(NormalResponseNoData { status: ApiStatus::Success, message: Some("Token list has been reloaded".to_string()), @@ -55,13 +73,8 @@ pub async fn handle_update_tokeninfo( // 获取 TokenInfo 处理 pub async fn handle_get_tokeninfo( - State(_state): State>>, headers: HeaderMap, ) -> Result, StatusCode> { - let auth_token = AUTH_TOKEN.as_str(); - let token_file = TOKEN_FILE.as_str(); - let token_list_file = TOKEN_LIST_FILE.as_str(); - // 验证 AUTH_TOKEN let auth_header = headers .get(HEADER_NAME_AUTHORIZATION) @@ -69,20 +82,37 @@ pub async fn handle_get_tokeninfo( .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(StatusCode::UNAUTHORIZED)?; - if auth_header != auth_token { + if auth_header != AUTH_TOKEN.as_str() { return Err(StatusCode::UNAUTHORIZED); } + let token_file = TOKEN_FILE.as_str(); + let token_list_file = TOKEN_LIST_FILE.as_str(); + // 读取文件内容 let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new()); let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new()); + // 获取 tokens_count + let tokens_count = { + #[cfg(feature = "sqlite")] + { + APP_DB.lock() + .map(|db| db.get_token_infos().map(|v| v.len()).unwrap_or(0)) + .unwrap_or(0) + } + #[cfg(not(feature = "sqlite"))] + { + tokens.len() + } + }; + Ok(Json(TokenInfoResponse { status: ApiStatus::Success, token_file: token_file.to_string(), token_list_file: token_list_file.to_string(), - tokens: Some(tokens.clone()), - tokens_count: Some(tokens.len()), + tokens: Some(tokens), + tokens_count: Some(tokens_count), token_list: Some(token_list), message: None, })) @@ -104,14 +134,10 @@ pub struct TokenInfoResponse { } pub async fn handle_update_tokeninfo_post( - State(state): State>>, + #[cfg(not(feature = "sqlite"))] State(state): State>>, headers: HeaderMap, Json(request): Json, ) -> Result, StatusCode> { - let auth_token = AUTH_TOKEN.as_str(); - let token_file = TOKEN_FILE.as_str(); - let token_list_file = TOKEN_LIST_FILE.as_str(); - // 验证 AUTH_TOKEN let auth_header = headers .get(HEADER_NAME_AUTHORIZATION) @@ -119,15 +145,18 @@ pub async fn handle_update_tokeninfo_post( .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(StatusCode::UNAUTHORIZED)?; - if auth_header != auth_token { + if auth_header != AUTH_TOKEN.as_str() { return Err(StatusCode::UNAUTHORIZED); } - // 写入 .token 文件 - std::fs::write(&token_file, &request.tokens).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let token_file = TOKEN_FILE.as_str(); + let token_list_file = TOKEN_LIST_FILE.as_str(); - // 如果提供了 token_list,则写入 - if let Some(token_list) = request.token_list { + // 写入文件 + std::fs::write(&token_file, &request.tokens) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if let Some(token_list) = &request.token_list { std::fs::write(&token_list_file, token_list) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; } @@ -137,11 +166,21 @@ pub async fn handle_update_tokeninfo_post( let token_infos_len = token_infos.len(); // 更新应用状态 + #[cfg(not(feature = "sqlite"))] { let mut state = state.lock().await; state.token_infos = token_infos; } + #[cfg(feature = "sqlite")] + { + if let Ok(db) = APP_DB.lock() { + for token_info in token_infos { + let _ = db.update_token_info(&token_info); + } + } + } + Ok(Json(TokenInfoResponse { status: ApiStatus::Success, token_file: token_file.to_string(), diff --git a/src/chat/route/usage.rs b/src/chat/route/usage.rs index ae16f50..ebd9466 100644 --- a/src/chat/route/usage.rs +++ b/src/chat/route/usage.rs @@ -1,5 +1,6 @@ use crate::{ app::model::AppState, + chat::constant::ERR_NODATA, common::{models::usage::GetUserInfo, utils::get_user_usage}, }; use axum::{ @@ -26,11 +27,11 @@ pub async fn get_user_info( let (auth_token, checksum) = match token_info { Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()), - None => return Json(GetUserInfo::Error("No data".to_string())), + None => return Json(GetUserInfo::Error(ERR_NODATA.to_string())), }; match get_user_usage(&auth_token, &checksum).await { Some(usage) => Json(GetUserInfo::Usage(usage)), - None => Json(GetUserInfo::Error("No data".to_string())), + None => Json(GetUserInfo::Error(ERR_NODATA.to_string())), } } diff --git a/src/chat/service.rs b/src/chat/service.rs index 24a7758..fea3787 100644 --- a/src/chat/service.rs +++ b/src/chat/service.rs @@ -147,7 +147,9 @@ pub async fn handle_chat( } } + let next_id = state.request_logs.last().map_or(1, |log| log.id + 1); state.request_logs.push(RequestLog { + id: next_id, timestamp: request_time, model: request.model.clone(), token_info: TokenInfo { @@ -420,11 +422,6 @@ pub async fn handle_chat( } Ok(Bytes::new()) } - Err(StreamError::ChatError(error)) => { - buffer_guard.clear(); - eprintln!("Stream error occurred: {}", error.to_json()); - Ok(Bytes::new()) - } Err(e) => { buffer_guard.clear(); eprintln!("[警告] Stream error: {}", e); @@ -480,7 +477,7 @@ pub async fn handle_chat( } Err(StreamError::ChatError(error)) => { return Err(( - StatusCode::from_u16(error.error.details[0].debug.status_code()) + StatusCode::from_u16(error.status_code()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), Json(error.to_error_response().to_common()), )); diff --git a/src/common/client.rs b/src/common/client.rs index 70f888e..b5e69e1 100644 --- a/src/common/client.rs +++ b/src/common/client.rs @@ -1,7 +1,10 @@ -use crate::app::constant::{ - AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, - CURSOR_API2_BASE_URL, CURSOR_API2_HOST, CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION, - HEADER_NAME_CONTENT_TYPE, +use crate::app::{ + constant::{ + AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, + CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION, + HEADER_NAME_CONTENT_TYPE, + }, + lazy::{CURSOR_API2_HOST, CURSOR_API2_BASE_URL}, }; use reqwest::Client; use uuid::Uuid; @@ -17,7 +20,7 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest }; client - .post(format!("{}{}", CURSOR_API2_BASE_URL, endpoint)) + .post(format!("{}{}", *CURSOR_API2_BASE_URL, endpoint)) .header(HEADER_NAME_CONTENT_TYPE, content_type) .header( HEADER_NAME_AUTHORIZATION, @@ -32,5 +35,5 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest .header("x-cursor-timezone", "Asia/Shanghai") .header("x-ghost-mode", "false") .header("x-request-id", trace_id) - .header("Host", CURSOR_API2_HOST) + .header("Host", CURSOR_API2_HOST.clone()) } diff --git a/src/common/utils.rs b/src/common/utils.rs index 9e1a402..33dfa7b 100644 --- a/src/common/utils.rs +++ b/src/common/utils.rs @@ -1,6 +1,7 @@ mod checksum; pub use checksum::*; pub mod tokens; +pub mod oauth; use prost::Message as _; use crate::{app::constant::CURSOR_API2_GET_USER_INFO, chat::aiserver::v1::GetUserInfoResponse}; diff --git a/src/common/utils/checksum.rs b/src/common/utils/checksum.rs index 75f6326..1015c2a 100644 --- a/src/common/utils/checksum.rs +++ b/src/common/utils/checksum.rs @@ -42,3 +42,54 @@ pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String { None => format!("{}{}", encoded, device_id), } } + +pub fn validate_checksum(checksum: &str) -> bool { + // 首先检查是否包含基本的 base64 编码部分和 hash 格式的 device_id + let parts: Vec<&str> = checksum.split('/').collect(); + + match parts.len() { + // 没有 MAC 地址的情况 + 1 => { + // 检查是否包含 BASE64 编码的 timestamp (8字符) + 64字符的hash + if checksum.len() != 72 { + // 8 + 64 = 72 + return false; + } + + // 验证 BASE64 部分 + let base64_len = 8; + let encoded_part = &checksum[..base64_len]; + if !BASE64.decode(encoded_part).is_ok() { + return false; + } + + // 验证 device_id hash 部分 + let device_hash = &checksum[base64_len..]; + is_valid_hash(device_hash) + } + // 包含 MAC hash 的情况 + 2 => { + let first_part = parts[0]; + let mac_hash = parts[1]; + + // MAC hash 必须是64字符的十六进制 + if !is_valid_hash(mac_hash) { + return false; + } + + // 递归验证第一部分 + validate_checksum(first_part) + } + _ => false, + } +} + +fn is_valid_hash(hash: &str) -> bool { + // 检查长度是否为64 + if hash.len() != 64 { + return false; + } + + // 检查是否都是有效的十六进制字符 + hash.chars().all(|c| c.is_ascii_hexdigit()) +} diff --git a/src/common/utils/oauth.rs b/src/common/utils/oauth.rs new file mode 100644 index 0000000..4732f3f --- /dev/null +++ b/src/common/utils/oauth.rs @@ -0,0 +1,80 @@ +use anyhow::Result; +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +const OAUTH_AUTHORIZE_URL: &str = "https://connect.linux.do/oauth2/authorize"; +const OAUTH_TOKEN_URL: &str = "https://connect.linux.do/oauth2/token"; +const OAUTH_USER_INFO_URL: &str = "https://connect.linux.do/api/user"; + +#[derive(Debug, Serialize, Deserialize)] +pub struct ForumUser { + pub id: i64, + pub username: String, + pub name: String, + pub active: bool, + pub trust_level: i32, + pub silenced: bool, +} + +pub struct ForumOAuth { + client_id: String, + client_secret: String, + redirect_uri: String, + http_client: Client, +} + +impl ForumOAuth { + pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self { + Self { + client_id, + client_secret, + redirect_uri, + http_client: Client::new(), + } + } + + pub fn get_authorize_url(&self, state: &str) -> String { + format!( + "{}?response_type=code&client_id={}&redirect_uri={}&state={}", + OAUTH_AUTHORIZE_URL, + self.client_id, + urlencoding::encode(&self.redirect_uri), + state + ) + } + + pub async fn exchange_code_for_token(&self, code: &str) -> Result { + let response = self + .http_client + .post(OAUTH_TOKEN_URL) + .form(&[ + ("grant_type", "authorization_code"), + ("code", code), + ("client_id", &self.client_id), + ("client_secret", &self.client_secret), + ("redirect_uri", &self.redirect_uri), + ]) + .send() + .await? + .json::() + .await?; + + Ok(response["access_token"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("No access token found"))? + .to_string()) + } + + pub async fn get_user_info(&self, access_token: &str) -> Result { + let user = self + .http_client + .get(OAUTH_USER_INFO_URL) + .bearer_auth(access_token) + .send() + .await? + .json::() + .await?; + + Ok(user) + } +} diff --git a/src/main.rs b/src/main.rs index b7a4389..e01c3b4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -63,6 +63,9 @@ async fn main() { let token_infos = load_tokens(); // 初始化应用状态 + #[cfg(feature = "sqlite")] + let state = Arc::new(Mutex::new(AppState::new())); + #[cfg(not(feature = "sqlite"))] let state = Arc::new(Mutex::new(AppState::new(token_infos))); // 设置路由 diff --git a/static/logs.html b/static/logs.html index 8401fe5..c408ce1 100644 --- a/static/logs.html +++ b/static/logs.html @@ -200,6 +200,7 @@ + @@ -304,6 +305,7 @@ tbody.innerHTML = data.logs.map(log => ` +
id 时间 模型 Token信息
${log.id} ${new Date(log.timestamp).toLocaleString()} ${log.model}