这是可回退普通版的提交

This commit is contained in:
wisdgod
2025-01-04 02:08:16 +08:00
parent c709f9bfc7
commit ea19cbc70a
20 changed files with 655 additions and 121 deletions

View File

@@ -37,3 +37,6 @@ VISION_ABILITY=base64
# 默认提示词
DEFAULT_INSTRUCTIONS="Respond in Chinese by default"
# 反向代理服务器主机名
CURSOR_API2_HOST=

1
.gitignore vendored
View File

@@ -17,3 +17,4 @@ node_modules
/release
/*.py
/logs

88
Cargo.lock generated
View File

@@ -17,6 +17,18 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "ahash"
version = "0.8.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
"zerocopy",
]
[[package]]
name = "aho-corasick"
version = "1.1.3"
@@ -292,8 +304,9 @@ dependencies = [
[[package]]
name = "cursor-api"
version = "0.1.3-rc.3"
version = "0.1.3"
dependencies = [
"anyhow",
"axum",
"base64",
"bytes",
@@ -311,6 +324,7 @@ dependencies = [
"rand",
"regex",
"reqwest",
"rusqlite",
"serde",
"serde_json",
"sha2",
@@ -318,6 +332,7 @@ dependencies = [
"tokio",
"tokio-stream",
"tower-http",
"urlencoding",
"uuid",
]
@@ -379,6 +394,18 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "fallible-iterator"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fastrand"
version = "2.3.0"
@@ -573,12 +600,30 @@ dependencies = [
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
[[package]]
name = "hashlink"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af"
dependencies = [
"hashbrown 0.14.5",
]
[[package]]
name = "heck"
version = "0.5.0"
@@ -906,7 +951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f"
dependencies = [
"equivalent",
"hashbrown",
"hashbrown 0.15.2",
]
[[package]]
@@ -952,6 +997,17 @@ version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libsqlite3-sys"
version = "0.30.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.14"
@@ -1318,9 +1374,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.11"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fe060fe50f524be480214aba758c71f99f90ee8c83c5a36b5e9e1d568eb4eb3"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"async-compression",
"base64",
@@ -1378,6 +1434,20 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rusqlite"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e"
dependencies = [
"bitflags 2.6.0",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink",
"libsqlite3-sys",
"smallvec",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
@@ -1602,9 +1672,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
version = "2.0.92"
version = "2.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ae51629bf965c5c098cc9e87908a3df5301051a9e087d6f9bef5c9771ed126"
checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3"
dependencies = [
"proc-macro2",
"quote",
@@ -1856,6 +1926,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf16_iter"
version = "1.0.5"

View File

@@ -1,10 +1,8 @@
[package]
name = "cursor-api"
version = "0.1.3-rc.3"
version = "0.1.3"
edition = "2021"
authors = ["wisdgod <nav@wisdgod.com>"]
# license = "MIT"
# copyright = "Copyright (c) 2024 wisdgod"
description = "OpenAI format compatibility layer for the Cursor API"
repository = "https://github.com/wisdgod/cursor-api"
@@ -14,6 +12,7 @@ sha2 = { version = "0.10.8", default-features = false }
serde_json = "1.0.134"
[dependencies]
anyhow = "1.0.95"
axum = { version = "0.7.9", features = ["json"] }
base64 = { version = "0.22.1", default-features = false, features = ["std"] }
# brotli = { version = "7.0.0", default-features = false, features = ["std"] }
@@ -30,7 +29,8 @@ paste = "1.0.15"
prost = "0.13.4"
rand = { version = "0.8.5", default-features = false, features = ["std", "std_rng"] }
regex = { version = "1.11.1", default-features = false, features = ["std", "perf"] }
reqwest = { version = "0.12.11", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] }
reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] }
rusqlite = { version = "0.32.1", features = ["bundled"], optional = true }
serde = { version = "1.0.217", default-features = false, features = ["std", "derive"] }
serde_json = "1.0.134"
sha2 = { version = "0.10.8", default-features = false }
@@ -38,6 +38,7 @@ sysinfo = { version = "0.33.1", default-features = false, features = ["system"]
tokio = { version = "1.42.0", features = ["rt-multi-thread", "macros", "net", "sync", "time"] }
tokio-stream = { version = "0.1.17", features = ["time"] }
tower-http = { version = "0.6.2", features = ["cors"] }
urlencoding = "2.1.3"
uuid = { version = "1.11.0", features = ["v4"] }
[profile.release]
@@ -47,15 +48,6 @@ panic = 'abort'
strip = true
opt-level = 3
# 构建脚本设置
[package.metadata.cross.target.x86_64-unknown-linux-gnu]
image = "ghcr.io/cross-rs/x86_64-unknown-linux-gnu:main"
[package.metadata.cross.target.aarch64-unknown-linux-gnu]
image = "ghcr.io/cross-rs/aarch64-unknown-linux-gnu:main"
[package.metadata.cross.target.x86_64-apple-darwin]
image = "ghcr.io/cross-rs/x86_64-apple-darwin:main"
[package.metadata.cross.target.aarch64-apple-darwin]
image = "ghcr.io/cross-rs/aarch64-apple-darwin:main"
[features]
default = []
sqlite = ["dep:rusqlite"]

View File

@@ -1,4 +1,6 @@
pub mod config;
pub mod constant;
#[cfg(feature = "sqlite")]
pub mod db;
pub mod model;
pub mod lazy;

View File

@@ -50,9 +50,6 @@ def_pub_const!(AUTHORIZATION_BEARER_PREFIX, "Bearer ");
def_pub_const!(OBJECT_CHAT_COMPLETION, "chat.completion");
def_pub_const!(OBJECT_CHAT_COMPLETION_CHUNK, "chat.completion.chunk");
def_pub_const!(CURSOR_API2_HOST, "api2.cursor.sh");
def_pub_const!(CURSOR_API2_BASE_URL, "https://api2.cursor.sh/aiserver.v1.AiService/");
def_pub_const!(CURSOR_API2_STREAM_CHAT, "StreamChat");
def_pub_const!(CURSOR_API2_GET_USER_INFO, "GetUserInfo");

262
src/app/db.rs Normal file
View File

@@ -0,0 +1,262 @@
use crate::app::model::{RequestLog, TokenInfo};
use crate::common::models::usage::UserUsageInfo;
use chrono::{DateTime, Local};
use lazy_static::lazy_static;
use rusqlite::params;
use rusqlite::{Connection, Result};
use std::path::Path;
use std::sync::Mutex;
const DB_PATH: &str = "logs/sqlite.db";
pub struct AppDb {
conn: Connection,
}
impl AppDb {
pub fn new() -> Result<Self> {
// 确保目录存在
if let Some(parent) = Path::new(DB_PATH).parent() {
std::fs::create_dir_all(parent).map_err(|e| {
rusqlite::Error::SqliteFailure(
rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_IOERR),
Some(e.to_string()),
)
})?;
}
let conn = Connection::open(DB_PATH)?;
// 启用WAL模式以提升性能
conn.execute_batch("PRAGMA journal_mode = WAL")?;
// 创建token信息表
conn.execute(
"CREATE TABLE IF NOT EXISTS token_infos (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token TEXT NOT NULL UNIQUE,
checksum TEXT NOT NULL,
alias TEXT,
fast_requests INTEGER,
max_fast_requests INTEGER
)",
[],
)?;
// 创建请求日志表
conn.execute(
"CREATE TABLE IF NOT EXISTS request_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
model TEXT NOT NULL,
token_id INTEGER NOT NULL,
prompt TEXT,
stream BOOLEAN NOT NULL,
status TEXT NOT NULL,
error TEXT,
FOREIGN KEY(token_id) REFERENCES token_infos(id)
)",
[],
)?;
// 创建索引
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_token ON token_infos(token)",
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_timestamp_model ON request_logs(timestamp, model)",
[],
)?;
Ok(Self { conn })
}
fn get_or_create_token_info(&self, token_info: &TokenInfo) -> Result<i64> {
let mut stmt = self.conn.prepare_cached(
"INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests)
VALUES (?1, ?2, ?3, ?4, ?5)
RETURNING id"
)?;
stmt.query_row(
params![
&token_info.token,
&token_info.checksum,
&token_info.alias,
token_info.usage.as_ref().map(|u| u.fast_requests),
token_info.usage.as_ref().map(|u| u.max_fast_requests),
],
|row| row.get(0),
)
}
pub fn add_log(&self, log: &RequestLog) -> Result<()> {
let token_id = self.get_or_create_token_info(&log.token_info)?;
self.conn.execute(
"INSERT INTO request_logs (timestamp, model, token_id, prompt, stream, status, error)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
params![
log.timestamp.to_rfc3339(),
&log.model,
token_id,
&log.prompt,
log.stream,
&log.status,
&log.error,
],
)?;
Ok(())
}
fn map_row_to_log(&self, row: &rusqlite::Row) -> Result<RequestLog> {
let token_id: i64 = row.get(3)?;
let token_info = self.get_token_info_by_id(token_id)?;
Ok(RequestLog {
id: row.get(0)?,
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
.unwrap()
.with_timezone(&Local),
model: row.get(2)?,
token_info,
prompt: row.get(4)?,
stream: row.get(5)?,
status: row.get(6)?,
error: row.get(7)?,
})
}
fn get_token_info_by_id(&self, id: i64) -> Result<TokenInfo> {
let mut stmt = self.conn.prepare_cached(
"SELECT token, checksum, alias, fast_requests, max_fast_requests
FROM token_infos
WHERE id = ?",
)?;
stmt.query_row([id], |row| {
Ok(TokenInfo {
token: row.get(0)?,
checksum: row.get(1)?,
alias: row.get(2)?,
usage: Some(UserUsageInfo {
fast_requests: row.get(3)?,
max_fast_requests: row.get(4)?,
}),
})
})
}
pub fn get_token_infos(&self) -> Result<Vec<TokenInfo>> {
let mut stmt = self.conn.prepare_cached(
"SELECT token, checksum, alias, fast_requests, max_fast_requests
FROM token_infos",
)?;
let tokens = stmt.query_map([], |row| {
Ok(TokenInfo {
token: row.get(0)?,
checksum: row.get(1)?,
alias: row.get(2)?,
usage: Some(UserUsageInfo {
fast_requests: row.get(3)?,
max_fast_requests: row.get(4)?,
}),
})
})?;
tokens.collect()
}
pub fn get_recent_logs(&self, limit: i64) -> Result<Vec<RequestLog>> {
let mut stmt = self.conn.prepare_cached(
"SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests
FROM request_logs r
JOIN token_infos t ON r.token_id = t.id
ORDER BY r.timestamp DESC
LIMIT ?",
)?;
let logs = stmt.query_map([limit], |row| {
Ok(RequestLog {
id: row.get(0)?,
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
.unwrap()
.with_timezone(&Local),
model: row.get(2)?,
token_info: TokenInfo {
token: row.get(8)?,
checksum: row.get(9)?,
alias: row.get(10)?,
usage: Some(UserUsageInfo {
fast_requests: row.get(11)?,
max_fast_requests: row.get(12)?,
}),
},
prompt: row.get(4)?,
stream: row.get(5)?,
status: row.get(6)?,
error: row.get(7)?,
})
})?;
logs.collect()
}
pub fn get_logs_by_timerange(
&self,
start: DateTime<Local>,
end: DateTime<Local>,
) -> Result<Vec<RequestLog>> {
let mut stmt = self.conn.prepare_cached(
"SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests
FROM request_logs r
JOIN token_infos t ON r.token_id = t.id
WHERE r.timestamp BETWEEN ?1 AND ?2
ORDER BY r.timestamp DESC",
)?;
let logs = stmt.query_map([start.to_rfc3339(), end.to_rfc3339()], |row| {
Ok(RequestLog {
id: row.get(0)?,
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
.unwrap()
.with_timezone(&Local),
model: row.get(2)?,
token_info: TokenInfo {
token: row.get(8)?,
checksum: row.get(9)?,
alias: row.get(10)?,
usage: Some(UserUsageInfo {
fast_requests: row.get(11)?,
max_fast_requests: row.get(12)?,
}),
},
prompt: row.get(4)?,
stream: row.get(5)?,
status: row.get(6)?,
error: row.get(7)?,
})
})?;
logs.collect()
}
pub fn update_token_info(&self, token_info: &TokenInfo) -> Result<()> {
self.conn.execute(
"INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![
&token_info.token,
&token_info.checksum,
&token_info.alias,
token_info.usage.as_ref().map(|u| u.fast_requests),
token_info.usage.as_ref().map(|u| u.max_fast_requests),
],
)?;
Ok(())
}
}
lazy_static! {
pub static ref APP_DB: Mutex<AppDb> =
Mutex::new(AppDb::new().expect("Failed to initialize database"));
}

View File

@@ -33,11 +33,11 @@ def_pub_static!(TOKEN_FILE, env: "TOKEN_FILE", default: DEFAULT_TOKEN_FILE_NAME)
def_pub_static!(TOKEN_LIST_FILE, env: "TOKEN_LIST_FILE", default: DEFAULT_TOKEN_LIST_FILE_NAME);
def_pub_static!(
ROUTE_MODELS_PATH,
format!("{}/v1/models", ROUTE_PREFIX.as_str())
format!("{}/v1/models", *ROUTE_PREFIX)
);
def_pub_static!(
ROUTE_CHAT_PATH,
format!("{}/v1/chat/completions", ROUTE_PREFIX.as_str())
format!("{}/v1/chat/completions", *ROUTE_PREFIX)
);
pub static START_TIME: LazyLock<chrono::DateTime<chrono::Local>> =
@@ -49,6 +49,12 @@ pub fn get_start_time() -> chrono::DateTime<chrono::Local> {
def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default");
def_pub_static!(CURSOR_API2_HOST, env: "REVERSE_PROXY_HOST", default: "api2.cursor.sh");
pub static CURSOR_API2_BASE_URL: LazyLock<String> = LazyLock::new(|| {
format!("https://{}/aiserver.v1.AiService/", *CURSOR_API2_HOST)
});
// pub static DEBUG: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("DEBUG", false));
// #[macro_export]

View File

@@ -87,7 +87,9 @@ pub struct Pages {
pub struct AppState {
pub total_requests: u64,
pub active_requests: u64,
#[cfg(not(feature = "sqlite"))]
pub request_logs: Vec<RequestLog>,
#[cfg(not(feature = "sqlite"))]
pub token_infos: Vec<TokenInfo>,
}
@@ -273,6 +275,7 @@ impl AppConfig {
}
impl AppState {
#[cfg(not(feature = "sqlite"))]
pub fn new(token_infos: Vec<TokenInfo>) -> Self {
Self {
total_requests: 0,
@@ -281,11 +284,20 @@ impl AppState {
token_infos,
}
}
#[cfg(feature = "sqlite")]
pub fn new() -> Self {
Self {
total_requests: 0,
active_requests: 0,
}
}
}
// 请求日志
#[derive(Serialize, Clone)]
pub struct RequestLog {
pub id: u64,
pub timestamp: chrono::DateTime<chrono::Local>,
pub model: String,
pub token_info: TokenInfo,

View File

@@ -7,6 +7,7 @@ macro_rules! def_pub_const {
}
def_pub_const!(ERR_UNSUPPORTED_GIF, "不支持动态 GIF");
def_pub_const!(ERR_UNSUPPORTED_IMAGE_FORMAT, "不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF");
def_pub_const!(ERR_NODATA, "No data");
const MODEL_OBJECT: &str = "model";
const CREATED: &i64 = &1706659200;

View File

@@ -2,41 +2,66 @@ use super::aiserver::v1::throw_error_check_request::Error as ErrorType;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct ChatError {
pub error: ErrorBody,
error: ErrorBody,
}
#[derive(Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct ErrorBody {
pub code: String,
pub message: String,
pub details: Vec<ErrorDetail>,
code: String,
// message: String, always: Error
details: Vec<ErrorDetail>,
}
#[derive(Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct ErrorDetail {
#[serde(rename = "type")]
pub error_type: String,
pub debug: ErrorDebug,
pub value: String,
// #[serde(rename = "type")]
// error_type: String, always: aiserver.v1.ErrorDetails
debug: ErrorDebug,
value: String,
}
#[derive(Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct ErrorDebug {
pub error: String,
pub details: ErrorDetails,
#[serde(rename = "isExpected")]
pub is_expected: bool,
error: String,
details: ErrorDetails,
// #[serde(rename = "isExpected")]
// is_expected: Option<bool>,
}
impl ErrorDebug {
// pub fn is_valid(&self) -> bool {
// ErrorType::from_str_name(&self.error).is_some()
// }
#[derive(Deserialize)]
pub struct ErrorDetails {
title: String,
detail: String,
// #[serde(rename = "isRetryable")]
// is_retryable: Option<bool>,
}
use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse};
impl ChatError {
pub fn to_error_response(&self) -> ErrorResponse {
if self.error.details.is_empty() {
return ErrorResponse {
status: 500,
code: "unknown".to_string(),
error: None,
};
}
ErrorResponse {
status: self.status_code(),
code: self.error.code.clone(),
error: Some(Error {
message: self.error.details[0].debug.details.title.clone(),
details: self.error.details[0].debug.details.detail.clone(),
value: self.error.details[0].value.clone(),
}),
}
}
pub fn status_code(&self) -> u16 {
match ErrorType::from_str_name(&self.error) {
match ErrorType::from_str_name(&self.error.details[0].debug.error) {
Some(error) => match error {
ErrorType::Unspecified => 500,
ErrorType::BadApiKey
@@ -68,46 +93,26 @@ impl ErrorDebug {
| ErrorType::SlashEditFileTooLong
| ErrorType::FileUnsupported
| ErrorType::ClaudeImageTooLarge => 400,
_ => 500,
ErrorType::Deprecated
| ErrorType::FreeUserUsageLimit
| ErrorType::ProUserUsageLimit
| ErrorType::ResourceExhausted
| ErrorType::Openai
| ErrorType::MaxTokens
| ErrorType::ApiKeyNotSupported
| ErrorType::UserAbortedRequest
| ErrorType::CustomMessage
| ErrorType::OutdatedClient
| ErrorType::Debounced
| ErrorType::RepositoryServiceRepositoryIsNotInitialized => 500,
},
None => 500,
}
}
}
#[derive(Serialize, Deserialize)]
pub struct ErrorDetails {
pub title: String,
pub detail: String,
#[serde(rename = "isRetryable")]
pub is_retryable: bool,
}
use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse};
impl ChatError {
pub fn to_json(&self) -> serde_json::Value {
serde_json::to_value(self).unwrap()
}
pub fn to_error_response(&self) -> ErrorResponse {
if self.error.details.is_empty() {
return ErrorResponse {
status: 500,
code: "ERROR_UNKNOWN".to_string(),
error: None,
};
}
ErrorResponse {
status: self.error.details[0].debug.status_code(),
code: self.error.details[0].debug.error.clone(),
error: Some(Error {
message: self.error.details[0].debug.details.title.clone(),
details: self.error.details[0].debug.details.detail.clone(),
value: self.error.details[0].value.clone(),
}),
}
}
// pub fn is_expected(&self) -> bool {
// self.error.details[0].debug.is_expected.unwrap_or_default()
// }
}
#[derive(Serialize)]
@@ -135,7 +140,7 @@ impl ErrorResponse {
}
pub fn native_code(&self) -> String {
self.code.replace("_", " ").to_lowercase()
self.code.replace("_", " ")
}
pub fn to_common(self) -> CommonErrorResponse {
@@ -157,7 +162,7 @@ pub enum StreamError {
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StreamError::ChatError(error) => write!(f, "{}", serde_json::to_string(error).unwrap()),
StreamError::ChatError(error) => write!(f, "{}", error.error.details[0].debug.details.title),
StreamError::DataLengthLessThan5 => write!(f, "data length less than 5"),
StreamError::EmptyMessage => write!(f, "empty message"),
}

View File

@@ -5,7 +5,7 @@ use crate::{
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE,
ROUTE_TOKENINFO_PATH,
},
model::{AppConfig, AppState, PageContent, TokenUpdateRequest},
model::{AppConfig, PageContent, TokenUpdateRequest},
lazy::{AUTH_TOKEN, TOKEN_FILE, TOKEN_LIST_FILE},
},
common::{
@@ -13,15 +13,22 @@ use crate::{
utils::{generate_checksum, generate_hash, tokens::load_tokens},
},
};
#[cfg(not(feature = "sqlite"))]
use crate::app::model::AppState;
#[cfg(feature = "sqlite")]
use crate::app::db::APP_DB;
use axum::{
extract::State,
http::HeaderMap,
response::{IntoResponse, Response},
Json,
};
#[cfg(not(feature = "sqlite"))]
use axum::extract::State;
use reqwest::StatusCode;
use serde::Serialize;
#[cfg(not(feature = "sqlite"))]
use std::sync::Arc;
#[cfg(not(feature = "sqlite"))]
use tokio::sync::Mutex;
#[derive(Serialize)]
@@ -36,17 +43,28 @@ pub async fn handle_get_checksum() -> Json<ChecksumResponse> {
// 更新 TokenInfo 处理
pub async fn handle_update_tokeninfo(
State(state): State<Arc<Mutex<AppState>>>,
#[cfg(not(feature = "sqlite"))] State(state): State<Arc<Mutex<AppState>>>,
) -> Json<NormalResponseNoData> {
// 重新加载 tokens
let token_infos = load_tokens();
// 更新应用状态
#[cfg(not(feature = "sqlite"))]
{
let mut state = state.lock().await;
state.token_infos = token_infos;
}
#[cfg(feature = "sqlite")]
{
// 使用 APP_DB 更新 token_infos
if let Ok(db) = APP_DB.lock() {
for token_info in token_infos {
let _ = db.update_token_info(&token_info);
}
}
}
Json(NormalResponseNoData {
status: ApiStatus::Success,
message: Some("Token list has been reloaded".to_string()),
@@ -55,13 +73,8 @@ pub async fn handle_update_tokeninfo(
// 获取 TokenInfo 处理
pub async fn handle_get_tokeninfo(
State(_state): State<Arc<Mutex<AppState>>>,
headers: HeaderMap,
) -> Result<Json<TokenInfoResponse>, StatusCode> {
let auth_token = AUTH_TOKEN.as_str();
let token_file = TOKEN_FILE.as_str();
let token_list_file = TOKEN_LIST_FILE.as_str();
// 验证 AUTH_TOKEN
let auth_header = headers
.get(HEADER_NAME_AUTHORIZATION)
@@ -69,20 +82,37 @@ pub async fn handle_get_tokeninfo(
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
.ok_or(StatusCode::UNAUTHORIZED)?;
if auth_header != auth_token {
if auth_header != AUTH_TOKEN.as_str() {
return Err(StatusCode::UNAUTHORIZED);
}
let token_file = TOKEN_FILE.as_str();
let token_list_file = TOKEN_LIST_FILE.as_str();
// 读取文件内容
let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new());
let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new());
// 获取 tokens_count
let tokens_count = {
#[cfg(feature = "sqlite")]
{
APP_DB.lock()
.map(|db| db.get_token_infos().map(|v| v.len()).unwrap_or(0))
.unwrap_or(0)
}
#[cfg(not(feature = "sqlite"))]
{
tokens.len()
}
};
Ok(Json(TokenInfoResponse {
status: ApiStatus::Success,
token_file: token_file.to_string(),
token_list_file: token_list_file.to_string(),
tokens: Some(tokens.clone()),
tokens_count: Some(tokens.len()),
tokens: Some(tokens),
tokens_count: Some(tokens_count),
token_list: Some(token_list),
message: None,
}))
@@ -104,14 +134,10 @@ pub struct TokenInfoResponse {
}
pub async fn handle_update_tokeninfo_post(
State(state): State<Arc<Mutex<AppState>>>,
#[cfg(not(feature = "sqlite"))] State(state): State<Arc<Mutex<AppState>>>,
headers: HeaderMap,
Json(request): Json<TokenUpdateRequest>,
) -> Result<Json<TokenInfoResponse>, StatusCode> {
let auth_token = AUTH_TOKEN.as_str();
let token_file = TOKEN_FILE.as_str();
let token_list_file = TOKEN_LIST_FILE.as_str();
// 验证 AUTH_TOKEN
let auth_header = headers
.get(HEADER_NAME_AUTHORIZATION)
@@ -119,15 +145,18 @@ pub async fn handle_update_tokeninfo_post(
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
.ok_or(StatusCode::UNAUTHORIZED)?;
if auth_header != auth_token {
if auth_header != AUTH_TOKEN.as_str() {
return Err(StatusCode::UNAUTHORIZED);
}
// 写入 .token 文件
std::fs::write(&token_file, &request.tokens).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let token_file = TOKEN_FILE.as_str();
let token_list_file = TOKEN_LIST_FILE.as_str();
// 如果提供了 token_list写入
if let Some(token_list) = request.token_list {
// 写入文件
std::fs::write(&token_file, &request.tokens)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(token_list) = &request.token_list {
std::fs::write(&token_list_file, token_list)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
@@ -137,11 +166,21 @@ pub async fn handle_update_tokeninfo_post(
let token_infos_len = token_infos.len();
// 更新应用状态
#[cfg(not(feature = "sqlite"))]
{
let mut state = state.lock().await;
state.token_infos = token_infos;
}
#[cfg(feature = "sqlite")]
{
if let Ok(db) = APP_DB.lock() {
for token_info in token_infos {
let _ = db.update_token_info(&token_info);
}
}
}
Ok(Json(TokenInfoResponse {
status: ApiStatus::Success,
token_file: token_file.to_string(),

View File

@@ -1,5 +1,6 @@
use crate::{
app::model::AppState,
chat::constant::ERR_NODATA,
common::{models::usage::GetUserInfo, utils::get_user_usage},
};
use axum::{
@@ -26,11 +27,11 @@ pub async fn get_user_info(
let (auth_token, checksum) = match token_info {
Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()),
None => return Json(GetUserInfo::Error("No data".to_string())),
None => return Json(GetUserInfo::Error(ERR_NODATA.to_string())),
};
match get_user_usage(&auth_token, &checksum).await {
Some(usage) => Json(GetUserInfo::Usage(usage)),
None => Json(GetUserInfo::Error("No data".to_string())),
None => Json(GetUserInfo::Error(ERR_NODATA.to_string())),
}
}

View File

@@ -147,7 +147,9 @@ pub async fn handle_chat(
}
}
let next_id = state.request_logs.last().map_or(1, |log| log.id + 1);
state.request_logs.push(RequestLog {
id: next_id,
timestamp: request_time,
model: request.model.clone(),
token_info: TokenInfo {
@@ -420,11 +422,6 @@ pub async fn handle_chat(
}
Ok(Bytes::new())
}
Err(StreamError::ChatError(error)) => {
buffer_guard.clear();
eprintln!("Stream error occurred: {}", error.to_json());
Ok(Bytes::new())
}
Err(e) => {
buffer_guard.clear();
eprintln!("[警告] Stream error: {}", e);
@@ -480,7 +477,7 @@ pub async fn handle_chat(
}
Err(StreamError::ChatError(error)) => {
return Err((
StatusCode::from_u16(error.error.details[0].debug.status_code())
StatusCode::from_u16(error.status_code())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
Json(error.to_error_response().to_common()),
));

View File

@@ -1,7 +1,10 @@
use crate::app::constant::{
use crate::app::{
constant::{
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO,
CURSOR_API2_BASE_URL, CURSOR_API2_HOST, CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION,
CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION,
HEADER_NAME_CONTENT_TYPE,
},
lazy::{CURSOR_API2_HOST, CURSOR_API2_BASE_URL},
};
use reqwest::Client;
use uuid::Uuid;
@@ -17,7 +20,7 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest
};
client
.post(format!("{}{}", CURSOR_API2_BASE_URL, endpoint))
.post(format!("{}{}", *CURSOR_API2_BASE_URL, endpoint))
.header(HEADER_NAME_CONTENT_TYPE, content_type)
.header(
HEADER_NAME_AUTHORIZATION,
@@ -32,5 +35,5 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest
.header("x-cursor-timezone", "Asia/Shanghai")
.header("x-ghost-mode", "false")
.header("x-request-id", trace_id)
.header("Host", CURSOR_API2_HOST)
.header("Host", CURSOR_API2_HOST.clone())
}

View File

@@ -1,6 +1,7 @@
mod checksum;
pub use checksum::*;
pub mod tokens;
pub mod oauth;
use prost::Message as _;
use crate::{app::constant::CURSOR_API2_GET_USER_INFO, chat::aiserver::v1::GetUserInfoResponse};

View File

@@ -42,3 +42,54 @@ pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String {
None => format!("{}{}", encoded, device_id),
}
}
pub fn validate_checksum(checksum: &str) -> bool {
// 首先检查是否包含基本的 base64 编码部分和 hash 格式的 device_id
let parts: Vec<&str> = checksum.split('/').collect();
match parts.len() {
// 没有 MAC 地址的情况
1 => {
// 检查是否包含 BASE64 编码的 timestamp (8字符) + 64字符的hash
if checksum.len() != 72 {
// 8 + 64 = 72
return false;
}
// 验证 BASE64 部分
let base64_len = 8;
let encoded_part = &checksum[..base64_len];
if !BASE64.decode(encoded_part).is_ok() {
return false;
}
// 验证 device_id hash 部分
let device_hash = &checksum[base64_len..];
is_valid_hash(device_hash)
}
// 包含 MAC hash 的情况
2 => {
let first_part = parts[0];
let mac_hash = parts[1];
// MAC hash 必须是64字符的十六进制
if !is_valid_hash(mac_hash) {
return false;
}
// 递归验证第一部分
validate_checksum(first_part)
}
_ => false,
}
}
fn is_valid_hash(hash: &str) -> bool {
// 检查长度是否为64
if hash.len() != 64 {
return false;
}
// 检查是否都是有效的十六进制字符
hash.chars().all(|c| c.is_ascii_hexdigit())
}

80
src/common/utils/oauth.rs Normal file
View File

@@ -0,0 +1,80 @@
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
const OAUTH_AUTHORIZE_URL: &str = "https://connect.linux.do/oauth2/authorize";
const OAUTH_TOKEN_URL: &str = "https://connect.linux.do/oauth2/token";
const OAUTH_USER_INFO_URL: &str = "https://connect.linux.do/api/user";
#[derive(Debug, Serialize, Deserialize)]
pub struct ForumUser {
pub id: i64,
pub username: String,
pub name: String,
pub active: bool,
pub trust_level: i32,
pub silenced: bool,
}
pub struct ForumOAuth {
client_id: String,
client_secret: String,
redirect_uri: String,
http_client: Client,
}
impl ForumOAuth {
pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
Self {
client_id,
client_secret,
redirect_uri,
http_client: Client::new(),
}
}
pub fn get_authorize_url(&self, state: &str) -> String {
format!(
"{}?response_type=code&client_id={}&redirect_uri={}&state={}",
OAUTH_AUTHORIZE_URL,
self.client_id,
urlencoding::encode(&self.redirect_uri),
state
)
}
pub async fn exchange_code_for_token(&self, code: &str) -> Result<String> {
let response = self
.http_client
.post(OAUTH_TOKEN_URL)
.form(&[
("grant_type", "authorization_code"),
("code", code),
("client_id", &self.client_id),
("client_secret", &self.client_secret),
("redirect_uri", &self.redirect_uri),
])
.send()
.await?
.json::<serde_json::Value>()
.await?;
Ok(response["access_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("No access token found"))?
.to_string())
}
pub async fn get_user_info(&self, access_token: &str) -> Result<ForumUser> {
let user = self
.http_client
.get(OAUTH_USER_INFO_URL)
.bearer_auth(access_token)
.send()
.await?
.json::<ForumUser>()
.await?;
Ok(user)
}
}

View File

@@ -63,6 +63,9 @@ async fn main() {
let token_infos = load_tokens();
// 初始化应用状态
#[cfg(feature = "sqlite")]
let state = Arc::new(Mutex::new(AppState::new()));
#[cfg(not(feature = "sqlite"))]
let state = Arc::new(Mutex::new(AppState::new(token_infos)));
// 设置路由

View File

@@ -200,6 +200,7 @@
<table id="logsTable">
<thead>
<tr>
<th>id</th>
<th>时间</th>
<th>模型</th>
<th>Token信息</th>
@@ -304,6 +305,7 @@
tbody.innerHTML = data.logs.map(log => `
<tr>
<td>${log.id}</td>
<td>${new Date(log.timestamp).toLocaleString()}</td>
<td>${log.model}</td>
<td>