mirror of
https://github.com/wisdgod/cursor-api.git
synced 2025-10-05 14:46:53 +08:00
这是可回退普通版的提交
This commit is contained in:
@@ -37,3 +37,6 @@ VISION_ABILITY=base64
|
||||
|
||||
# 默认提示词
|
||||
DEFAULT_INSTRUCTIONS="Respond in Chinese by default"
|
||||
|
||||
# 反向代理服务器主机名
|
||||
CURSOR_API2_HOST=
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -17,3 +17,4 @@ node_modules
|
||||
/release
|
||||
|
||||
/*.py
|
||||
/logs
|
||||
|
88
Cargo.lock
generated
88
Cargo.lock
generated
@@ -17,6 +17,18 @@ version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.3"
|
||||
@@ -292,8 +304,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cursor-api"
|
||||
version = "0.1.3-rc.3"
|
||||
version = "0.1.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"base64",
|
||||
"bytes",
|
||||
@@ -311,6 +324,7 @@ dependencies = [
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
@@ -318,6 +332,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower-http",
|
||||
"urlencoding",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
@@ -379,6 +394,18 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
|
||||
|
||||
[[package]]
|
||||
name = "fallible-streaming-iterator"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
@@ -573,12 +600,30 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af"
|
||||
dependencies = [
|
||||
"hashbrown 0.14.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
@@ -906,7 +951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -952,6 +997,17 @@ version = "0.2.169"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.14"
|
||||
@@ -1318,9 +1374,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.11"
|
||||
version = "0.12.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7fe060fe50f524be480214aba758c71f99f90ee8c83c5a36b5e9e1d568eb4eb3"
|
||||
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
|
||||
dependencies = [
|
||||
"async-compression",
|
||||
"base64",
|
||||
@@ -1378,6 +1434,20 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusqlite"
|
||||
version = "0.32.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"fallible-iterator",
|
||||
"fallible-streaming-iterator",
|
||||
"hashlink",
|
||||
"libsqlite3-sys",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.24"
|
||||
@@ -1602,9 +1672,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.92"
|
||||
version = "2.0.94"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70ae51629bf965c5c098cc9e87908a3df5301051a9e087d6f9bef5c9771ed126"
|
||||
checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -1856,6 +1926,12 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urlencoding"
|
||||
version = "2.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||
|
||||
[[package]]
|
||||
name = "utf16_iter"
|
||||
version = "1.0.5"
|
||||
|
24
Cargo.toml
24
Cargo.toml
@@ -1,10 +1,8 @@
|
||||
[package]
|
||||
name = "cursor-api"
|
||||
version = "0.1.3-rc.3"
|
||||
version = "0.1.3"
|
||||
edition = "2021"
|
||||
authors = ["wisdgod <nav@wisdgod.com>"]
|
||||
# license = "MIT"
|
||||
# copyright = "Copyright (c) 2024 wisdgod"
|
||||
description = "OpenAI format compatibility layer for the Cursor API"
|
||||
repository = "https://github.com/wisdgod/cursor-api"
|
||||
|
||||
@@ -14,6 +12,7 @@ sha2 = { version = "0.10.8", default-features = false }
|
||||
serde_json = "1.0.134"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.95"
|
||||
axum = { version = "0.7.9", features = ["json"] }
|
||||
base64 = { version = "0.22.1", default-features = false, features = ["std"] }
|
||||
# brotli = { version = "7.0.0", default-features = false, features = ["std"] }
|
||||
@@ -30,7 +29,8 @@ paste = "1.0.15"
|
||||
prost = "0.13.4"
|
||||
rand = { version = "0.8.5", default-features = false, features = ["std", "std_rng"] }
|
||||
regex = { version = "1.11.1", default-features = false, features = ["std", "perf"] }
|
||||
reqwest = { version = "0.12.11", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] }
|
||||
reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] }
|
||||
rusqlite = { version = "0.32.1", features = ["bundled"], optional = true }
|
||||
serde = { version = "1.0.217", default-features = false, features = ["std", "derive"] }
|
||||
serde_json = "1.0.134"
|
||||
sha2 = { version = "0.10.8", default-features = false }
|
||||
@@ -38,6 +38,7 @@ sysinfo = { version = "0.33.1", default-features = false, features = ["system"]
|
||||
tokio = { version = "1.42.0", features = ["rt-multi-thread", "macros", "net", "sync", "time"] }
|
||||
tokio-stream = { version = "0.1.17", features = ["time"] }
|
||||
tower-http = { version = "0.6.2", features = ["cors"] }
|
||||
urlencoding = "2.1.3"
|
||||
uuid = { version = "1.11.0", features = ["v4"] }
|
||||
|
||||
[profile.release]
|
||||
@@ -47,15 +48,6 @@ panic = 'abort'
|
||||
strip = true
|
||||
opt-level = 3
|
||||
|
||||
# 构建脚本设置
|
||||
[package.metadata.cross.target.x86_64-unknown-linux-gnu]
|
||||
image = "ghcr.io/cross-rs/x86_64-unknown-linux-gnu:main"
|
||||
|
||||
[package.metadata.cross.target.aarch64-unknown-linux-gnu]
|
||||
image = "ghcr.io/cross-rs/aarch64-unknown-linux-gnu:main"
|
||||
|
||||
[package.metadata.cross.target.x86_64-apple-darwin]
|
||||
image = "ghcr.io/cross-rs/x86_64-apple-darwin:main"
|
||||
|
||||
[package.metadata.cross.target.aarch64-apple-darwin]
|
||||
image = "ghcr.io/cross-rs/aarch64-apple-darwin:main"
|
||||
[features]
|
||||
default = []
|
||||
sqlite = ["dep:rusqlite"]
|
||||
|
@@ -1,4 +1,6 @@
|
||||
pub mod config;
|
||||
pub mod constant;
|
||||
#[cfg(feature = "sqlite")]
|
||||
pub mod db;
|
||||
pub mod model;
|
||||
pub mod lazy;
|
||||
|
@@ -50,9 +50,6 @@ def_pub_const!(AUTHORIZATION_BEARER_PREFIX, "Bearer ");
|
||||
def_pub_const!(OBJECT_CHAT_COMPLETION, "chat.completion");
|
||||
def_pub_const!(OBJECT_CHAT_COMPLETION_CHUNK, "chat.completion.chunk");
|
||||
|
||||
def_pub_const!(CURSOR_API2_HOST, "api2.cursor.sh");
|
||||
def_pub_const!(CURSOR_API2_BASE_URL, "https://api2.cursor.sh/aiserver.v1.AiService/");
|
||||
|
||||
def_pub_const!(CURSOR_API2_STREAM_CHAT, "StreamChat");
|
||||
def_pub_const!(CURSOR_API2_GET_USER_INFO, "GetUserInfo");
|
||||
|
||||
|
262
src/app/db.rs
Normal file
262
src/app/db.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use crate::app::model::{RequestLog, TokenInfo};
|
||||
use crate::common::models::usage::UserUsageInfo;
|
||||
use chrono::{DateTime, Local};
|
||||
use lazy_static::lazy_static;
|
||||
use rusqlite::params;
|
||||
use rusqlite::{Connection, Result};
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
|
||||
const DB_PATH: &str = "logs/sqlite.db";
|
||||
|
||||
pub struct AppDb {
|
||||
conn: Connection,
|
||||
}
|
||||
|
||||
impl AppDb {
|
||||
pub fn new() -> Result<Self> {
|
||||
// 确保目录存在
|
||||
if let Some(parent) = Path::new(DB_PATH).parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
rusqlite::Error::SqliteFailure(
|
||||
rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_IOERR),
|
||||
Some(e.to_string()),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(DB_PATH)?;
|
||||
|
||||
// 启用WAL模式以提升性能
|
||||
conn.execute_batch("PRAGMA journal_mode = WAL")?;
|
||||
|
||||
// 创建token信息表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS token_infos (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
checksum TEXT NOT NULL,
|
||||
alias TEXT,
|
||||
fast_requests INTEGER,
|
||||
max_fast_requests INTEGER
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
// 创建请求日志表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS request_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
token_id INTEGER NOT NULL,
|
||||
prompt TEXT,
|
||||
stream BOOLEAN NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
error TEXT,
|
||||
FOREIGN KEY(token_id) REFERENCES token_infos(id)
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
// 创建索引
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_token ON token_infos(token)",
|
||||
[],
|
||||
)?;
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_timestamp_model ON request_logs(timestamp, model)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(Self { conn })
|
||||
}
|
||||
|
||||
fn get_or_create_token_info(&self, token_info: &TokenInfo) -> Result<i64> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||
RETURNING id"
|
||||
)?;
|
||||
|
||||
stmt.query_row(
|
||||
params![
|
||||
&token_info.token,
|
||||
&token_info.checksum,
|
||||
&token_info.alias,
|
||||
token_info.usage.as_ref().map(|u| u.fast_requests),
|
||||
token_info.usage.as_ref().map(|u| u.max_fast_requests),
|
||||
],
|
||||
|row| row.get(0),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn add_log(&self, log: &RequestLog) -> Result<()> {
|
||||
let token_id = self.get_or_create_token_info(&log.token_info)?;
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO request_logs (timestamp, model, token_id, prompt, stream, status, error)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
|
||||
params![
|
||||
log.timestamp.to_rfc3339(),
|
||||
&log.model,
|
||||
token_id,
|
||||
&log.prompt,
|
||||
log.stream,
|
||||
&log.status,
|
||||
&log.error,
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn map_row_to_log(&self, row: &rusqlite::Row) -> Result<RequestLog> {
|
||||
let token_id: i64 = row.get(3)?;
|
||||
let token_info = self.get_token_info_by_id(token_id)?;
|
||||
|
||||
Ok(RequestLog {
|
||||
id: row.get(0)?,
|
||||
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
|
||||
.unwrap()
|
||||
.with_timezone(&Local),
|
||||
model: row.get(2)?,
|
||||
token_info,
|
||||
prompt: row.get(4)?,
|
||||
stream: row.get(5)?,
|
||||
status: row.get(6)?,
|
||||
error: row.get(7)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_token_info_by_id(&self, id: i64) -> Result<TokenInfo> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"SELECT token, checksum, alias, fast_requests, max_fast_requests
|
||||
FROM token_infos
|
||||
WHERE id = ?",
|
||||
)?;
|
||||
|
||||
stmt.query_row([id], |row| {
|
||||
Ok(TokenInfo {
|
||||
token: row.get(0)?,
|
||||
checksum: row.get(1)?,
|
||||
alias: row.get(2)?,
|
||||
usage: Some(UserUsageInfo {
|
||||
fast_requests: row.get(3)?,
|
||||
max_fast_requests: row.get(4)?,
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_token_infos(&self) -> Result<Vec<TokenInfo>> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"SELECT token, checksum, alias, fast_requests, max_fast_requests
|
||||
FROM token_infos",
|
||||
)?;
|
||||
|
||||
let tokens = stmt.query_map([], |row| {
|
||||
Ok(TokenInfo {
|
||||
token: row.get(0)?,
|
||||
checksum: row.get(1)?,
|
||||
alias: row.get(2)?,
|
||||
usage: Some(UserUsageInfo {
|
||||
fast_requests: row.get(3)?,
|
||||
max_fast_requests: row.get(4)?,
|
||||
}),
|
||||
})
|
||||
})?;
|
||||
tokens.collect()
|
||||
}
|
||||
|
||||
pub fn get_recent_logs(&self, limit: i64) -> Result<Vec<RequestLog>> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests
|
||||
FROM request_logs r
|
||||
JOIN token_infos t ON r.token_id = t.id
|
||||
ORDER BY r.timestamp DESC
|
||||
LIMIT ?",
|
||||
)?;
|
||||
|
||||
let logs = stmt.query_map([limit], |row| {
|
||||
Ok(RequestLog {
|
||||
id: row.get(0)?,
|
||||
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
|
||||
.unwrap()
|
||||
.with_timezone(&Local),
|
||||
model: row.get(2)?,
|
||||
token_info: TokenInfo {
|
||||
token: row.get(8)?,
|
||||
checksum: row.get(9)?,
|
||||
alias: row.get(10)?,
|
||||
usage: Some(UserUsageInfo {
|
||||
fast_requests: row.get(11)?,
|
||||
max_fast_requests: row.get(12)?,
|
||||
}),
|
||||
},
|
||||
prompt: row.get(4)?,
|
||||
stream: row.get(5)?,
|
||||
status: row.get(6)?,
|
||||
error: row.get(7)?,
|
||||
})
|
||||
})?;
|
||||
logs.collect()
|
||||
}
|
||||
|
||||
pub fn get_logs_by_timerange(
|
||||
&self,
|
||||
start: DateTime<Local>,
|
||||
end: DateTime<Local>,
|
||||
) -> Result<Vec<RequestLog>> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests
|
||||
FROM request_logs r
|
||||
JOIN token_infos t ON r.token_id = t.id
|
||||
WHERE r.timestamp BETWEEN ?1 AND ?2
|
||||
ORDER BY r.timestamp DESC",
|
||||
)?;
|
||||
|
||||
let logs = stmt.query_map([start.to_rfc3339(), end.to_rfc3339()], |row| {
|
||||
Ok(RequestLog {
|
||||
id: row.get(0)?,
|
||||
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
|
||||
.unwrap()
|
||||
.with_timezone(&Local),
|
||||
model: row.get(2)?,
|
||||
token_info: TokenInfo {
|
||||
token: row.get(8)?,
|
||||
checksum: row.get(9)?,
|
||||
alias: row.get(10)?,
|
||||
usage: Some(UserUsageInfo {
|
||||
fast_requests: row.get(11)?,
|
||||
max_fast_requests: row.get(12)?,
|
||||
}),
|
||||
},
|
||||
prompt: row.get(4)?,
|
||||
stream: row.get(5)?,
|
||||
status: row.get(6)?,
|
||||
error: row.get(7)?,
|
||||
})
|
||||
})?;
|
||||
logs.collect()
|
||||
}
|
||||
|
||||
pub fn update_token_info(&self, token_info: &TokenInfo) -> Result<()> {
|
||||
self.conn.execute(
|
||||
"INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||
params![
|
||||
&token_info.token,
|
||||
&token_info.checksum,
|
||||
&token_info.alias,
|
||||
token_info.usage.as_ref().map(|u| u.fast_requests),
|
||||
token_info.usage.as_ref().map(|u| u.max_fast_requests),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
pub static ref APP_DB: Mutex<AppDb> =
|
||||
Mutex::new(AppDb::new().expect("Failed to initialize database"));
|
||||
}
|
@@ -33,11 +33,11 @@ def_pub_static!(TOKEN_FILE, env: "TOKEN_FILE", default: DEFAULT_TOKEN_FILE_NAME)
|
||||
def_pub_static!(TOKEN_LIST_FILE, env: "TOKEN_LIST_FILE", default: DEFAULT_TOKEN_LIST_FILE_NAME);
|
||||
def_pub_static!(
|
||||
ROUTE_MODELS_PATH,
|
||||
format!("{}/v1/models", ROUTE_PREFIX.as_str())
|
||||
format!("{}/v1/models", *ROUTE_PREFIX)
|
||||
);
|
||||
def_pub_static!(
|
||||
ROUTE_CHAT_PATH,
|
||||
format!("{}/v1/chat/completions", ROUTE_PREFIX.as_str())
|
||||
format!("{}/v1/chat/completions", *ROUTE_PREFIX)
|
||||
);
|
||||
|
||||
pub static START_TIME: LazyLock<chrono::DateTime<chrono::Local>> =
|
||||
@@ -49,6 +49,12 @@ pub fn get_start_time() -> chrono::DateTime<chrono::Local> {
|
||||
|
||||
def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default");
|
||||
|
||||
def_pub_static!(CURSOR_API2_HOST, env: "REVERSE_PROXY_HOST", default: "api2.cursor.sh");
|
||||
|
||||
pub static CURSOR_API2_BASE_URL: LazyLock<String> = LazyLock::new(|| {
|
||||
format!("https://{}/aiserver.v1.AiService/", *CURSOR_API2_HOST)
|
||||
});
|
||||
|
||||
// pub static DEBUG: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("DEBUG", false));
|
||||
|
||||
// #[macro_export]
|
||||
|
@@ -87,7 +87,9 @@ pub struct Pages {
|
||||
pub struct AppState {
|
||||
pub total_requests: u64,
|
||||
pub active_requests: u64,
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
pub request_logs: Vec<RequestLog>,
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
pub token_infos: Vec<TokenInfo>,
|
||||
}
|
||||
|
||||
@@ -273,6 +275,7 @@ impl AppConfig {
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
pub fn new(token_infos: Vec<TokenInfo>) -> Self {
|
||||
Self {
|
||||
total_requests: 0,
|
||||
@@ -281,11 +284,20 @@ impl AppState {
|
||||
token_infos,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
total_requests: 0,
|
||||
active_requests: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 请求日志
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct RequestLog {
|
||||
pub id: u64,
|
||||
pub timestamp: chrono::DateTime<chrono::Local>,
|
||||
pub model: String,
|
||||
pub token_info: TokenInfo,
|
||||
|
@@ -7,6 +7,7 @@ macro_rules! def_pub_const {
|
||||
}
|
||||
def_pub_const!(ERR_UNSUPPORTED_GIF, "不支持动态 GIF");
|
||||
def_pub_const!(ERR_UNSUPPORTED_IMAGE_FORMAT, "不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF");
|
||||
def_pub_const!(ERR_NODATA, "No data");
|
||||
|
||||
const MODEL_OBJECT: &str = "model";
|
||||
const CREATED: &i64 = &1706659200;
|
||||
|
@@ -2,41 +2,66 @@ use super::aiserver::v1::throw_error_check_request::Error as ErrorType;
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Deserialize)]
|
||||
pub struct ChatError {
|
||||
pub error: ErrorBody,
|
||||
error: ErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Deserialize)]
|
||||
pub struct ErrorBody {
|
||||
pub code: String,
|
||||
pub message: String,
|
||||
pub details: Vec<ErrorDetail>,
|
||||
code: String,
|
||||
// message: String, always: Error
|
||||
details: Vec<ErrorDetail>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Deserialize)]
|
||||
pub struct ErrorDetail {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub debug: ErrorDebug,
|
||||
pub value: String,
|
||||
// #[serde(rename = "type")]
|
||||
// error_type: String, always: aiserver.v1.ErrorDetails
|
||||
debug: ErrorDebug,
|
||||
value: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Deserialize)]
|
||||
pub struct ErrorDebug {
|
||||
pub error: String,
|
||||
pub details: ErrorDetails,
|
||||
#[serde(rename = "isExpected")]
|
||||
pub is_expected: bool,
|
||||
error: String,
|
||||
details: ErrorDetails,
|
||||
// #[serde(rename = "isExpected")]
|
||||
// is_expected: Option<bool>,
|
||||
}
|
||||
|
||||
impl ErrorDebug {
|
||||
// pub fn is_valid(&self) -> bool {
|
||||
// ErrorType::from_str_name(&self.error).is_some()
|
||||
// }
|
||||
#[derive(Deserialize)]
|
||||
pub struct ErrorDetails {
|
||||
title: String,
|
||||
detail: String,
|
||||
// #[serde(rename = "isRetryable")]
|
||||
// is_retryable: Option<bool>,
|
||||
}
|
||||
|
||||
use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse};
|
||||
|
||||
impl ChatError {
|
||||
pub fn to_error_response(&self) -> ErrorResponse {
|
||||
if self.error.details.is_empty() {
|
||||
return ErrorResponse {
|
||||
status: 500,
|
||||
code: "unknown".to_string(),
|
||||
error: None,
|
||||
};
|
||||
}
|
||||
ErrorResponse {
|
||||
status: self.status_code(),
|
||||
code: self.error.code.clone(),
|
||||
error: Some(Error {
|
||||
message: self.error.details[0].debug.details.title.clone(),
|
||||
details: self.error.details[0].debug.details.detail.clone(),
|
||||
value: self.error.details[0].value.clone(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn status_code(&self) -> u16 {
|
||||
match ErrorType::from_str_name(&self.error) {
|
||||
match ErrorType::from_str_name(&self.error.details[0].debug.error) {
|
||||
Some(error) => match error {
|
||||
ErrorType::Unspecified => 500,
|
||||
ErrorType::BadApiKey
|
||||
@@ -68,46 +93,26 @@ impl ErrorDebug {
|
||||
| ErrorType::SlashEditFileTooLong
|
||||
| ErrorType::FileUnsupported
|
||||
| ErrorType::ClaudeImageTooLarge => 400,
|
||||
_ => 500,
|
||||
ErrorType::Deprecated
|
||||
| ErrorType::FreeUserUsageLimit
|
||||
| ErrorType::ProUserUsageLimit
|
||||
| ErrorType::ResourceExhausted
|
||||
| ErrorType::Openai
|
||||
| ErrorType::MaxTokens
|
||||
| ErrorType::ApiKeyNotSupported
|
||||
| ErrorType::UserAbortedRequest
|
||||
| ErrorType::CustomMessage
|
||||
| ErrorType::OutdatedClient
|
||||
| ErrorType::Debounced
|
||||
| ErrorType::RepositoryServiceRepositoryIsNotInitialized => 500,
|
||||
},
|
||||
None => 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ErrorDetails {
|
||||
pub title: String,
|
||||
pub detail: String,
|
||||
#[serde(rename = "isRetryable")]
|
||||
pub is_retryable: bool,
|
||||
}
|
||||
|
||||
use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse};
|
||||
|
||||
impl ChatError {
|
||||
pub fn to_json(&self) -> serde_json::Value {
|
||||
serde_json::to_value(self).unwrap()
|
||||
}
|
||||
|
||||
pub fn to_error_response(&self) -> ErrorResponse {
|
||||
if self.error.details.is_empty() {
|
||||
return ErrorResponse {
|
||||
status: 500,
|
||||
code: "ERROR_UNKNOWN".to_string(),
|
||||
error: None,
|
||||
};
|
||||
}
|
||||
ErrorResponse {
|
||||
status: self.error.details[0].debug.status_code(),
|
||||
code: self.error.details[0].debug.error.clone(),
|
||||
error: Some(Error {
|
||||
message: self.error.details[0].debug.details.title.clone(),
|
||||
details: self.error.details[0].debug.details.detail.clone(),
|
||||
value: self.error.details[0].value.clone(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
// pub fn is_expected(&self) -> bool {
|
||||
// self.error.details[0].debug.is_expected.unwrap_or_default()
|
||||
// }
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -135,7 +140,7 @@ impl ErrorResponse {
|
||||
}
|
||||
|
||||
pub fn native_code(&self) -> String {
|
||||
self.code.replace("_", " ").to_lowercase()
|
||||
self.code.replace("_", " ")
|
||||
}
|
||||
|
||||
pub fn to_common(self) -> CommonErrorResponse {
|
||||
@@ -157,7 +162,7 @@ pub enum StreamError {
|
||||
impl std::fmt::Display for StreamError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
StreamError::ChatError(error) => write!(f, "{}", serde_json::to_string(error).unwrap()),
|
||||
StreamError::ChatError(error) => write!(f, "{}", error.error.details[0].debug.details.title),
|
||||
StreamError::DataLengthLessThan5 => write!(f, "data length less than 5"),
|
||||
StreamError::EmptyMessage => write!(f, "empty message"),
|
||||
}
|
||||
|
@@ -5,7 +5,7 @@ use crate::{
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE,
|
||||
ROUTE_TOKENINFO_PATH,
|
||||
},
|
||||
model::{AppConfig, AppState, PageContent, TokenUpdateRequest},
|
||||
model::{AppConfig, PageContent, TokenUpdateRequest},
|
||||
lazy::{AUTH_TOKEN, TOKEN_FILE, TOKEN_LIST_FILE},
|
||||
},
|
||||
common::{
|
||||
@@ -13,15 +13,22 @@ use crate::{
|
||||
utils::{generate_checksum, generate_hash, tokens::load_tokens},
|
||||
},
|
||||
};
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
use crate::app::model::AppState;
|
||||
#[cfg(feature = "sqlite")]
|
||||
use crate::app::db::APP_DB;
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::HeaderMap,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
use axum::extract::State;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Serialize;
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
use std::sync::Arc;
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -36,17 +43,28 @@ pub async fn handle_get_checksum() -> Json<ChecksumResponse> {
|
||||
|
||||
// 更新 TokenInfo 处理
|
||||
pub async fn handle_update_tokeninfo(
|
||||
State(state): State<Arc<Mutex<AppState>>>,
|
||||
#[cfg(not(feature = "sqlite"))] State(state): State<Arc<Mutex<AppState>>>,
|
||||
) -> Json<NormalResponseNoData> {
|
||||
// 重新加载 tokens
|
||||
let token_infos = load_tokens();
|
||||
|
||||
// 更新应用状态
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.token_infos = token_infos;
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
// 使用 APP_DB 更新 token_infos
|
||||
if let Ok(db) = APP_DB.lock() {
|
||||
for token_info in token_infos {
|
||||
let _ = db.update_token_info(&token_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Json(NormalResponseNoData {
|
||||
status: ApiStatus::Success,
|
||||
message: Some("Token list has been reloaded".to_string()),
|
||||
@@ -55,13 +73,8 @@ pub async fn handle_update_tokeninfo(
|
||||
|
||||
// 获取 TokenInfo 处理
|
||||
pub async fn handle_get_tokeninfo(
|
||||
State(_state): State<Arc<Mutex<AppState>>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<TokenInfoResponse>, StatusCode> {
|
||||
let auth_token = AUTH_TOKEN.as_str();
|
||||
let token_file = TOKEN_FILE.as_str();
|
||||
let token_list_file = TOKEN_LIST_FILE.as_str();
|
||||
|
||||
// 验证 AUTH_TOKEN
|
||||
let auth_header = headers
|
||||
.get(HEADER_NAME_AUTHORIZATION)
|
||||
@@ -69,20 +82,37 @@ pub async fn handle_get_tokeninfo(
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if auth_header != auth_token {
|
||||
if auth_header != AUTH_TOKEN.as_str() {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let token_file = TOKEN_FILE.as_str();
|
||||
let token_list_file = TOKEN_LIST_FILE.as_str();
|
||||
|
||||
// 读取文件内容
|
||||
let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new());
|
||||
let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new());
|
||||
|
||||
// 获取 tokens_count
|
||||
let tokens_count = {
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
APP_DB.lock()
|
||||
.map(|db| db.get_token_infos().map(|v| v.len()).unwrap_or(0))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
{
|
||||
tokens.len()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(TokenInfoResponse {
|
||||
status: ApiStatus::Success,
|
||||
token_file: token_file.to_string(),
|
||||
token_list_file: token_list_file.to_string(),
|
||||
tokens: Some(tokens.clone()),
|
||||
tokens_count: Some(tokens.len()),
|
||||
tokens: Some(tokens),
|
||||
tokens_count: Some(tokens_count),
|
||||
token_list: Some(token_list),
|
||||
message: None,
|
||||
}))
|
||||
@@ -104,14 +134,10 @@ pub struct TokenInfoResponse {
|
||||
}
|
||||
|
||||
pub async fn handle_update_tokeninfo_post(
|
||||
State(state): State<Arc<Mutex<AppState>>>,
|
||||
#[cfg(not(feature = "sqlite"))] State(state): State<Arc<Mutex<AppState>>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<TokenUpdateRequest>,
|
||||
) -> Result<Json<TokenInfoResponse>, StatusCode> {
|
||||
let auth_token = AUTH_TOKEN.as_str();
|
||||
let token_file = TOKEN_FILE.as_str();
|
||||
let token_list_file = TOKEN_LIST_FILE.as_str();
|
||||
|
||||
// 验证 AUTH_TOKEN
|
||||
let auth_header = headers
|
||||
.get(HEADER_NAME_AUTHORIZATION)
|
||||
@@ -119,15 +145,18 @@ pub async fn handle_update_tokeninfo_post(
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if auth_header != auth_token {
|
||||
if auth_header != AUTH_TOKEN.as_str() {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// 写入 .token 文件
|
||||
std::fs::write(&token_file, &request.tokens).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let token_file = TOKEN_FILE.as_str();
|
||||
let token_list_file = TOKEN_LIST_FILE.as_str();
|
||||
|
||||
// 如果提供了 token_list,则写入
|
||||
if let Some(token_list) = request.token_list {
|
||||
// 写入文件
|
||||
std::fs::write(&token_file, &request.tokens)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if let Some(token_list) = &request.token_list {
|
||||
std::fs::write(&token_list_file, token_list)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
@@ -137,11 +166,21 @@ pub async fn handle_update_tokeninfo_post(
|
||||
let token_infos_len = token_infos.len();
|
||||
|
||||
// 更新应用状态
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.token_infos = token_infos;
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
if let Ok(db) = APP_DB.lock() {
|
||||
for token_info in token_infos {
|
||||
let _ = db.update_token_info(&token_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(TokenInfoResponse {
|
||||
status: ApiStatus::Success,
|
||||
token_file: token_file.to_string(),
|
||||
|
@@ -1,5 +1,6 @@
|
||||
use crate::{
|
||||
app::model::AppState,
|
||||
chat::constant::ERR_NODATA,
|
||||
common::{models::usage::GetUserInfo, utils::get_user_usage},
|
||||
};
|
||||
use axum::{
|
||||
@@ -26,11 +27,11 @@ pub async fn get_user_info(
|
||||
|
||||
let (auth_token, checksum) = match token_info {
|
||||
Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()),
|
||||
None => return Json(GetUserInfo::Error("No data".to_string())),
|
||||
None => return Json(GetUserInfo::Error(ERR_NODATA.to_string())),
|
||||
};
|
||||
|
||||
match get_user_usage(&auth_token, &checksum).await {
|
||||
Some(usage) => Json(GetUserInfo::Usage(usage)),
|
||||
None => Json(GetUserInfo::Error("No data".to_string())),
|
||||
None => Json(GetUserInfo::Error(ERR_NODATA.to_string())),
|
||||
}
|
||||
}
|
||||
|
@@ -147,7 +147,9 @@ pub async fn handle_chat(
|
||||
}
|
||||
}
|
||||
|
||||
let next_id = state.request_logs.last().map_or(1, |log| log.id + 1);
|
||||
state.request_logs.push(RequestLog {
|
||||
id: next_id,
|
||||
timestamp: request_time,
|
||||
model: request.model.clone(),
|
||||
token_info: TokenInfo {
|
||||
@@ -420,11 +422,6 @@ pub async fn handle_chat(
|
||||
}
|
||||
Ok(Bytes::new())
|
||||
}
|
||||
Err(StreamError::ChatError(error)) => {
|
||||
buffer_guard.clear();
|
||||
eprintln!("Stream error occurred: {}", error.to_json());
|
||||
Ok(Bytes::new())
|
||||
}
|
||||
Err(e) => {
|
||||
buffer_guard.clear();
|
||||
eprintln!("[警告] Stream error: {}", e);
|
||||
@@ -480,7 +477,7 @@ pub async fn handle_chat(
|
||||
}
|
||||
Err(StreamError::ChatError(error)) => {
|
||||
return Err((
|
||||
StatusCode::from_u16(error.error.details[0].debug.status_code())
|
||||
StatusCode::from_u16(error.status_code())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Json(error.to_error_response().to_common()),
|
||||
));
|
||||
|
@@ -1,7 +1,10 @@
|
||||
use crate::app::constant::{
|
||||
use crate::app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO,
|
||||
CURSOR_API2_BASE_URL, CURSOR_API2_HOST, CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION,
|
||||
CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION,
|
||||
HEADER_NAME_CONTENT_TYPE,
|
||||
},
|
||||
lazy::{CURSOR_API2_HOST, CURSOR_API2_BASE_URL},
|
||||
};
|
||||
use reqwest::Client;
|
||||
use uuid::Uuid;
|
||||
@@ -17,7 +20,7 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest
|
||||
};
|
||||
|
||||
client
|
||||
.post(format!("{}{}", CURSOR_API2_BASE_URL, endpoint))
|
||||
.post(format!("{}{}", *CURSOR_API2_BASE_URL, endpoint))
|
||||
.header(HEADER_NAME_CONTENT_TYPE, content_type)
|
||||
.header(
|
||||
HEADER_NAME_AUTHORIZATION,
|
||||
@@ -32,5 +35,5 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest
|
||||
.header("x-cursor-timezone", "Asia/Shanghai")
|
||||
.header("x-ghost-mode", "false")
|
||||
.header("x-request-id", trace_id)
|
||||
.header("Host", CURSOR_API2_HOST)
|
||||
.header("Host", CURSOR_API2_HOST.clone())
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
mod checksum;
|
||||
pub use checksum::*;
|
||||
pub mod tokens;
|
||||
pub mod oauth;
|
||||
use prost::Message as _;
|
||||
|
||||
use crate::{app::constant::CURSOR_API2_GET_USER_INFO, chat::aiserver::v1::GetUserInfoResponse};
|
||||
|
@@ -42,3 +42,54 @@ pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String {
|
||||
None => format!("{}{}", encoded, device_id),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_checksum(checksum: &str) -> bool {
|
||||
// 首先检查是否包含基本的 base64 编码部分和 hash 格式的 device_id
|
||||
let parts: Vec<&str> = checksum.split('/').collect();
|
||||
|
||||
match parts.len() {
|
||||
// 没有 MAC 地址的情况
|
||||
1 => {
|
||||
// 检查是否包含 BASE64 编码的 timestamp (8字符) + 64字符的hash
|
||||
if checksum.len() != 72 {
|
||||
// 8 + 64 = 72
|
||||
return false;
|
||||
}
|
||||
|
||||
// 验证 BASE64 部分
|
||||
let base64_len = 8;
|
||||
let encoded_part = &checksum[..base64_len];
|
||||
if !BASE64.decode(encoded_part).is_ok() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 验证 device_id hash 部分
|
||||
let device_hash = &checksum[base64_len..];
|
||||
is_valid_hash(device_hash)
|
||||
}
|
||||
// 包含 MAC hash 的情况
|
||||
2 => {
|
||||
let first_part = parts[0];
|
||||
let mac_hash = parts[1];
|
||||
|
||||
// MAC hash 必须是64字符的十六进制
|
||||
if !is_valid_hash(mac_hash) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 递归验证第一部分
|
||||
validate_checksum(first_part)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_valid_hash(hash: &str) -> bool {
|
||||
// 检查长度是否为64
|
||||
if hash.len() != 64 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查是否都是有效的十六进制字符
|
||||
hash.chars().all(|c| c.is_ascii_hexdigit())
|
||||
}
|
||||
|
80
src/common/utils/oauth.rs
Normal file
80
src/common/utils/oauth.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use anyhow::Result;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const OAUTH_AUTHORIZE_URL: &str = "https://connect.linux.do/oauth2/authorize";
|
||||
const OAUTH_TOKEN_URL: &str = "https://connect.linux.do/oauth2/token";
|
||||
const OAUTH_USER_INFO_URL: &str = "https://connect.linux.do/api/user";
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ForumUser {
|
||||
pub id: i64,
|
||||
pub username: String,
|
||||
pub name: String,
|
||||
pub active: bool,
|
||||
pub trust_level: i32,
|
||||
pub silenced: bool,
|
||||
}
|
||||
|
||||
pub struct ForumOAuth {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
redirect_uri: String,
|
||||
http_client: Client,
|
||||
}
|
||||
|
||||
impl ForumOAuth {
|
||||
pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
|
||||
Self {
|
||||
client_id,
|
||||
client_secret,
|
||||
redirect_uri,
|
||||
http_client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_authorize_url(&self, state: &str) -> String {
|
||||
format!(
|
||||
"{}?response_type=code&client_id={}&redirect_uri={}&state={}",
|
||||
OAUTH_AUTHORIZE_URL,
|
||||
self.client_id,
|
||||
urlencoding::encode(&self.redirect_uri),
|
||||
state
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn exchange_code_for_token(&self, code: &str) -> Result<String> {
|
||||
let response = self
|
||||
.http_client
|
||||
.post(OAUTH_TOKEN_URL)
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("client_id", &self.client_id),
|
||||
("client_secret", &self.client_secret),
|
||||
("redirect_uri", &self.redirect_uri),
|
||||
])
|
||||
.send()
|
||||
.await?
|
||||
.json::<serde_json::Value>()
|
||||
.await?;
|
||||
|
||||
Ok(response["access_token"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("No access token found"))?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
pub async fn get_user_info(&self, access_token: &str) -> Result<ForumUser> {
|
||||
let user = self
|
||||
.http_client
|
||||
.get(OAUTH_USER_INFO_URL)
|
||||
.bearer_auth(access_token)
|
||||
.send()
|
||||
.await?
|
||||
.json::<ForumUser>()
|
||||
.await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
}
|
@@ -63,6 +63,9 @@ async fn main() {
|
||||
let token_infos = load_tokens();
|
||||
|
||||
// 初始化应用状态
|
||||
#[cfg(feature = "sqlite")]
|
||||
let state = Arc::new(Mutex::new(AppState::new()));
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
let state = Arc::new(Mutex::new(AppState::new(token_infos)));
|
||||
|
||||
// 设置路由
|
||||
|
@@ -200,6 +200,7 @@
|
||||
<table id="logsTable">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>id</th>
|
||||
<th>时间</th>
|
||||
<th>模型</th>
|
||||
<th>Token信息</th>
|
||||
@@ -304,6 +305,7 @@
|
||||
|
||||
tbody.innerHTML = data.logs.map(log => `
|
||||
<tr>
|
||||
<td>${log.id}</td>
|
||||
<td>${new Date(log.timestamp).toLocaleString()}</td>
|
||||
<td>${log.model}</td>
|
||||
<td>
|
||||
|
Reference in New Issue
Block a user