diff --git a/.env.example b/.env.example index 5c606d4..472c252 100644 --- a/.env.example +++ b/.env.example @@ -37,3 +37,6 @@ VISION_ABILITY=base64 # 默认提示词 DEFAULT_INSTRUCTIONS="Respond in Chinese by default" + +# 反向代理服务器主机名 +CURSOR_API2_HOST= diff --git a/.github/workflows/build-linux.yml b/.github/workflows/build-linux.yml new file mode 100644 index 0000000..a4c2b57 --- /dev/null +++ b/.github/workflows/build-linux.yml @@ -0,0 +1,34 @@ +name: Build Linux Binaries + +on: + workflow_dispatch: + +jobs: + build: + name: Build ${{ matrix.target }} + runs-on: ubuntu-latest + strategy: + matrix: + target: [x86_64-unknown-linux-gnu] + + steps: + - uses: actions/checkout@v4.2.2 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + targets: ${{ matrix.target }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler pkg-config libssl-dev nodejs npm + + - name: Build binary + run: cargo build --release --target ${{ matrix.target }} + + - name: Upload artifact + uses: actions/upload-artifact@v4.5.0 + with: + name: cursor-api-${{ matrix.target }} + path: target/${{ matrix.target }}/release/cursor-api diff --git a/.gitignore b/.gitignore index 4114801..6be81ea 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ node_modules /cursor-api.exe /release -/*.py \ No newline at end of file +/*.py +/logs diff --git a/Cargo.lock b/Cargo.lock index 2b01dfc..6eaceb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -62,9 +74,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "1b1244b10dcd56c92219da4e14caa97e312079e185f04ba3eea25061561dc0a0" dependencies = [ "proc-macro2", "quote", @@ -212,9 +224,9 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cc" -version = "1.2.6" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d6dbb628b8f8555f86d0323c2eb39e3ec81901f4b83e091db8a6a76d316a333" +checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" dependencies = [ "shlex", ] @@ -292,8 +304,9 @@ dependencies = [ [[package]] name = "cursor-api" -version = "0.1.3-rc.3" +version = "0.1.3" dependencies = [ + "anyhow", "axum", "base64", "bytes", @@ -311,6 +324,7 @@ dependencies = [ "rand", "regex", "reqwest", + "rusqlite", "serde", "serde_json", "sha2", @@ -318,6 +332,7 @@ dependencies = [ "tokio", "tokio-stream", "tower-http", + "urlencoding", "uuid", ] @@ -379,6 +394,18 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" @@ -573,12 +600,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.5.0" @@ -906,7 +951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -952,6 +997,17 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1318,9 +1374,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.11" +version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe060fe50f524be480214aba758c71f99f90ee8c83c5a36b5e9e1d568eb4eb3" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "async-compression", "base64", @@ -1378,6 +1434,20 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags 2.6.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1602,9 +1672,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.92" +version = "2.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ae51629bf965c5c098cc9e87908a3df5301051a9e087d6f9bef5c9771ed126" +checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a" dependencies = [ "proc-macro2", "quote", @@ -1667,12 +1737,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", + "getrandom", "once_cell", "rustix", "windows-sys 0.59.0", @@ -1856,6 +1927,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf16_iter" version = "1.0.5" diff --git a/Cargo.toml b/Cargo.toml index 714c3d8..9736b19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,8 @@ [package] name = "cursor-api" -version = "0.1.3-rc.3" +version = "0.1.3" edition = "2021" authors = ["wisdgod "] -# license = "MIT" -# copyright = "Copyright (c) 2024 wisdgod" description = "OpenAI format compatibility layer for the Cursor API" repository = "https://github.com/wisdgod/cursor-api" @@ -14,6 +12,7 @@ sha2 = { version = "0.10.8", default-features = false } serde_json = "1.0.134" [dependencies] +anyhow = "1.0.95" axum = { version = "0.7.9", features = ["json"] } base64 = { version = "0.22.1", default-features = false, features = ["std"] } # brotli = { version = "7.0.0", default-features = false, features = ["std"] } @@ -30,7 +29,8 @@ paste = "1.0.15" prost = "0.13.4" rand = { version = "0.8.5", default-features = false, features = ["std", "std_rng"] } regex = { version = "1.11.1", default-features = false, features = ["std", "perf"] } -reqwest = { version = "0.12.11", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] } +reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] } +rusqlite = { version = "0.32.1", features = ["bundled"], optional = true } serde = { version = "1.0.217", default-features = false, features = ["std", "derive"] } serde_json = "1.0.134" sha2 = { version = "0.10.8", default-features = false } @@ -38,6 +38,7 @@ sysinfo = { version = "0.33.1", default-features = false, features = ["system"] tokio = { version = "1.42.0", features = ["rt-multi-thread", "macros", "net", "sync", "time"] } tokio-stream = { version = "0.1.17", features = ["time"] } tower-http = { version = "0.6.2", features = ["cors"] } +urlencoding = "2.1.3" uuid = { version = "1.11.0", features = ["v4"] } [profile.release] @@ -46,16 +47,3 @@ codegen-units = 1 panic = 'abort' strip = true opt-level = 3 - -# 构建脚本设置 -[package.metadata.cross.target.x86_64-unknown-linux-gnu] -image = "ghcr.io/cross-rs/x86_64-unknown-linux-gnu:main" - -[package.metadata.cross.target.aarch64-unknown-linux-gnu] -image = "ghcr.io/cross-rs/aarch64-unknown-linux-gnu:main" - -[package.metadata.cross.target.x86_64-apple-darwin] -image = "ghcr.io/cross-rs/x86_64-apple-darwin:main" - -[package.metadata.cross.target.aarch64-apple-darwin] -image = "ghcr.io/cross-rs/aarch64-apple-darwin:main" diff --git a/src/app/config.rs b/src/app/config.rs index d7c7ab7..6b20554 100644 --- a/src/app/config.rs +++ b/src/app/config.rs @@ -1,7 +1,7 @@ use super::{ - constant::{HEADER_NAME_AUTHORIZATION, AUTHORIZATION_BEARER_PREFIX}, - model::{AppConfig, AppState}, + constant::AUTHORIZATION_BEARER_PREFIX, lazy::AUTH_TOKEN, + model::{AppConfig, AppState}, }; use crate::common::models::{ config::{ConfigData, ConfigUpdateRequest}, @@ -9,7 +9,7 @@ use crate::common::models::{ }; use axum::{ extract::State, - http::{HeaderMap, StatusCode}, + http::{header::AUTHORIZATION, HeaderMap, StatusCode}, Json, }; use std::sync::Arc; @@ -59,7 +59,7 @@ pub async fn handle_config_update( Json(request): Json, ) -> Result>, (StatusCode, Json)> { let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(( @@ -116,12 +116,42 @@ pub async fn handle_config_update( } } - handle_update!(request, enable_stream_check, AppConfig::update_stream_check, "enable_stream_check"); - handle_update!(request, include_stop_stream, AppConfig::update_stop_stream, "include_stop_stream"); - handle_update!(request, vision_ability, AppConfig::update_vision_ability, "vision_ability"); - handle_update!(request, enable_slow_pool, AppConfig::update_slow_pool, "enable_slow_pool"); - handle_update!(request, enable_all_claude, AppConfig::update_allow_claude, "enable_all_claude"); - handle_update!(request, check_usage_models, AppConfig::update_usage_check, "check_usage_models"); + handle_update!( + request, + enable_stream_check, + AppConfig::update_stream_check, + "enable_stream_check" + ); + handle_update!( + request, + include_stop_stream, + AppConfig::update_stop_stream, + "include_stop_stream" + ); + handle_update!( + request, + vision_ability, + AppConfig::update_vision_ability, + "vision_ability" + ); + handle_update!( + request, + enable_slow_pool, + AppConfig::update_slow_pool, + "enable_slow_pool" + ); + handle_update!( + request, + enable_all_claude, + AppConfig::update_allow_claude, + "enable_all_claude" + ); + handle_update!( + request, + check_usage_models, + AppConfig::update_usage_check, + "check_usage_models" + ); Ok(Json(NormalResponse { status: ApiStatus::Success, @@ -146,12 +176,42 @@ pub async fn handle_config_update( } } - handle_reset!(request, enable_stream_check, AppConfig::reset_stream_check, "enable_stream_check"); - handle_reset!(request, include_stop_stream, AppConfig::reset_stop_stream, "include_stop_stream"); - handle_reset!(request, vision_ability, AppConfig::reset_vision_ability, "vision_ability"); - handle_reset!(request, enable_slow_pool, AppConfig::reset_slow_pool, "enable_slow_pool"); - handle_reset!(request, enable_all_claude, AppConfig::reset_allow_claude, "enable_all_claude"); - handle_reset!(request, check_usage_models, AppConfig::reset_usage_check, "check_usage_models"); + handle_reset!( + request, + enable_stream_check, + AppConfig::reset_stream_check, + "enable_stream_check" + ); + handle_reset!( + request, + include_stop_stream, + AppConfig::reset_stop_stream, + "include_stop_stream" + ); + handle_reset!( + request, + vision_ability, + AppConfig::reset_vision_ability, + "vision_ability" + ); + handle_reset!( + request, + enable_slow_pool, + AppConfig::reset_slow_pool, + "enable_slow_pool" + ); + handle_reset!( + request, + enable_all_claude, + AppConfig::reset_allow_claude, + "enable_all_claude" + ); + handle_reset!( + request, + check_usage_models, + AppConfig::reset_usage_check, + "check_usage_models" + ); Ok(Json(NormalResponse { status: ApiStatus::Success, diff --git a/src/app/constant.rs b/src/app/constant.rs index 0658add..4b291b9 100644 --- a/src/app/constant.rs +++ b/src/app/constant.rs @@ -27,6 +27,7 @@ def_pub_const!(ROUTE_SHARED_STYLES_PATH, "/static/shared-styles.css"); def_pub_const!(ROUTE_SHARED_JS_PATH, "/static/shared.js"); def_pub_const!(ROUTE_ABOUT_PATH, "/about"); def_pub_const!(ROUTE_README_PATH, "/readme"); +def_pub_const!(ROUTE_BASIC_CALIBRATION_PATH, "/basic-calibration"); def_pub_const!(DEFAULT_TOKEN_FILE_NAME, ".token"); def_pub_const!(DEFAULT_TOKEN_LIST_FILE_NAME, ".token-list"); @@ -34,9 +35,10 @@ def_pub_const!(DEFAULT_TOKEN_LIST_FILE_NAME, ".token-list"); def_pub_const!(STATUS_SUCCESS, "success"); def_pub_const!(STATUS_FAILED, "failed"); -def_pub_const!(HEADER_NAME_CONTENT_TYPE, "content-type"); -def_pub_const!(HEADER_NAME_AUTHORIZATION, "authorization"); -def_pub_const!(HEADER_NAME_LOCATION, "Location"); +def_pub_const!(HEADER_NAME_GHOST_MODE, "x-ghost-mode"); + +def_pub_const!(TRUE, "true"); +def_pub_const!(FALSE, "false"); def_pub_const!(CONTENT_TYPE_PROTO, "application/proto"); def_pub_const!(CONTENT_TYPE_CONNECT_PROTO, "application/connect+proto"); @@ -50,9 +52,6 @@ def_pub_const!(AUTHORIZATION_BEARER_PREFIX, "Bearer "); def_pub_const!(OBJECT_CHAT_COMPLETION, "chat.completion"); def_pub_const!(OBJECT_CHAT_COMPLETION_CHUNK, "chat.completion.chunk"); -def_pub_const!(CURSOR_API2_HOST, "api2.cursor.sh"); -def_pub_const!(CURSOR_API2_BASE_URL, "https://api2.cursor.sh/aiserver.v1.AiService/"); - def_pub_const!(CURSOR_API2_STREAM_CHAT, "StreamChat"); def_pub_const!(CURSOR_API2_GET_USER_INFO, "GetUserInfo"); diff --git a/src/app/lazy.rs b/src/app/lazy.rs index 553a87a..74ab67e 100644 --- a/src/app/lazy.rs +++ b/src/app/lazy.rs @@ -33,11 +33,11 @@ def_pub_static!(TOKEN_FILE, env: "TOKEN_FILE", default: DEFAULT_TOKEN_FILE_NAME) def_pub_static!(TOKEN_LIST_FILE, env: "TOKEN_LIST_FILE", default: DEFAULT_TOKEN_LIST_FILE_NAME); def_pub_static!( ROUTE_MODELS_PATH, - format!("{}/v1/models", ROUTE_PREFIX.as_str()) + format!("{}/v1/models", *ROUTE_PREFIX) ); def_pub_static!( ROUTE_CHAT_PATH, - format!("{}/v1/chat/completions", ROUTE_PREFIX.as_str()) + format!("{}/v1/chat/completions", *ROUTE_PREFIX) ); pub static START_TIME: LazyLock> = @@ -49,6 +49,12 @@ pub fn get_start_time() -> chrono::DateTime { def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default"); +def_pub_static!(CURSOR_API2_HOST, env: "REVERSE_PROXY_HOST", default: "api2.cursor.sh"); + +pub static CURSOR_API2_BASE_URL: LazyLock = LazyLock::new(|| { + format!("https://{}/aiserver.v1.AiService/", *CURSOR_API2_HOST) +}); + // pub static DEBUG: LazyLock = LazyLock::new(|| parse_bool_from_env("DEBUG", false)); // #[macro_export] diff --git a/src/app/model.rs b/src/app/model.rs index 0476065..9f2a3c9 100644 --- a/src/app/model.rs +++ b/src/app/model.rs @@ -286,6 +286,7 @@ impl AppState { // 请求日志 #[derive(Serialize, Clone)] pub struct RequestLog { + pub id: u64, pub timestamp: chrono::DateTime, pub model: String, pub token_info: TokenInfo, diff --git a/src/chat/constant.rs b/src/chat/constant.rs index 73eba5f..84a8d6a 100644 --- a/src/chat/constant.rs +++ b/src/chat/constant.rs @@ -7,6 +7,7 @@ macro_rules! def_pub_const { } def_pub_const!(ERR_UNSUPPORTED_GIF, "不支持动态 GIF"); def_pub_const!(ERR_UNSUPPORTED_IMAGE_FORMAT, "不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF"); +def_pub_const!(ERR_NODATA, "No data"); const MODEL_OBJECT: &str = "model"; const CREATED: &i64 = &1706659200; diff --git a/src/chat/error.rs b/src/chat/error.rs index 290c981..75e530d 100644 --- a/src/chat/error.rs +++ b/src/chat/error.rs @@ -2,41 +2,66 @@ use super::aiserver::v1::throw_error_check_request::Error as ErrorType; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ChatError { - pub error: ErrorBody, + error: ErrorBody, } -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ErrorBody { - pub code: String, - pub message: String, - pub details: Vec, + code: String, + // message: String, always: Error + details: Vec, } -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ErrorDetail { - #[serde(rename = "type")] - pub error_type: String, - pub debug: ErrorDebug, - pub value: String, + // #[serde(rename = "type")] + // error_type: String, always: aiserver.v1.ErrorDetails + debug: ErrorDebug, + value: String, } -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ErrorDebug { - pub error: String, - pub details: ErrorDetails, - #[serde(rename = "isExpected")] - pub is_expected: bool, + error: String, + details: ErrorDetails, + // #[serde(rename = "isExpected")] + // is_expected: Option, } -impl ErrorDebug { - // pub fn is_valid(&self) -> bool { - // ErrorType::from_str_name(&self.error).is_some() - // } +#[derive(Deserialize)] +pub struct ErrorDetails { + title: String, + detail: String, + // #[serde(rename = "isRetryable")] + // is_retryable: Option, +} + +use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse}; + +impl ChatError { + pub fn to_error_response(&self) -> ErrorResponse { + if self.error.details.is_empty() { + return ErrorResponse { + status: 500, + code: "unknown".to_string(), + error: None, + }; + } + ErrorResponse { + status: self.status_code(), + code: self.error.code.clone(), + error: Some(Error { + message: self.error.details[0].debug.details.title.clone(), + details: self.error.details[0].debug.details.detail.clone(), + value: self.error.details[0].value.clone(), + }), + } + } pub fn status_code(&self) -> u16 { - match ErrorType::from_str_name(&self.error) { + match ErrorType::from_str_name(&self.error.details[0].debug.error) { Some(error) => match error { ErrorType::Unspecified => 500, ErrorType::BadApiKey @@ -68,46 +93,26 @@ impl ErrorDebug { | ErrorType::SlashEditFileTooLong | ErrorType::FileUnsupported | ErrorType::ClaudeImageTooLarge => 400, - _ => 500, + ErrorType::Deprecated + | ErrorType::FreeUserUsageLimit + | ErrorType::ProUserUsageLimit + | ErrorType::ResourceExhausted + | ErrorType::Openai + | ErrorType::MaxTokens + | ErrorType::ApiKeyNotSupported + | ErrorType::UserAbortedRequest + | ErrorType::CustomMessage + | ErrorType::OutdatedClient + | ErrorType::Debounced + | ErrorType::RepositoryServiceRepositoryIsNotInitialized => 500, }, None => 500, } } -} -#[derive(Serialize, Deserialize)] -pub struct ErrorDetails { - pub title: String, - pub detail: String, - #[serde(rename = "isRetryable")] - pub is_retryable: bool, -} - -use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse}; - -impl ChatError { - pub fn to_json(&self) -> serde_json::Value { - serde_json::to_value(self).unwrap() - } - - pub fn to_error_response(&self) -> ErrorResponse { - if self.error.details.is_empty() { - return ErrorResponse { - status: 500, - code: "ERROR_UNKNOWN".to_string(), - error: None, - }; - } - ErrorResponse { - status: self.error.details[0].debug.status_code(), - code: self.error.details[0].debug.error.clone(), - error: Some(Error { - message: self.error.details[0].debug.details.title.clone(), - details: self.error.details[0].debug.details.detail.clone(), - value: self.error.details[0].value.clone(), - }), - } - } + // pub fn is_expected(&self) -> bool { + // self.error.details[0].debug.is_expected.unwrap_or_default() + // } } #[derive(Serialize)] @@ -135,7 +140,7 @@ impl ErrorResponse { } pub fn native_code(&self) -> String { - self.code.replace("_", " ").to_lowercase() + self.code.replace("_", " ") } pub fn to_common(self) -> CommonErrorResponse { @@ -157,7 +162,7 @@ pub enum StreamError { impl std::fmt::Display for StreamError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - StreamError::ChatError(error) => write!(f, "{}", serde_json::to_string(error).unwrap()), + StreamError::ChatError(error) => write!(f, "{}", error.error.details[0].debug.details.title), StreamError::DataLengthLessThan5 => write!(f, "data length less than 5"), StreamError::EmptyMessage => write!(f, "empty message"), } diff --git a/src/chat/route.rs b/src/chat/route.rs index e345a32..956b11f 100644 --- a/src/chat/route.rs +++ b/src/chat/route.rs @@ -3,7 +3,7 @@ pub use logs::{handle_logs, handle_logs_post}; mod health; pub use health::{handle_root, handle_health}; mod token; -pub use token::{handle_get_checksum, handle_update_tokeninfo, handle_get_tokeninfo, handle_update_tokeninfo_post, handle_tokeninfo_page}; +pub use token::{handle_get_checksum, handle_update_tokeninfo, handle_get_tokeninfo, handle_update_tokeninfo_post, handle_tokeninfo_page, handle_basic_calibration}; mod usage; pub use usage::get_user_info; mod config; diff --git a/src/chat/route/config.rs b/src/chat/route/config.rs index 328d541..06ae99e 100644 --- a/src/chat/route/config.rs +++ b/src/chat/route/config.rs @@ -1,22 +1,24 @@ use crate::app::{ constant::{ CONTENT_TYPE_TEXT_CSS_WITH_UTF8, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, - CONTENT_TYPE_TEXT_JS_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, - HEADER_NAME_CONTENT_TYPE, HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, - ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, + CONTENT_TYPE_TEXT_JS_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_ABOUT_PATH, + ROUTE_CONFIG_PATH, ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, }, model::{AppConfig, PageContent}, }; use axum::{ body::Body, extract::Path, - http::StatusCode, + http::{ + header::{CONTENT_TYPE, LOCATION}, + StatusCode, + }, response::{IntoResponse, Response}, }; pub async fn handle_env_example() -> impl IntoResponse { Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(include_str!("../../../.env.example").to_string()) .unwrap() } @@ -25,15 +27,15 @@ pub async fn handle_env_example() -> impl IntoResponse { pub async fn handle_config_page() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_CONFIG_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(include_str!("../../../static/config.min.html").to_string()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(content.clone()) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(content.clone()) .unwrap(), } @@ -44,11 +46,11 @@ pub async fn handle_static(Path(path): Path) -> impl IntoResponse { "shared-styles.css" => { match AppConfig::get_page_content(ROUTE_SHARED_STYLES_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) .body(include_str!("../../../static/shared-styles.min.css").to_string()) .unwrap(), PageContent::Text(content) | PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) .body(content.clone()) .unwrap(), } @@ -56,11 +58,11 @@ pub async fn handle_static(Path(path): Path) -> impl IntoResponse { "shared.js" => { match AppConfig::get_page_content(ROUTE_SHARED_JS_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) .body(include_str!("../../../static/shared.min.js").to_string()) .unwrap(), PageContent::Text(content) | PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) .body(content.clone()) .unwrap(), } @@ -75,15 +77,15 @@ pub async fn handle_static(Path(path): Path) -> impl IntoResponse { pub async fn handle_about() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_ABOUT_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(include_str!("../../../static/readme.min.html").to_string()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(content.clone()) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(content.clone()) .unwrap(), } @@ -93,15 +95,15 @@ pub async fn handle_readme() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_README_PATH).unwrap_or_default() { PageContent::Default => Response::builder() .status(StatusCode::TEMPORARY_REDIRECT) - .header(HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH) + .header(LOCATION, ROUTE_ABOUT_PATH) .body(Body::empty()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), } diff --git a/src/chat/route/health.rs b/src/chat/route/health.rs index e5acdb1..f27c255 100644 --- a/src/chat/route/health.rs +++ b/src/chat/route/health.rs @@ -1,15 +1,15 @@ use crate::{ app::{ constant::{ - CONTENT_TYPE_TEXT_HTML_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, - HEADER_NAME_CONTENT_TYPE, HEADER_NAME_LOCATION, PKG_VERSION, ROUTE_ABOUT_PATH, - ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, - ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, - ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKENINFO_PATH, - ROUTE_UPDATE_TOKENINFO_PATH, + AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, + CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, PKG_VERSION, ROUTE_ABOUT_PATH, + ROUTE_BASIC_CALIBRATION_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, + ROUTE_GET_CHECKSUM, ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, + ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, + ROUTE_STATIC_PATH, ROUTE_TOKENINFO_PATH, ROUTE_UPDATE_TOKENINFO_PATH, }, + lazy::{get_start_time, AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH}, model::{AppConfig, AppState, PageContent}, - lazy::{get_start_time, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH}, }, chat::constant::AVAILABLE_MODELS, common::models::{ @@ -20,11 +20,15 @@ use crate::{ use axum::{ body::Body, extract::State, - http::StatusCode, + http::{ + header::{CONTENT_TYPE, LOCATION}, + HeaderMap, StatusCode, + }, response::{IntoResponse, Response}, Json, }; use chrono::Local; +use reqwest::header::AUTHORIZATION; use std::sync::Arc; use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System}; use tokio::sync::Mutex; @@ -33,53 +37,59 @@ pub async fn handle_root() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_ROOT_PATH).unwrap_or_default() { PageContent::Default => Response::builder() .status(StatusCode::TEMPORARY_REDIRECT) - .header(HEADER_NAME_LOCATION, ROUTE_HEALTH_PATH) + .header(LOCATION, ROUTE_HEALTH_PATH) .body(Body::empty()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), } } -pub async fn handle_health(State(state): State>>) -> Json { +pub async fn handle_health( + State(state): State>>, + headers: HeaderMap, +) -> Json { let start_time = get_start_time(); - - // 创建系统信息实例,只监控 CPU 和内存 - let mut sys = System::new_with_specifics( - RefreshKind::nothing() - .with_memory(MemoryRefreshKind::everything()) - .with_cpu(CpuRefreshKind::everything()), - ); - - std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL); - - // 刷新 CPU 和内存信息 - sys.refresh_memory(); - sys.refresh_cpu_usage(); - - let pid = std::process::id() as usize; - let process = sys.process(pid.into()); - - // 获取内存信息 - let memory = process.map(|p| p.memory()).unwrap_or(0); - - // 获取 CPU 使用率 - let cpu_usage = sys.global_cpu_usage(); - - let state = state.lock().await; let uptime = (Local::now() - start_time).num_seconds(); - Json(HealthCheckResponse { - status: ApiStatus::Healthy, - version: PKG_VERSION, - uptime, - stats: SystemStats { + // 先检查 headers 是否包含有效的认证信息 + let stats = if headers + .get(AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) + .map_or(false, |token| token == AUTH_TOKEN.as_str()) + { + // 只有在需要系统信息时才创建实例 + let mut sys = System::new_with_specifics( + RefreshKind::nothing() + .with_memory(MemoryRefreshKind::everything()) + .with_cpu(CpuRefreshKind::everything()), + ); + + std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL); + + // 刷新 CPU 和内存信息 + sys.refresh_memory(); + sys.refresh_cpu_usage(); + + let pid = std::process::id() as usize; + let process = sys.process(pid.into()); + + // 获取内存信息 + let memory = process.map(|p| p.memory()).unwrap_or(0); + + // 获取 CPU 使用率 + let cpu_usage = sys.global_cpu_usage(); + + let state = state.lock().await; + + Some(SystemStats { started: start_time.to_string(), total_requests: state.total_requests, active_requests: state.active_requests, @@ -91,7 +101,16 @@ pub async fn handle_health(State(state): State>>) -> Json>(), endpoints: vec![ ROUTE_CHAT_PATH.as_str(), @@ -101,12 +120,13 @@ pub async fn handle_health(State(state): State>>) -> Json impl IntoResponse { match AppConfig::get_page_content(ROUTE_LOGS_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from( include_str!("../../../static/logs.min.html").to_string(), )) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), } @@ -47,22 +49,68 @@ pub async fn handle_logs_post( ) -> Result, StatusCode> { let auth_token = AUTH_TOKEN.as_str(); - // 验证 AUTH_TOKEN + // 获取认证头 let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(StatusCode::UNAUTHORIZED)?; - if auth_header != auth_token { + let state = state.lock().await; + + // 如果是管理员token,返回所有日志 + if auth_header == auth_token { + return Ok(Json(LogsResponse { + status: ApiStatus::Success, + total: state.request_logs.len(), + logs: state.request_logs.clone(), + timestamp: Local::now().to_string(), + })); + } + + // 解析 token 和 checksum + let token_part = if let Some(pos) = auth_header.find("::") { + let (_, rest) = auth_header.split_at(pos + 2); + if let Some(comma_pos) = rest.find(',') { + let (token, _) = rest.split_at(comma_pos); + token + } else { + rest + } + } else if let Some(pos) = auth_header.find("%3A%3A") { + let (_, rest) = auth_header.split_at(pos + 6); + if let Some(comma_pos) = rest.find(',') { + let (token, _) = rest.split_at(comma_pos); + token + } else { + rest + } + } else { + if let Some(comma_pos) = auth_header.find(',') { + let (token, _) = auth_header.split_at(comma_pos); + token + } else { + auth_header + } + }; + + // 否则筛选出token匹配的日志 + let filtered_logs: Vec = state + .request_logs + .iter() + .filter(|log| log.token_info.token == token_part) + .cloned() + .collect(); + + // 如果没有匹配的日志,返回未授权错误 + if filtered_logs.is_empty() { return Err(StatusCode::UNAUTHORIZED); } - let state = state.lock().await; Ok(Json(LogsResponse { status: ApiStatus::Success, - total: state.request_logs.len(), - logs: state.request_logs.clone(), + total: filtered_logs.len(), + logs: filtered_logs, timestamp: Local::now().to_string(), })) } diff --git a/src/chat/route/token.rs b/src/chat/route/token.rs index f637c13..20b3506 100644 --- a/src/chat/route/token.rs +++ b/src/chat/route/token.rs @@ -2,25 +2,30 @@ use crate::{ app::{ constant::{ AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, - CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE, - ROUTE_TOKENINFO_PATH, + CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_TOKENINFO_PATH, }, - model::{AppConfig, AppState, PageContent, TokenUpdateRequest}, lazy::{AUTH_TOKEN, TOKEN_FILE, TOKEN_LIST_FILE}, + model::{AppConfig, AppState, PageContent, TokenUpdateRequest}, }, common::{ models::{ApiStatus, NormalResponseNoData}, - utils::{generate_checksum, generate_hash, tokens::load_tokens}, + utils::{ + extract_time, extract_user_id, generate_checksum_with_default, load_tokens, + validate_checksum, validate_token, + }, }, }; use axum::{ extract::State, - http::HeaderMap, + http::{ + header::{AUTHORIZATION, CONTENT_TYPE}, + HeaderMap, + }, response::{IntoResponse, Response}, Json, }; use reqwest::StatusCode; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::Mutex; @@ -30,7 +35,7 @@ pub struct ChecksumResponse { } pub async fn handle_get_checksum() -> Json { - let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); + let checksum = generate_checksum_with_default(); Json(ChecksumResponse { checksum }) } @@ -55,34 +60,39 @@ pub async fn handle_update_tokeninfo( // 获取 TokenInfo 处理 pub async fn handle_get_tokeninfo( - State(_state): State>>, headers: HeaderMap, ) -> Result, StatusCode> { - let auth_token = AUTH_TOKEN.as_str(); - let token_file = TOKEN_FILE.as_str(); - let token_list_file = TOKEN_LIST_FILE.as_str(); - // 验证 AUTH_TOKEN let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(StatusCode::UNAUTHORIZED)?; - if auth_header != auth_token { + if auth_header != AUTH_TOKEN.as_str() { return Err(StatusCode::UNAUTHORIZED); } + let token_file = TOKEN_FILE.as_str(); + let token_list_file = TOKEN_LIST_FILE.as_str(); + // 读取文件内容 let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new()); let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new()); + // 获取 tokens_count + let tokens_count = { + { + tokens.len() + } + }; + Ok(Json(TokenInfoResponse { status: ApiStatus::Success, token_file: token_file.to_string(), token_list_file: token_list_file.to_string(), - tokens: Some(tokens.clone()), - tokens_count: Some(tokens.len()), + tokens: Some(tokens), + tokens_count: Some(tokens_count), token_list: Some(token_list), message: None, })) @@ -108,26 +118,24 @@ pub async fn handle_update_tokeninfo_post( headers: HeaderMap, Json(request): Json, ) -> Result, StatusCode> { - let auth_token = AUTH_TOKEN.as_str(); - let token_file = TOKEN_FILE.as_str(); - let token_list_file = TOKEN_LIST_FILE.as_str(); - // 验证 AUTH_TOKEN let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(StatusCode::UNAUTHORIZED)?; - if auth_header != auth_token { + if auth_header != AUTH_TOKEN.as_str() { return Err(StatusCode::UNAUTHORIZED); } - // 写入 .token 文件 + let token_file = TOKEN_FILE.as_str(); + let token_list_file = TOKEN_LIST_FILE.as_str(); + + // 写入文件 std::fs::write(&token_file, &request.tokens).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - // 如果提供了 token_list,则写入 - if let Some(token_list) = request.token_list { + if let Some(token_list) = &request.token_list { std::fs::write(&token_list_file, token_list) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; } @@ -156,16 +164,106 @@ pub async fn handle_update_tokeninfo_post( pub async fn handle_tokeninfo_page() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_TOKENINFO_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(include_str!("../../../static/tokeninfo.min.html").to_string()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(content.clone()) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(content.clone()) .unwrap(), } } + +#[derive(Deserialize)] +pub struct TokenRequest { + pub token: Option, +} + +#[derive(Serialize)] +pub struct BasicCalibrationResponse { + pub status: ApiStatus, + pub message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub create_at: Option, +} + +pub async fn handle_basic_calibration( + Json(request): Json, +) -> Json { + // 从请求头中获取并验证 auth token + let auth_token = match request.token { + Some(token) => token, + None => { + return Json(BasicCalibrationResponse { + status: ApiStatus::Error, + message: Some("未提供授权令牌".to_string()), + user_id: None, + create_at: None, + }) + } + }; + + // 解析 token 和 checksum + let (token_part, checksum) = if let Some(pos) = auth_token.find("::") { + let (_, rest) = auth_token.split_at(pos + 2); + if let Some(comma_pos) = rest.find(',') { + let (token, checksum) = rest.split_at(comma_pos); + (token, &checksum[1..]) + } else { + (rest, "") + } + } else if let Some(pos) = auth_token.find("%3A%3A") { + let (_, rest) = auth_token.split_at(pos + 6); + if let Some(comma_pos) = rest.find(',') { + let (token, checksum) = rest.split_at(comma_pos); + (token, &checksum[1..]) + } else { + (rest, "") + } + } else { + if let Some(comma_pos) = auth_token.find(',') { + let (token, checksum) = auth_token.split_at(comma_pos); + (token, &checksum[1..]) + } else { + (&auth_token[..], "") + } + }; + + // 验证 token 有效性 + if !validate_token(token_part) { + return Json(BasicCalibrationResponse { + status: ApiStatus::Error, + message: Some("无效的授权令牌".to_string()), + user_id: None, + create_at: None, + }); + } + + // 验证 checksum + if !validate_checksum(checksum) { + return Json(BasicCalibrationResponse { + status: ApiStatus::Error, + message: Some("无效的校验和".to_string()), + user_id: None, + create_at: None, + }); + } + + // 提取用户ID和创建时间 + let user_id = extract_user_id(token_part); + let create_at = extract_time(token_part).map(|dt| dt.to_string()); + + // 返回校准结果 + Json(BasicCalibrationResponse { + status: ApiStatus::Success, + message: Some("校准成功".to_string()), + user_id, + create_at, + }) +} diff --git a/src/chat/route/usage.rs b/src/chat/route/usage.rs index ae16f50..ffd95e6 100644 --- a/src/chat/route/usage.rs +++ b/src/chat/route/usage.rs @@ -1,36 +1,48 @@ use crate::{ - app::model::AppState, - common::{models::usage::GetUserInfo, utils::get_user_usage}, + chat::constant::ERR_NODATA, + common::{ + models::usage::GetUserInfo, + utils::{generate_checksum_with_default, get_user_usage}, + }, }; -use axum::{ - extract::{Query, State}, - Json, -}; -use serde::Deserialize; -use std::sync::Arc; -use tokio::sync::Mutex; +use axum::Json; -#[derive(Deserialize)] -pub struct GetUserInfoQuery { - alias: String, -} +use super::token::TokenRequest; -pub async fn get_user_info( - State(state): State>>, - Query(query): Query, -) -> Json { - let token_infos = &state.lock().await.token_infos; - let token_info = token_infos - .iter() - .find(|token_info| token_info.alias == Some(query.alias.clone())); - - let (auth_token, checksum) = match token_info { - Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()), - None => return Json(GetUserInfo::Error("No data".to_string())), +pub async fn get_user_info(Json(request): Json) -> Json { + let auth_token = match request.token { + Some(token) => token, + None => return Json(GetUserInfo::Error(ERR_NODATA.to_string())), }; - match get_user_usage(&auth_token, &checksum).await { + // 解析 token 和 checksum + let (token_part, checksum) = if let Some(pos) = auth_token.find("::") { + let (_, rest) = auth_token.split_at(pos + 2); + if let Some(comma_pos) = rest.find(',') { + let (token, checksum) = rest.split_at(comma_pos); + (token, checksum[1..].to_string()) + } else { + (rest, generate_checksum_with_default()) + } + } else if let Some(pos) = auth_token.find("%3A%3A") { + let (_, rest) = auth_token.split_at(pos + 6); + if let Some(comma_pos) = rest.find(',') { + let (token, checksum) = rest.split_at(comma_pos); + (token, checksum[1..].to_string()) + } else { + (rest, generate_checksum_with_default()) + } + } else { + if let Some(comma_pos) = auth_token.find(',') { + let (token, checksum) = auth_token.split_at(comma_pos); + (token, checksum[1..].to_string()) + } else { + (&auth_token[..], generate_checksum_with_default()) + } + }; + + match get_user_usage(&token_part, &checksum).await { Some(usage) => Json(GetUserInfo::Usage(usage)), - None => Json(GetUserInfo::Error("No data".to_string())), + None => Json(GetUserInfo::Error(ERR_NODATA.to_string())), } } diff --git a/src/chat/service.rs b/src/chat/service.rs index 24a7758..28b1fa3 100644 --- a/src/chat/service.rs +++ b/src/chat/service.rs @@ -3,11 +3,10 @@ use crate::{ app::{ constant::{ AUTHORIZATION_BEARER_PREFIX, CURSOR_API2_STREAM_CHAT, FINISH_REASON_STOP, - HEADER_NAME_CONTENT_TYPE, OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK, - STATUS_FAILED, STATUS_SUCCESS, + OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK, STATUS_FAILED, STATUS_SUCCESS, }, - model::{AppConfig, AppState, ChatRequest, RequestLog, TokenInfo}, lazy::AUTH_TOKEN, + model::{AppConfig, AppState, ChatRequest, RequestLog, TokenInfo}, }, chat::{ error::StreamError, @@ -19,13 +18,16 @@ use crate::{ common::{ client::build_client, models::{error::ChatError, ErrorResponse}, - utils::get_user_usage, + utils::{get_user_usage, validate_token_and_checksum}, }, }; use axum::{ body::Body, extract::State, - http::{HeaderMap, StatusCode}, + http::{ + header::{AUTHORIZATION, CONTENT_TYPE}, + HeaderMap, StatusCode, + }, response::Response, Json, }; @@ -42,6 +44,8 @@ use std::{ use tokio::sync::Mutex; use uuid::Uuid; +const REQUEST_LOGS_LIMIT: usize = 1000; + // 模型列表处理 pub async fn handle_models() -> Json { Json(ModelsResponse { @@ -79,8 +83,8 @@ pub async fn handle_chat( } // 获取并处理认证令牌 - let auth_token = headers - .get(axum::http::header::AUTHORIZATION) + let auth_header = headers + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(( @@ -88,16 +92,9 @@ pub async fn handle_chat( Json(ChatError::Unauthorized.to_json()), ))?; - // 验证 AuthToken - if auth_token != AUTH_TOKEN.as_str() { - return Err(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - )); - } - - // 完整的令牌处理逻辑和对应的 checksum - let (auth_token, checksum, alias) = { + // 验证 AuthToken 和 获取 token 信息 + let (auth_token, checksum, alias) = if auth_header == AUTH_TOKEN.as_str() { + // 如果是管理员Token,使用原有逻辑 static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0); let state_guard = state.lock().await; let token_infos = &state_guard.token_infos; @@ -116,6 +113,12 @@ pub async fn handle_chat( token_info.checksum.clone(), token_info.alias.clone(), ) + } else { + // 否则尝试解析token + validate_token_and_checksum(auth_header).ok_or(( + StatusCode::UNAUTHORIZED, + Json(ChatError::Unauthorized.to_json()), + ))? }; // 更新请求日志 @@ -147,7 +150,9 @@ pub async fn handle_chat( } } + let next_id = state.request_logs.last().map_or(1, |log| log.id + 1); state.request_logs.push(RequestLog { + id: next_id, timestamp: request_time, model: request.model.clone(), token_info: TokenInfo { @@ -162,7 +167,7 @@ pub async fn handle_chat( error: None, }); - if state.request_logs.len() > 100 { + if state.request_logs.len() > REQUEST_LOGS_LIMIT { state.request_logs.remove(0); } } @@ -420,11 +425,6 @@ pub async fn handle_chat( } Ok(Bytes::new()) } - Err(StreamError::ChatError(error)) => { - buffer_guard.clear(); - eprintln!("Stream error occurred: {}", error.to_json()); - Ok(Bytes::new()) - } Err(e) => { buffer_guard.clear(); eprintln!("[警告] Stream error: {}", e); @@ -438,7 +438,7 @@ pub async fn handle_chat( Ok(Response::builder() .header("Cache-Control", "no-cache") .header("Connection", "keep-alive") - .header(HEADER_NAME_CONTENT_TYPE, "text/event-stream") + .header(CONTENT_TYPE, "text/event-stream") .body(Body::from_stream(stream)) .unwrap()) } else { @@ -480,7 +480,7 @@ pub async fn handle_chat( } Err(StreamError::ChatError(error)) => { return Err(( - StatusCode::from_u16(error.error.details[0].debug.status_code()) + StatusCode::from_u16(error.status_code()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), Json(error.to_error_response().to_common()), )); @@ -545,7 +545,7 @@ pub async fn handle_chat( }; Ok(Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, "application/json") + .header(CONTENT_TYPE, "application/json") .body(Body::from(serde_json::to_string(&response_data).unwrap())) .unwrap()) } diff --git a/src/common/client.rs b/src/common/client.rs index 70f888e..25732c3 100644 --- a/src/common/client.rs +++ b/src/common/client.rs @@ -1,9 +1,12 @@ -use crate::app::constant::{ - AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, - CURSOR_API2_BASE_URL, CURSOR_API2_HOST, CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION, - HEADER_NAME_CONTENT_TYPE, +use crate::app::{ + constant::{ + AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, + CURSOR_API2_STREAM_CHAT, HEADER_NAME_GHOST_MODE, + TRUE, FALSE + }, + lazy::{CURSOR_API2_BASE_URL, CURSOR_API2_HOST}, }; -use reqwest::Client; +use reqwest::{header::{CONTENT_TYPE,AUTHORIZATION,USER_AGENT,HOST}, Client}; use uuid::Uuid; /// 返回预构建的 Cursor API 客户端 @@ -17,20 +20,46 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest }; client - .post(format!("{}{}", CURSOR_API2_BASE_URL, endpoint)) - .header(HEADER_NAME_CONTENT_TYPE, content_type) + .post(format!("{}{}", *CURSOR_API2_BASE_URL, endpoint)) + .header(CONTENT_TYPE, content_type) .header( - HEADER_NAME_AUTHORIZATION, + AUTHORIZATION, format!("{}{}", AUTHORIZATION_BEARER_PREFIX, auth_token), ) .header("connect-accept-encoding", "gzip,br") .header("connect-protocol-version", "1") - .header("user-agent", "connect-es/1.6.1") + .header(USER_AGENT, "connect-es/1.6.1") .header("x-amzn-trace-id", format!("Root={}", trace_id)) .header("x-cursor-checksum", checksum) .header("x-cursor-client-version", "0.42.5") .header("x-cursor-timezone", "Asia/Shanghai") - .header("x-ghost-mode", "false") + .header(HEADER_NAME_GHOST_MODE, FALSE) .header("x-request-id", trace_id) - .header("Host", CURSOR_API2_HOST) + .header(HOST, CURSOR_API2_HOST.clone()) +} + +/// 返回预构建的获取 Stripe 账户信息的 Cursor API 客户端 +pub fn build_profile_client(auth_token: &str) -> reqwest::RequestBuilder { + let client = Client::new(); + + client + .get(format!("https://{}/auth/full_stripe_profile", *CURSOR_API2_HOST)) + .header(HOST, CURSOR_API2_HOST.clone()) + .header("sec-ch-ua", "\"Not-A.Brand\";v=\"99\", \"Chromium\";v=\"124\"") + .header(HEADER_NAME_GHOST_MODE, TRUE) + .header("sec-ch-ua-mobile", "?0") + .header( + AUTHORIZATION, + format!("{}{}", AUTHORIZATION_BEARER_PREFIX, auth_token), + ) + .header(USER_AGENT, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.42.5 Chrome/124.0.6367.243 Electron/30.4.0 Safari/537.36") + .header("sec-ch-ua-platform", "\"Windows\"") + .header("accept", "*/*") + .header("origin", "vscode-file://vscode-app") + .header("sec-fetch-site", "cross-site") + .header("sec-fetch-mode", "cors") + .header("sec-fetch-dest", "empty") + .header("accept-encoding", "gzip, deflate, br") + .header("accept-language", "zh-CN") + .header("priority", "u=1, i") } diff --git a/src/common/models/health.rs b/src/common/models/health.rs index 43d2241..f74fe09 100644 --- a/src/common/models/health.rs +++ b/src/common/models/health.rs @@ -7,7 +7,8 @@ pub struct HealthCheckResponse { pub status: ApiStatus, pub version: &'static str, pub uptime: i64, - pub stats: SystemStats, + #[serde(skip_serializing_if = "Option::is_none")] + pub stats: Option, pub models: Vec<&'static str>, pub endpoints: Vec<&'static str>, } diff --git a/src/common/models/usage.rs b/src/common/models/usage.rs index 9a61e05..4aedcf9 100644 --- a/src/common/models/usage.rs +++ b/src/common/models/usage.rs @@ -1,4 +1,4 @@ -use serde::Serialize; +use serde::{Deserialize, Serialize}; #[derive(Serialize)] pub enum GetUserInfo { @@ -12,4 +12,14 @@ pub enum GetUserInfo { pub struct UserUsageInfo { pub fast_requests: u32, pub max_fast_requests: u32, + pub mtype: String, + pub trial_days: u32, +} + +#[derive(Deserialize)] +pub struct StripeProfile { + #[serde(rename = "membershipType")] + pub membership_type: String, + #[serde(rename = "daysRemainingOnTrial")] + pub days_remaining_on_trial: i32, } diff --git a/src/common/utils.rs b/src/common/utils.rs index 9e1a402..bbf390a 100644 --- a/src/common/utils.rs +++ b/src/common/utils.rs @@ -1,11 +1,12 @@ mod checksum; pub use checksum::*; -pub mod tokens; +mod tokens; +pub use tokens::*; use prost::Message as _; use crate::{app::constant::CURSOR_API2_GET_USER_INFO, chat::aiserver::v1::GetUserInfoResponse}; -use super::models::usage::UserUsageInfo; +use super::models::usage::{StripeProfile, UserUsageInfo}; pub fn parse_bool_from_env(key: &str, default: bool) -> bool { std::env::var(key) @@ -43,8 +44,49 @@ pub async fn get_user_usage(auth_token: &str, checksum: &str) -> Option Option<(String, u32)> { + let client = super::client::build_profile_client(auth_token); + let response = client.send().await.ok()?.json::().await.ok()?; + Some((response.membership_type, i32_to_u32(response.days_remaining_on_trial))) +} + +pub fn validate_token_and_checksum(auth_token: &str) -> Option<(String, String, Option)> { + // 提取 token、checksum 和可能的 alias + let (token, checksum, alias) = { + // 先尝试提取 alias + let (token_part, alias) = if let Some(pos) = auth_token.find("::") { + let (alias, rest) = auth_token.split_at(pos); + (&rest[2..], Some(alias)) + } else if let Some(pos) = auth_token.find("%3A%3A") { + let (alias, rest) = auth_token.split_at(pos); + (&rest[6..], Some(alias)) + } else { + (auth_token, None) + }; + + // 提取 token 和 checksum + if let Some(comma_pos) = token_part.find(',') { + let (token, checksum) = token_part.split_at(comma_pos); + (token, &checksum[1..], alias) + } else { + return None; // 缺少必要的 checksum + } + }; + + // 验证 token 和 checksum 有效性 + if validate_token(token) && validate_checksum(checksum) { + Some((token.to_string(), checksum.to_string(), alias.map(String::from))) + } else { + None + } +} diff --git a/src/common/utils/checksum.rs b/src/common/utils/checksum.rs index 75f6326..01d857b 100644 --- a/src/common/utils/checksum.rs +++ b/src/common/utils/checksum.rs @@ -2,7 +2,7 @@ use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use rand::Rng; use sha2::{Digest, Sha256}; -pub fn generate_hash() -> String { +fn generate_hash() -> String { let random_bytes = rand::thread_rng().gen::<[u8; 32]>(); let mut hasher = Sha256::new(); hasher.update(random_bytes); @@ -18,7 +18,7 @@ fn obfuscate_bytes(bytes: &mut [u8]) { } } -pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String { +fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String { let timestamp = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() @@ -42,3 +42,51 @@ pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String { None => format!("{}{}", encoded, device_id), } } + +pub fn generate_checksum_with_default() -> String { + generate_checksum(&generate_hash(), Some(&generate_hash())) +} + +pub fn validate_checksum(checksum: &str) -> bool { + // 首先检查是否包含基本的 base64 编码部分和 hash 格式的 device_id + let parts: Vec<&str> = checksum.split('/').collect(); + + match parts.len() { + // 没有 MAC 地址的情况 + 1 => { + // 检查是否包含 BASE64 编码的 timestamp (8字符) + 64字符的hash + if checksum.len() != 72 { + // 8 + 64 = 72 + return false; + } + + // 验证 device_id hash 部分 + let device_hash = &checksum[8..]; + is_valid_hash(device_hash) + } + // 包含 MAC hash 的情况 + 2 => { + let first_part = parts[0]; + let mac_hash = parts[1]; + + // MAC hash 必须是64字符的十六进制 + if !is_valid_hash(mac_hash) { + return false; + } + + // 递归验证第一部分 + validate_checksum(first_part) + } + _ => false, + } +} + +fn is_valid_hash(hash: &str) -> bool { + // 检查长度是否为64 + if hash.len() != 64 { + return false; + } + + // 检查是否都是有效的十六进制字符 + hash.chars().all(|c| c.is_ascii_hexdigit()) +} diff --git a/src/common/utils/tokens.rs b/src/common/utils/tokens.rs index acd1d7e..107fdf7 100644 --- a/src/common/utils/tokens.rs +++ b/src/common/utils/tokens.rs @@ -4,7 +4,7 @@ use crate::{ model::TokenInfo, lazy::{TOKEN_FILE, TOKEN_LIST_FILE}, }, - common::utils::{generate_checksum, generate_hash}, + common::utils::generate_checksum_with_default, }; // 规范化文件内容并写入 @@ -109,7 +109,7 @@ pub fn load_tokens() -> Vec { } } else { // 为新token生成checksum - let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); + let checksum = generate_checksum_with_default(); token_map.insert(token, (checksum, alias)); } } @@ -142,3 +142,152 @@ pub fn load_tokens() -> Vec { }) .collect() } + +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use chrono::{DateTime, Local, TimeZone}; + +// 验证jwt token是否有效 +pub fn validate_token(token: &str) -> bool { + // 检查 token 格式 + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return false; + } + + // 解码 payload + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(decoded) => decoded, + Err(_) => return false, + }; + + // 转换为字符串 + let payload_str = match String::from_utf8(payload) { + Ok(s) => s, + Err(_) => return false, + }; + + // 解析 JSON + let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) { + Ok(v) => v, + Err(_) => return false, + }; + + // 验证必要字段是否存在且有效 + let required_fields = ["sub", "exp", "iss", "aud", "randomness", "time"]; + for field in required_fields { + if !payload_json.get(field).is_some() { + return false; + } + } + + // 验证 randomness 长度 + if let Some(randomness) = payload_json["randomness"].as_str() { + if randomness.len() != 18 { + return false; + } + } else { + return false; + } + + // 验证 time 字段 + if let Some(time) = payload_json["time"].as_str() { + // 验证 time 是否为有效的数字字符串 + if let Ok(time_value) = time.parse::() { + let current_time = chrono::Utc::now().timestamp(); + if time_value > current_time { + return false; + } + } else { + return false; + } + } else { + return false; + } + + // 验证过期时间 + if let Some(exp) = payload_json["exp"].as_i64() { + let current_time = chrono::Utc::now().timestamp(); + if current_time > exp { + return false; + } + } else { + return false; + } + + // 验证发行者 + if payload_json["iss"].as_str() != Some("https://authentication.cursor.sh") { + return false; + } + + // 验证受众 + if payload_json["aud"].as_str() != Some("https://cursor.com") { + return false; + } + + true +} + +// 从 JWT token 中提取用户 ID +pub fn extract_user_id(token: &str) -> Option { + // JWT token 由3部分组成,用 . 分隔 + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return None; + } + + // 解码 payload (第二部分) + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(decoded) => decoded, + Err(_) => return None, + }; + + // 将 payload 转换为字符串 + let payload_str = match String::from_utf8(payload) { + Ok(s) => s, + Err(_) => return None, + }; + + // 解析 JSON + let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) { + Ok(v) => v, + Err(_) => return None, + }; + + // 提取 sub 字段 + payload_json["sub"] + .as_str() + .map(|s| s.split('|').nth(1).unwrap_or(s).to_string()) +} + +// 从 JWT token 中提取 time 字段 +pub fn extract_time(token: &str) -> Option> { + // JWT token 由3部分组成,用 . 分隔 + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return None; + } + + // 解码 payload (第二部分) + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(decoded) => decoded, + Err(_) => return None, + }; + + // 将 payload 转换为字符串 + let payload_str = match String::from_utf8(payload) { + Ok(s) => s, + Err(_) => return None, + }; + + // 解析 JSON + let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) { + Ok(v) => v, + Err(_) => return None, + }; + + // 提取时间戳并转换为本地时间 + payload_json["time"] + .as_str() + .and_then(|t| t.parse::().ok()) + .and_then(|timestamp| Local.timestamp_opt(timestamp, 0).single()) +} diff --git a/src/main.rs b/src/main.rs index b7a4389..b400e9c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,13 +5,13 @@ mod common; use app::{ config::handle_config_update, constant::{ - EMPTY_STRING, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, - ROUTE_GET_CHECKSUM, ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, - ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, - ROUTE_TOKENINFO_PATH, ROUTE_UPDATE_TOKENINFO_PATH, + EMPTY_STRING, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_BASIC_CALIBRATION_PATH, + ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, ROUTE_GET_TOKENINFO_PATH, + ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, ROUTE_README_PATH, + ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKENINFO_PATH, ROUTE_UPDATE_TOKENINFO_PATH, }, - model::*, lazy::{AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH}, + model::*, }; use axum::{ routing::{get, post}, @@ -19,14 +19,14 @@ use axum::{ }; use chat::{ route::{ - get_user_info, handle_about, handle_config_page, handle_env_example, handle_get_checksum, - handle_get_tokeninfo, handle_health, handle_logs, handle_logs_post, handle_readme, - handle_root, handle_static, handle_tokeninfo_page, handle_update_tokeninfo, - handle_update_tokeninfo_post, + get_user_info, handle_about, handle_basic_calibration, handle_config_page, + handle_env_example, handle_get_checksum, handle_get_tokeninfo, handle_health, handle_logs, + handle_logs_post, handle_readme, handle_root, handle_static, handle_tokeninfo_page, + handle_update_tokeninfo, handle_update_tokeninfo_post, }, service::{handle_chat, handle_models}, }; -use common::utils::{parse_bool_from_env, parse_string_from_env, tokens::load_tokens}; +use common::utils::{load_tokens, parse_bool_from_env, parse_string_from_env}; use std::sync::Arc; use tokio::sync::Mutex; use tower_http::cors::CorsLayer; @@ -72,7 +72,6 @@ async fn main() { .route(ROUTE_TOKENINFO_PATH, get(handle_tokeninfo_page)) .route(ROUTE_MODELS_PATH.as_str(), get(handle_models)) .route(ROUTE_GET_CHECKSUM, get(handle_get_checksum)) - .route(ROUTE_GET_USER_INFO_PATH, get(get_user_info)) .route(ROUTE_UPDATE_TOKENINFO_PATH, get(handle_update_tokeninfo)) .route(ROUTE_GET_TOKENINFO_PATH, post(handle_get_tokeninfo)) .route( @@ -88,6 +87,8 @@ async fn main() { .route(ROUTE_STATIC_PATH, get(handle_static)) .route(ROUTE_ABOUT_PATH, get(handle_about)) .route(ROUTE_README_PATH, get(handle_readme)) + .route(ROUTE_BASIC_CALIBRATION_PATH, get(handle_basic_calibration)) + .route(ROUTE_GET_USER_INFO_PATH, get(get_user_info)) .layer(CorsLayer::permissive()) .with_state(state); diff --git a/static/logs.html b/static/logs.html index 8401fe5..08c77d0 100644 --- a/static/logs.html +++ b/static/logs.html @@ -200,6 +200,7 @@ + @@ -234,6 +235,14 @@ + + + + + + + + @@ -272,6 +281,16 @@ document.getElementById('modalChecksum').textContent = tokenInfo.checksum || '-'; document.getElementById('modalAlias').textContent = tokenInfo.alias || '-'; + // 添加会员类型和试用天数显示 + if (tokenInfo.usage) { + document.getElementById('modalMemberType').textContent = tokenInfo.usage.mtype || '-'; + document.getElementById('modalTrialDays').textContent = + tokenInfo.usage.trial_days > 0 ? `${tokenInfo.usage.trial_days}天` : '-'; + } else { + document.getElementById('modalMemberType').textContent = '-'; + document.getElementById('modalTrialDays').textContent = '-'; + } + // 获取进度条容器 const progressContainer = document.querySelector('.usage-progress-container'); @@ -304,6 +323,7 @@ tbody.innerHTML = data.logs.map(log => ` +
id 时间 模型 Token信息别名:
会员类型:
试用剩余天数:
使用情况:
${log.id} ${new Date(log.timestamp).toLocaleString()} ${log.model}