diff --git a/.gitignore b/.gitignore index 83c1229..4114801 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,4 @@ node_modules /cursor-api.exe /release -/static/readme.html -/*.py -/src/decoder.rs \ No newline at end of file +/*.py \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 69d9ea2..22a2711 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -292,7 +292,7 @@ dependencies = [ [[package]] name = "cursor-api" -version = "0.1.3-rc.1" +version = "0.1.3-rc.2" dependencies = [ "axum", "base64", @@ -305,6 +305,7 @@ dependencies = [ "hex", "image", "lazy_static", + "paste", "prost", "prost-build", "rand", @@ -1108,6 +1109,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "percent-encoding" version = "2.3.1" diff --git a/Cargo.toml b/Cargo.toml index 561a7c2..c8ab72d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cursor-api" -version = "0.1.3-rc.1" +version = "0.1.3-rc.2" edition = "2021" authors = ["wisdgod "] # license = "MIT" @@ -26,6 +26,7 @@ gif = { version = "0.13.1", default-features = false, features = ["std"] } hex = { version = "0.4.3", default-features = false, features = ["std"] } image = { version = "0.25.5", default-features = false, features = ["jpeg", "png", "gif", "webp"] } lazy_static = "1.5.0" +paste = "1.0.15" prost = "0.13.4" rand = { version = "0.8.5", default-features = false, features = ["std", "std_rng"] } regex = { version = "1.11.1", default-features = false, features = ["std", "perf"] } diff --git a/README.md b/README.md index 0faa06b..fb6088d 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ ## 获取key -1. 访问 [www.cursor.com](https://www.cursor.com) 并完成注册登录(赠送 250 次快速响应,可通过删除账号再注册重置) +1. 访问 [www.cursor.com](https://www.cursor.com) 并完成注册登录 2. 在浏览器中打开开发者工具(F12) -3. 找到 Application-Cookies 中名为 `WorkosCursorSessionToken` 的值并复制其第3个字段,%3A%3A是::的编码,cookie用:分隔值 +3. 在 Application-Cookies 中查找名为 `WorkosCursorSessionToken` 的条目,并复制其第三个字段。请注意,%3A%3A 是 :: 的 URL 编码形式,cookie 的值使用冒号 (:) 进行分隔。 ## 接口说明 @@ -107,7 +107,6 @@ gpt-4o claude-3-opus cursor-fast cursor-small -gpt-3.5 gpt-3.5-turbo gpt-4-turbo-2024-04-09 gpt-4o-128k diff --git a/build.rs b/build.rs index f9d3469..7477417 100644 --- a/build.rs +++ b/build.rs @@ -139,7 +139,7 @@ fn minify_assets() -> Result<()> { fn main() -> Result<()> { // Proto 文件处理 - println!("cargo:rerun-if-changed=src/aiserver/v1/aiserver.proto"); + println!("cargo:rerun-if-changed=src/chat/aiserver/v1/aiserver.proto"); let mut config = prost_build::Config::new(); // config.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]"); // config.type_attribute( @@ -147,7 +147,7 @@ fn main() -> Result<()> { // "#[derive(serde::Serialize, serde::Deserialize)]" // ); config - .compile_protos(&["src/aiserver/v1/aiserver.proto"], &["src/aiserver/v1/"]) + .compile_protos(&["src/chat/aiserver/v1/aiserver.proto"], &["src/chat/aiserver/v1/"]) .unwrap(); // 静态资源文件处理 diff --git a/src/app.rs b/src/app.rs index f83e629..c9d83c6 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,5 +1,7 @@ -pub mod models; +pub mod client; +pub mod config; pub mod constant; +pub mod models; +pub mod statics; pub mod token; pub mod utils; -pub mod client; diff --git a/src/app/config.rs b/src/app/config.rs new file mode 100644 index 0000000..c68b319 --- /dev/null +++ b/src/app/config.rs @@ -0,0 +1,301 @@ +use super::{ + constant::*, + models::{AppConfig, AppState}, + statics::*, +}; +use crate::common::models::{ + config::{ConfigData, ConfigUpdateRequest}, + ApiStatus, ErrorResponse, NormalResponse, +}; +use axum::{ + extract::State, + http::{HeaderMap, StatusCode}, + Json, +}; +use std::sync::Arc; +use tokio::sync::Mutex; + +pub async fn handle_config_update( + State(_state): State>>, + headers: HeaderMap, + Json(request): Json, +) -> Result>, (StatusCode, Json)> { + let auth_header = headers + .get(HEADER_NAME_AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) + .ok_or(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(401), + error: Some("未提供认证令牌".to_string()), + message: None, + }), + ))?; + + if auth_header != get_auth_token() { + return Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(401), + error: Some("无效的认证令牌".to_string()), + message: None, + }), + )); + } + + match request.action.as_str() { + "get" => Ok(Json(NormalResponse { + status: ApiStatus::Success, + data: Some(ConfigData { + page_content: AppConfig::get_page_content(&request.path), + enable_stream_check: AppConfig::get_stream_check(), + include_stop_stream: AppConfig::get_stop_stream(), + vision_ability: AppConfig::get_vision_ability(), + enable_slow_pool: AppConfig::get_slow_pool(), + enable_all_claude: AppConfig::get_allow_claude(), + check_usage_models: AppConfig::get_usage_check(), + }), + message: None, + })), + + "update" => { + // 处理页面内容更新 + if !request.path.is_empty() && request.content.is_some() { + let content = request.content.unwrap(); + + if let Err(e) = AppConfig::update_page_content(&request.path, content) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("更新页面内容失败: {}", e)), + message: None, + }), + )); + } + } + + // 处理 enable_stream_check 更新 + if let Some(enable_stream_check) = request.enable_stream_check { + if let Err(e) = AppConfig::update_stream_check(enable_stream_check) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("更新 enable_stream_check 失败: {}", e)), + message: None, + }), + )); + } + } + + // 处理 include_stop_stream 更新 + if let Some(include_stop_stream) = request.include_stop_stream { + if let Err(e) = AppConfig::update_stop_stream(include_stop_stream) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("更新 include_stop_stream 失败: {}", e)), + message: None, + }), + )); + } + } + + // 处理 vision_ability 更新 + if let Some(vision_ability) = request.vision_ability { + if let Err(e) = AppConfig::update_vision_ability(vision_ability) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("更新 vision_ability 失败: {}", e)), + message: None, + }), + )); + } + } + + // 处理 enable_slow_pool 更新 + if let Some(enable_slow_pool) = request.enable_slow_pool { + if let Err(e) = AppConfig::update_slow_pool(enable_slow_pool) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("更新 enable_slow_pool 失败: {}", e)), + message: None, + }), + )); + } + } + + // 处理 enable_all_claude 更新 + if let Some(enable_all_claude) = request.enable_all_claude { + if let Err(e) = AppConfig::update_allow_claude(enable_all_claude) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("更新 enable_all_claude 失败: {}", e)), + message: None, + }), + )); + } + } + + // 处理 check_usage_models 更新 + if let Some(check_usage_models) = request.check_usage_models { + if let Err(e) = AppConfig::update_usage_check(check_usage_models) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("更新 check_usage_models 失败: {}", e)), + message: None, + }), + )); + } + } + + Ok(Json(NormalResponse { + status: ApiStatus::Success, + data: None, + message: Some("配置已更新".to_string()), + })) + } + + "reset" => { + // 重置页面内容 + if !request.path.is_empty() { + if let Err(e) = AppConfig::reset_page_content(&request.path) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("重置页面内容失败: {}", e)), + message: None, + }), + )); + } + } + + // 重置 enable_stream_check + if request.enable_stream_check.is_some() { + if let Err(e) = AppConfig::reset_stream_check() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("重置 enable_stream_check 失败: {}", e)), + message: None, + }), + )); + } + } + + // 重置 include_stop_stream + if request.include_stop_stream.is_some() { + if let Err(e) = AppConfig::reset_stop_stream() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("重置 include_stop_stream 失败: {}", e)), + message: None, + }), + )); + } + } + + // 重置 vision_ability + if request.vision_ability.is_some() { + if let Err(e) = AppConfig::reset_vision_ability() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("重置 vision_ability 失败: {}", e)), + message: None, + }), + )); + } + } + + // 重置 enable_slow_pool + if request.enable_slow_pool.is_some() { + if let Err(e) = AppConfig::reset_slow_pool() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("重置 enable_slow_pool 失败: {}", e)), + message: None, + }), + )); + } + } + + // 重置 enable_all_claude + if request.enable_all_claude.is_some() { + if let Err(e) = AppConfig::reset_allow_claude() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("重置 enable_slow_pool 失败: {}", e)), + message: None, + }), + )); + } + } + + // 重置 check_usage_models + if request.check_usage_models.is_some() { + if let Err(e) = AppConfig::reset_usage_check() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("重置 check_usage_models 失败: {}", e)), + message: None, + }), + )); + } + } + Ok(Json(NormalResponse { + status: ApiStatus::Success, + data: None, + message: Some("配置已重置".to_string()), + })) + } + + _ => Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(400), + error: Some("无效的操作类型".to_string()), + message: None, + }), + )), + } +} diff --git a/src/app/constant.rs b/src/app/constant.rs index 5447ad6..15cff8f 100644 --- a/src/app/constant.rs +++ b/src/app/constant.rs @@ -1,64 +1,71 @@ -pub const PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); -pub const PKG_NAME: &str = env!("CARGO_PKG_NAME"); -pub const PKG_DESCRIPTION: &str = env!("CARGO_PKG_DESCRIPTION"); -pub const PKG_AUTHORS: &str = env!("CARGO_PKG_AUTHORS"); -pub const PKG_REPOSITORY: &str = env!("CARGO_PKG_REPOSITORY"); +macro_rules! def_pub_const { + ($name:ident, $value:expr) => { + pub const $name: &'static str = $value; + }; +} -pub const ROUTER_ROOT_PATH: &str = "/"; -pub const ROUTER_HEALTH_PATH: &str = "/health"; -pub const ROUTER_GET_CHECKSUM: &str = "/get-checksum"; -pub const ROUTER_GET_USER_INFO_PATH: &str = "/get-user-info"; -pub const ROUTER_LOGS_PATH: &str = "/logs"; -pub const ROUTER_CONFIG_PATH: &str = "/config"; -pub const ROUTER_TOKENINFO_PATH: &str = "/tokeninfo"; -pub const ROUTER_GET_TOKENINFO_PATH: &str = "/get-tokeninfo"; -pub const ROUTER_UPDATE_TOKENINFO_PATH: &str = "/update-tokeninfo"; -pub const ROUTER_ENV_EXAMPLE_PATH: &str = "/env-example"; -pub const ROUTER_SHARED_STYLES_PATH: &str = "/static/shared-styles.css"; -pub const ROUTER_SHARED_JS_PATH: &str = "/static/shared.js"; +def_pub_const!(PKG_VERSION, env!("CARGO_PKG_VERSION")); +def_pub_const!(PKG_NAME, env!("CARGO_PKG_NAME")); +def_pub_const!(PKG_DESCRIPTION, env!("CARGO_PKG_DESCRIPTION")); +def_pub_const!(PKG_AUTHORS, env!("CARGO_PKG_AUTHORS")); +def_pub_const!(PKG_REPOSITORY, env!("CARGO_PKG_REPOSITORY")); -pub const STATUS: &str = "status"; -pub const MESSAGE: &str = "message"; -pub const ERROR: &str = "error"; +def_pub_const!(EMPTY_STRING, ""); -pub const TOKEN_FILE: &str = "token_file"; -pub const TOKEN_LIST_FILE: &str = "token_list_file"; -pub const TOKENS: &str = "tokens"; -pub const TOKEN_LIST: &str = "token_list"; +def_pub_const!(ROUTE_ROOT_PATH, "/"); +def_pub_const!(ROUTE_HEALTH_PATH, "/health"); +def_pub_const!(ROUTE_GET_CHECKSUM, "/get-checksum"); +def_pub_const!(ROUTE_GET_USER_INFO_PATH, "/get-user-info"); +def_pub_const!(ROUTE_LOGS_PATH, "/logs"); +def_pub_const!(ROUTE_CONFIG_PATH, "/config"); +def_pub_const!(ROUTE_TOKENINFO_PATH, "/tokeninfo"); +def_pub_const!(ROUTE_GET_TOKENINFO_PATH, "/get-tokeninfo"); +def_pub_const!(ROUTE_UPDATE_TOKENINFO_PATH, "/update-tokeninfo"); +def_pub_const!(ROUTE_ENV_EXAMPLE_PATH, "/env-example"); +def_pub_const!(ROUTE_STATIC_PATH, "/static/:path"); +def_pub_const!(ROUTE_SHARED_STYLES_PATH, "/static/shared-styles.css"); +def_pub_const!(ROUTE_SHARED_JS_PATH, "/static/shared.js"); +def_pub_const!(ROUTE_ABOUT_PATH, "/about"); +def_pub_const!(ROUTE_README_PATH, "/readme"); -pub const STATUS_SUCCESS: &str = "success"; -pub const STATUS_FAILED: &str = "failed"; +def_pub_const!(STATUS, "status"); +def_pub_const!(MESSAGE, "message"); +def_pub_const!(ERROR, "error"); -pub const HEADER_NAME_CONTENT_TYPE: &str = "content-type"; -pub const HEADER_NAME_AUTHORIZATION: &str = "Authorization"; +def_pub_const!(TOKEN_FILE, "token_file"); +def_pub_const!(DEFAULT_TOKEN_FILE_NAME, ".token"); +def_pub_const!(TOKEN_LIST_FILE, "token_list_file"); +def_pub_const!(DEFAULT_TOKEN_LIST_FILE_NAME, ".token-list"); +def_pub_const!(TOKENS, "tokens"); +def_pub_const!(TOKEN_LIST, "token_list"); -pub const CONTENT_TYPE_PROTO: &str = "application/proto"; -pub const CONTENT_TYPE_CONNECT_PROTO: &str = "application/connect+proto"; -pub const CONTENT_TYPE_TEXT_HTML_WITH_UTF8: &str = "text/html;charset=utf-8"; -pub const CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8: &str = "text/plain;charset=utf-8"; +def_pub_const!(STATUS_SUCCESS, "success"); +def_pub_const!(STATUS_FAILED, "failed"); -pub const AUTHORIZATION_BEARER_PREFIX: &str = "Bearer "; +def_pub_const!(HEADER_NAME_CONTENT_TYPE, "content-type"); +def_pub_const!(HEADER_NAME_AUTHORIZATION, "authorization"); +def_pub_const!(HEADER_NAME_LOCATION, "Location"); -pub const OBJECT_CHAT_COMPLETION: &str = "chat.completion"; -pub const OBJECT_CHAT_COMPLETION_CHUNK: &str = "chat.completion.chunk"; +def_pub_const!(CONTENT_TYPE_PROTO, "application/proto"); +def_pub_const!(CONTENT_TYPE_CONNECT_PROTO, "application/connect+proto"); +def_pub_const!(CONTENT_TYPE_TEXT_HTML_WITH_UTF8, "text/html;charset=utf-8"); +def_pub_const!(CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, "text/plain;charset=utf-8"); +def_pub_const!(CONTENT_TYPE_TEXT_CSS_WITH_UTF8, "text/css;charset=utf-8"); +def_pub_const!(CONTENT_TYPE_TEXT_JS_WITH_UTF8, "text/javascript;charset=utf-8"); -pub const CURSOR_API2_HOST: &str = "api2.cursor.sh"; -pub const CURSOR_API2_BASE_URL: &str = "https://api2.cursor.sh/aiserver.v1.AiService/"; +def_pub_const!(AUTHORIZATION_BEARER_PREFIX, "Bearer "); -pub const CURSOR_API2_STREAM_CHAT: &str = "StreamChat"; -pub const CURSOR_API2_GET_USER_INFO: &str = "GetUserInfo"; +def_pub_const!(OBJECT_CHAT_COMPLETION, "chat.completion"); +def_pub_const!(OBJECT_CHAT_COMPLETION_CHUNK, "chat.completion.chunk"); -pub const FINISH_REASON_STOP: &str = "stop"; +def_pub_const!(CURSOR_API2_HOST, "api2.cursor.sh"); +def_pub_const!(CURSOR_API2_BASE_URL, "https://api2.cursor.sh/aiserver.v1.AiService/"); -pub const LONG_CONTEXT_MODELS: [&str; 4] = [ - "gpt-4o-128k", - "gemini-1.5-flash-500k", - "claude-3-haiku-200k", - "claude-3-5-sonnet-200k", -]; +def_pub_const!(CURSOR_API2_STREAM_CHAT, "StreamChat"); +def_pub_const!(CURSOR_API2_GET_USER_INFO, "GetUserInfo"); -pub const MODEL_OBJECT: &str = "model"; -pub const ANTHROPIC: &str = "anthropic"; -pub const CURSOR: &str = "cursor"; -pub const GOOGLE: &str = "google"; -pub const OPENAI: &str = "openai"; +def_pub_const!(FINISH_REASON_STOP, "stop"); + +def_pub_const!(ERR_UPDATE_CONFIG, "无法更新配置"); +def_pub_const!(ERR_RESET_CONFIG, "无法重置配置"); +def_pub_const!(ERR_INVALID_PATH, "无效的路径"); diff --git a/src/app/models.rs b/src/app/models.rs index 6c55d58..d71a1ef 100644 --- a/src/app/models.rs +++ b/src/app/models.rs @@ -1,6 +1,5 @@ use super::{constant::*, token::UserUsageInfo}; -use crate::message::Message; -use chrono::{DateTime, Local}; +use crate::chat::models::Message; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use std::sync::RwLock; @@ -23,20 +22,19 @@ impl Default for PageContent { } } +mod usage_check; +pub use usage_check::UsageCheck; + // 静态配置 #[derive(Clone)] pub struct AppConfig { - enable_stream_check: bool, - include_stop_stream: bool, + stream_check: bool, + stop_stream: bool, vision_ability: VisionAbility, - enable_slow_pool: bool, + slow_pool: bool, allow_claude: bool, - auth_token: String, - token_file: String, - token_list_file: String, - route_prefix: String, - pub start_time: chrono::DateTime, pages: Pages, + usage_check: UsageCheck, } #[derive(Serialize, Deserialize, Clone)] @@ -50,12 +48,12 @@ pub enum VisionAbility { } impl VisionAbility { - pub fn from_str(s: &str) -> Result { + pub fn from_str(s: &str) -> Self { match s.to_lowercase().as_str() { - "none" | "disabled" => Ok(Self::None), - "base64" | "base64-only" => Ok(Self::Base64), - "all" | "base64-http" => Ok(Self::All), - _ => Err("Invalid VisionAbility value"), + "none" | "disabled" => Self::None, + "base64" | "base64-only" => Self::Base64, + "all" | "base64-http" => Self::All, + _ => Self::default(), } } } @@ -66,7 +64,7 @@ impl Default for VisionAbility { } } -#[derive(Clone)] +#[derive(Clone, Default)] pub struct Pages { pub root_content: PageContent, pub logs_content: PageContent, @@ -74,19 +72,8 @@ pub struct Pages { pub tokeninfo_content: PageContent, pub shared_styles_content: PageContent, pub shared_js_content: PageContent, -} - -impl Default for Pages { - fn default() -> Self { - Self { - root_content: PageContent::Default, - logs_content: PageContent::Default, - config_content: PageContent::Default, - tokeninfo_content: PageContent::Default, - shared_styles_content: PageContent::Default, - shared_js_content: PageContent::Default, - } - } + pub about_content: PageContent, + pub readme_content: PageContent, } // 运行时状态 @@ -105,58 +92,72 @@ lazy_static! { impl Default for AppConfig { fn default() -> Self { Self { - enable_stream_check: true, - include_stop_stream: true, + stream_check: true, + stop_stream: true, vision_ability: VisionAbility::Base64, - enable_slow_pool: false, + slow_pool: false, allow_claude: false, - auth_token: String::new(), - token_file: ".token".to_string(), - token_list_file: ".token-list".to_string(), - route_prefix: String::new(), - start_time: chrono::Local::now(), pages: Pages::default(), + usage_check: UsageCheck::default(), } } } +macro_rules! config_methods { + ($($field:ident: $type:ty, $default:expr;)*) => { + $( + paste::paste! { + pub fn []() -> $type { + APP_CONFIG + .read() + .map(|config| config.$field.clone()) + .unwrap_or($default) + } + + pub fn [](value: $type) -> Result<(), &'static str> { + if let Ok(mut config) = APP_CONFIG.write() { + config.$field = value; + Ok(()) + } else { + Err(ERR_UPDATE_CONFIG) + } + } + + pub fn []() -> Result<(), &'static str> { + if let Ok(mut config) = APP_CONFIG.write() { + config.$field = $default; + Ok(()) + } else { + Err(ERR_RESET_CONFIG) + } + } + } + )* + }; +} + impl AppConfig { pub fn init( - enable_stream_check: bool, - include_stop_stream: bool, + stream_check: bool, + stop_stream: bool, vision_ability: VisionAbility, - enable_slow_pool: bool, + slow_pool: bool, allow_claude: bool, - auth_token: String, - token_file: String, - token_list_file: String, - route_prefix: String, ) { if let Ok(mut config) = APP_CONFIG.write() { - config.enable_stream_check = enable_stream_check; - config.include_stop_stream = include_stop_stream; + config.stream_check = stream_check; + config.stop_stream = stop_stream; config.vision_ability = vision_ability; - config.enable_slow_pool = enable_slow_pool; + config.slow_pool = slow_pool; config.allow_claude = allow_claude; - config.auth_token = auth_token; - config.token_file = token_file; - config.token_list_file = token_list_file; - config.route_prefix = route_prefix; } } - pub fn get_stream_check() -> bool { - APP_CONFIG - .read() - .map(|config| config.enable_stream_check) - .unwrap_or(true) - } - - pub fn get_stop_stream() -> bool { - APP_CONFIG - .read() - .map(|config| config.include_stop_stream) - .unwrap_or(true) + config_methods! { + stream_check: bool, true; + stop_stream: bool, true; + slow_pool: bool, false; + allow_claude: bool, false; } pub fn get_vision_ability() -> VisionAbility { @@ -166,137 +167,62 @@ impl AppConfig { .unwrap_or_default() } - pub fn get_slow_pool() -> bool { - APP_CONFIG - .read() - .map(|config| config.enable_slow_pool) - .unwrap_or(false) - } - - pub fn get_allow_claude() -> bool { - APP_CONFIG - .read() - .map(|config| config.allow_claude) - .unwrap_or(false) - } - - pub fn get_auth_token() -> String { - APP_CONFIG - .read() - .map(|config| config.auth_token.clone()) - .unwrap_or_default() - } - - pub fn get_token_file() -> String { - APP_CONFIG - .read() - .map(|config| config.token_file.clone()) - .unwrap_or_default() - } - - pub fn get_token_list_file() -> String { - APP_CONFIG - .read() - .map(|config| config.token_list_file.clone()) - .unwrap_or_default() - } - - pub fn get_route_prefix() -> String { - APP_CONFIG - .read() - .map(|config| config.route_prefix.clone()) - .unwrap_or_default() - } - pub fn get_page_content(path: &str) -> Option { APP_CONFIG.read().ok().map(|config| match path { - ROUTER_ROOT_PATH => config.pages.root_content.clone(), - ROUTER_LOGS_PATH => config.pages.logs_content.clone(), - ROUTER_CONFIG_PATH => config.pages.config_content.clone(), - ROUTER_TOKENINFO_PATH => config.pages.tokeninfo_content.clone(), - ROUTER_SHARED_STYLES_PATH => config.pages.shared_styles_content.clone(), - ROUTER_SHARED_JS_PATH => config.pages.shared_js_content.clone(), - _ => PageContent::Default, + ROUTE_ROOT_PATH => config.pages.root_content.clone(), + ROUTE_LOGS_PATH => config.pages.logs_content.clone(), + ROUTE_CONFIG_PATH => config.pages.config_content.clone(), + ROUTE_TOKENINFO_PATH => config.pages.tokeninfo_content.clone(), + ROUTE_SHARED_STYLES_PATH => config.pages.shared_styles_content.clone(), + ROUTE_SHARED_JS_PATH => config.pages.shared_js_content.clone(), + ROUTE_ABOUT_PATH => config.pages.about_content.clone(), + ROUTE_README_PATH => config.pages.readme_content.clone(), + _ => PageContent::default(), }) } - pub fn update_stream_check(enable: bool) -> Result<(), &'static str> { - if let Ok(mut config) = APP_CONFIG.write() { - config.enable_stream_check = enable; - Ok(()) - } else { - Err("无法更新配置") - } + pub fn get_usage_check() -> UsageCheck { + APP_CONFIG + .read() + .map(|config| config.usage_check.clone()) + .unwrap_or_default() } - pub fn update_stop_stream(enable: bool) -> Result<(), &'static str> { - if let Ok(mut config) = APP_CONFIG.write() { - config.include_stop_stream = enable; - Ok(()) - } else { - Err("无法更新配置") - } - } pub fn update_vision_ability(new_ability: VisionAbility) -> Result<(), &'static str> { if let Ok(mut config) = APP_CONFIG.write() { config.vision_ability = new_ability; Ok(()) } else { - Err("无法更新配置") - } - } - - pub fn update_slow_pool(enable: bool) -> Result<(), &'static str> { - if let Ok(mut config) = APP_CONFIG.write() { - config.enable_slow_pool = enable; - Ok(()) - } else { - Err("无法更新配置") - } - } - - pub fn update_allow_claude(enable: bool) -> Result<(), &'static str> { - if let Ok(mut config) = APP_CONFIG.write() { - config.allow_claude = enable; - Ok(()) - } else { - Err("无法更新配置") + Err(ERR_UPDATE_CONFIG) } } pub fn update_page_content(path: &str, content: PageContent) -> Result<(), &'static str> { if let Ok(mut config) = APP_CONFIG.write() { match path { - ROUTER_ROOT_PATH => config.pages.root_content = content, - ROUTER_LOGS_PATH => config.pages.logs_content = content, - ROUTER_CONFIG_PATH => config.pages.config_content = content, - ROUTER_TOKENINFO_PATH => config.pages.tokeninfo_content = content, - ROUTER_SHARED_STYLES_PATH => config.pages.shared_styles_content = content, - ROUTER_SHARED_JS_PATH => config.pages.shared_js_content = content, - _ => return Err("无效的路径"), + ROUTE_ROOT_PATH => config.pages.root_content = content, + ROUTE_LOGS_PATH => config.pages.logs_content = content, + ROUTE_CONFIG_PATH => config.pages.config_content = content, + ROUTE_TOKENINFO_PATH => config.pages.tokeninfo_content = content, + ROUTE_SHARED_STYLES_PATH => config.pages.shared_styles_content = content, + ROUTE_SHARED_JS_PATH => config.pages.shared_js_content = content, + ROUTE_ABOUT_PATH => config.pages.about_content = content, + ROUTE_README_PATH => config.pages.readme_content = content, + _ => return Err(ERR_INVALID_PATH), } Ok(()) } else { - Err("无法更新配置") + Err(ERR_UPDATE_CONFIG) } } - pub fn reset_stream_check() -> Result<(), &'static str> { + pub fn update_usage_check(rule: UsageCheck) -> Result<(), &'static str> { if let Ok(mut config) = APP_CONFIG.write() { - config.enable_stream_check = true; + config.usage_check = rule; Ok(()) } else { - Err("无法重置配置") - } - } - - pub fn reset_stop_stream() -> Result<(), &'static str> { - if let Ok(mut config) = APP_CONFIG.write() { - config.include_stop_stream = true; - Ok(()) - } else { - Err("无法重置配置") + Err(ERR_UPDATE_CONFIG) } } @@ -305,44 +231,37 @@ impl AppConfig { config.vision_ability = VisionAbility::Base64; Ok(()) } else { - Err("无法重置配置") - } - } - - pub fn reset_slow_pool() -> Result<(), &'static str> { - if let Ok(mut config) = APP_CONFIG.write() { - config.enable_slow_pool = false; - Ok(()) - } else { - Err("无法重置配置") - } - } - - pub fn reset_allow_claude() -> Result<(), &'static str> { - if let Ok(mut config) = APP_CONFIG.write() { - config.allow_claude = false; - Ok(()) - } else { - Err("无法重置配置") + Err(ERR_RESET_CONFIG) } } pub fn reset_page_content(path: &str) -> Result<(), &'static str> { if let Ok(mut config) = APP_CONFIG.write() { match path { - ROUTER_ROOT_PATH => config.pages.root_content = PageContent::Default, - ROUTER_LOGS_PATH => config.pages.logs_content = PageContent::Default, - ROUTER_CONFIG_PATH => config.pages.config_content = PageContent::Default, - ROUTER_TOKENINFO_PATH => config.pages.tokeninfo_content = PageContent::Default, - ROUTER_SHARED_STYLES_PATH => { - config.pages.shared_styles_content = PageContent::Default + ROUTE_ROOT_PATH => config.pages.root_content = PageContent::default(), + ROUTE_LOGS_PATH => config.pages.logs_content = PageContent::default(), + ROUTE_CONFIG_PATH => config.pages.config_content = PageContent::default(), + ROUTE_TOKENINFO_PATH => config.pages.tokeninfo_content = PageContent::default(), + ROUTE_SHARED_STYLES_PATH => { + config.pages.shared_styles_content = PageContent::default() } - ROUTER_SHARED_JS_PATH => config.pages.shared_js_content = PageContent::Default, - _ => return Err("无效的路径"), + ROUTE_SHARED_JS_PATH => config.pages.shared_js_content = PageContent::default(), + ROUTE_ABOUT_PATH => config.pages.about_content = PageContent::default(), + ROUTE_README_PATH => config.pages.readme_content = PageContent::default(), + _ => return Err(ERR_INVALID_PATH), } Ok(()) } else { - Err("无法重置配置") + Err(ERR_RESET_CONFIG) + } + } + + pub fn reset_usage_check() -> Result<(), &'static str> { + if let Ok(mut config) = APP_CONFIG.write() { + config.usage_check = UsageCheck::default(); + Ok(()) + } else { + Err(ERR_RESET_CONFIG) } } } @@ -362,31 +281,16 @@ impl AppState { } } -// 模型定义 -#[derive(Serialize, Clone)] -pub struct Model { - pub id: String, - pub created: i64, - pub object: String, - pub owned_by: String, -} - -// impl Model { -// pub fn is_pesticide(&self) -> bool { -// !(self.owned_by.as_str() == CURSOR || self.id.as_str() == "gpt-4o-mini") -// } -// } - // 请求日志 #[derive(Serialize, Clone)] pub struct RequestLog { - pub timestamp: DateTime, + pub timestamp: chrono::DateTime, pub model: String, pub token_info: TokenInfo, #[serde(skip_serializing_if = "Option::is_none")] pub prompt: Option, pub stream: bool, - pub status: String, + pub status: &'static str, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } @@ -424,24 +328,3 @@ pub struct TokenUpdateRequest { #[serde(default)] pub token_list: Option, } - -// 添加用于接收更新请求的结构体 -#[derive(Deserialize)] -pub struct ConfigUpdateRequest { - #[serde(default)] - pub action: String, // "get", "update", "reset" - #[serde(default)] - pub path: String, - #[serde(default)] - pub content: Option, // "default", "text", "html" - #[serde(default)] - pub enable_stream_check: Option, - #[serde(default)] - pub include_stop_stream: Option, - #[serde(default)] - pub vision_ability: Option, - #[serde(default)] - pub enable_slow_pool: Option, - #[serde(default)] - pub enable_all_claude: Option, -} diff --git a/src/app/models/usage_check.rs b/src/app/models/usage_check.rs new file mode 100644 index 0000000..b7c4848 --- /dev/null +++ b/src/app/models/usage_check.rs @@ -0,0 +1,91 @@ +use crate::chat::constant::AVAILABLE_MODELS; +use serde::{Deserialize, Serialize}; + +#[derive(Clone)] +pub enum UsageCheck { + None, + Default, + All, + Custom(Vec<&'static str>), +} + +impl Default for UsageCheck { + fn default() -> Self { + Self::Default + } +} + +impl Serialize for UsageCheck { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("UsageCheck", 1)?; + match self { + UsageCheck::None => { + state.serialize_field("type", "none")?; + } + UsageCheck::Default => { + state.serialize_field("type", "default")?; + } + UsageCheck::All => { + state.serialize_field("type", "all")?; + } + UsageCheck::Custom(models) => { + state.serialize_field("type", "list")?; + state.serialize_field("content", &models.join(","))?; + } + } + state.end() + } +} + +impl<'de> Deserialize<'de> for UsageCheck { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(tag = "type", content = "content")] + enum UsageCheckHelper { + #[serde(rename = "none")] + None, + #[serde(rename = "default")] + Default, + #[serde(rename = "all")] + All, + #[serde(rename = "list")] + Custom(String), + } + + let helper = UsageCheckHelper::deserialize(deserializer)?; + Ok(match helper { + UsageCheckHelper::None => UsageCheck::None, + UsageCheckHelper::Default => UsageCheck::Default, + UsageCheckHelper::All => UsageCheck::All, + UsageCheckHelper::Custom(list) => { + if list.is_empty() { + return Ok(UsageCheck::None); + } + + let models: Vec<&'static str> = list + .split(',') + .filter_map(|model| { + let model = model.trim(); + AVAILABLE_MODELS + .iter() + .find(|m| m.id == model) + .map(|m| m.id) + }) + .collect(); + + if models.is_empty() { + UsageCheck::None + } else { + UsageCheck::Custom(models) + } + } + }) + } +} diff --git a/src/app/statics.rs b/src/app/statics.rs new file mode 100644 index 0000000..43b559f --- /dev/null +++ b/src/app/statics.rs @@ -0,0 +1,52 @@ +use super::{ + constant::{DEFAULT_TOKEN_FILE_NAME, DEFAULT_TOKEN_LIST_FILE_NAME, EMPTY_STRING}, + utils::parse_string_from_env, +}; +use std::sync::LazyLock; + +macro_rules! def_pub_static { + // 基础版本:直接存储 String + ($name:ident, $value:expr) => { + pub static $name: LazyLock = LazyLock::new(|| $value); + + def_pub_static_getter!($name); + }; + + // 环境变量版本 + ($name:ident, env: $env_key:expr, default: $default:expr) => { + pub static $name: LazyLock = + LazyLock::new(|| parse_string_from_env($env_key, $default).trim().to_string()); + + def_pub_static_getter!($name); + }; +} + +macro_rules! def_pub_static_getter { + ($name:ident) => { + paste::paste! { + pub fn []() -> String { + (*$name).clone() + } + } + }; +} + +def_pub_static!(ROUTE_PREFIX, env: "ROUTE_PREFIX", default: EMPTY_STRING); +def_pub_static!(AUTH_TOKEN, env: "AUTH_TOKEN", default: EMPTY_STRING); +def_pub_static!(TOKEN_FILE, env: "TOKEN_FILE", default: DEFAULT_TOKEN_FILE_NAME); +def_pub_static!(TOKEN_LIST_FILE, env: "TOKEN_LIST_FILE", default: DEFAULT_TOKEN_LIST_FILE_NAME); +def_pub_static!( + ROUTE_MODELS_PATH, + format!("{}/v1/models", ROUTE_PREFIX.as_str()) +); +def_pub_static!( + ROUTE_CHAT_PATH, + format!("{}/v1/chat/completions", ROUTE_PREFIX.as_str()) +); + +pub static START_TIME: LazyLock> = + LazyLock::new(chrono::Local::now); + +pub fn get_start_time() -> chrono::DateTime { + *START_TIME +} diff --git a/src/app/token.rs b/src/app/token.rs index a146fc2..edd6459 100644 --- a/src/app/token.rs +++ b/src/app/token.rs @@ -1,9 +1,10 @@ use super::{ constant::*, - models::{AppConfig, AppState, TokenInfo, TokenUpdateRequest}, - utils::i32_to_u32, + models::{AppState, TokenInfo, TokenUpdateRequest}, + statics::*, + utils::{generate_checksum, generate_hash, i32_to_u32}, }; -use crate::aiserver::v1::GetUserInfoResponse; +use crate::{chat::aiserver::v1::GetUserInfoResponse, common::models::{ApiStatus, NormalResponseNoData}}; use axum::http::HeaderMap; use axum::{ extract::{Query, State}, @@ -41,13 +42,13 @@ fn parse_token_alias(token_part: &str, line: &str) -> Option<(String, Option Vec { - let token_file = AppConfig::get_token_file(); - let token_list_file = AppConfig::get_token_list_file(); + let token_file = get_token_file(); + let token_list_file = get_token_list_file(); // 确保文件存在 for file in [&token_file, &token_list_file] { if !std::path::Path::new(file).exists() { - if let Err(e) = std::fs::write(file, "") { + if let Err(e) = std::fs::write(file, EMPTY_STRING) { eprintln!("警告: 无法创建文件 '{}': {}", file, e); } } @@ -118,8 +119,7 @@ pub fn load_tokens() -> Vec { } } else { // 为新token生成checksum - let checksum = - crate::generate_checksum(&crate::generate_hash(), Some(&crate::generate_hash())); + let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); token_map.insert(token, (checksum, alias)); } } @@ -153,10 +153,20 @@ pub fn load_tokens() -> Vec { .collect() } +#[derive(Serialize)] +pub struct ChecksumResponse { + pub checksum: String, +} + +pub async fn handle_get_checksum() -> Json { + let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); + Json(ChecksumResponse { checksum }) +} + // 更新 TokenInfo 处理 pub async fn handle_update_tokeninfo( State(state): State>>, -) -> Json { +) -> Json { // 重新加载 tokens let token_infos = load_tokens(); @@ -166,20 +176,20 @@ pub async fn handle_update_tokeninfo( state.token_infos = token_infos; } - Json(serde_json::json!({ - STATUS: STATUS_SUCCESS, - MESSAGE: "Token list has been reloaded" - })) + Json(NormalResponseNoData { + status: ApiStatus::Success, + message: Some("Token list has been reloaded".to_string()), + }) } // 获取 TokenInfo 处理 pub async fn handle_get_tokeninfo( State(_state): State>>, headers: HeaderMap, -) -> Result, StatusCode> { - let auth_token = AppConfig::get_auth_token(); - let token_file = AppConfig::get_token_file(); - let token_list_file = AppConfig::get_token_list_file(); +) -> Result, StatusCode> { + let auth_token = get_auth_token(); + let token_file = get_token_file(); + let token_list_file = get_token_list_file(); // 验证 AUTH_TOKEN let auth_header = headers @@ -196,23 +206,40 @@ pub async fn handle_get_tokeninfo( let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new()); let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new()); - Ok(Json(serde_json::json!({ - STATUS: STATUS_SUCCESS, - "token_file": token_file, - "token_list_file": token_list_file, - "tokens": tokens, - "token_list": token_list - }))) + Ok(Json(TokenInfoResponse { + status: ApiStatus::Success, + token_file: token_file.clone(), + token_list_file: token_list_file.clone(), + tokens: Some(tokens.clone()), + tokens_count: Some(tokens.len()), + token_list: Some(token_list), + message: None, + })) +} + +#[derive(Serialize)] +pub struct TokenInfoResponse { + pub status: ApiStatus, + pub token_file: String, + pub token_list_file: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub token_list: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, } pub async fn handle_update_tokeninfo_post( State(state): State>>, headers: HeaderMap, Json(request): Json, -) -> Result, StatusCode> { - let auth_token = AppConfig::get_auth_token(); - let token_file = AppConfig::get_token_file(); - let token_list_file = AppConfig::get_token_list_file(); +) -> Result, StatusCode> { + let auth_token = get_auth_token(); + let token_file = get_token_file(); + let token_list_file = get_token_list_file(); // 验证 AUTH_TOKEN let auth_header = headers @@ -244,13 +271,15 @@ pub async fn handle_update_tokeninfo_post( state.token_infos = token_infos; } - Ok(Json(serde_json::json!({ - STATUS: STATUS_SUCCESS, - MESSAGE: "Token files have been updated and reloaded", - "token_file": token_file, - "token_list_file": token_list_file, - "token_count": token_infos_len - }))) + Ok(Json(TokenInfoResponse { + status: ApiStatus::Success, + token_file: token_file.clone(), + token_list_file: token_list_file.clone(), + tokens: None, + tokens_count: Some(token_infos_len), + token_list: None, + message: Some("Token files have been updated and reloaded".to_string()), + })) } #[derive(Deserialize)] @@ -262,14 +291,13 @@ pub async fn get_user_info( State(state): State>>, Query(query): Query, ) -> Json { - let (auth_token, checksum) = match { - let app_token_infos = &state.lock().await.token_infos; - app_token_infos - .iter() - .find(|token_info| token_info.alias == Some(query.alias.clone())) - .map(|token_info| (token_info.token.clone(), token_info.checksum.clone())) - } { - Some(token) => token, + let token_infos = &state.lock().await.token_infos; + let token_info = token_infos + .iter() + .find(|token_info| token_info.alias == Some(query.alias.clone())); + + let (auth_token, checksum) = match token_info { + Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()), None => return Json(GetUserInfo::Error("No data".to_string())), }; @@ -300,9 +328,9 @@ pub async fn get_user_usage(auth_token: &str, checksum: &str) -> Option bool { std::env::var(key) .ok() diff --git a/src/app/utils/checksum.rs b/src/app/utils/checksum.rs new file mode 100644 index 0000000..75f6326 --- /dev/null +++ b/src/app/utils/checksum.rs @@ -0,0 +1,44 @@ +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; +use rand::Rng; +use sha2::{Digest, Sha256}; + +pub fn generate_hash() -> String { + let random_bytes = rand::thread_rng().gen::<[u8; 32]>(); + let mut hasher = Sha256::new(); + hasher.update(random_bytes); + hex::encode(hasher.finalize()) +} + +fn obfuscate_bytes(bytes: &mut [u8]) { + let mut prev: u8 = 165; + for (idx, byte) in bytes.iter_mut().enumerate() { + let old_value = *byte; + *byte = (old_value ^ prev).wrapping_add((idx % 256) as u8); + prev = *byte; + } +} + +pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() + / 1_000_000; + + let mut timestamp_bytes = vec![ + ((timestamp >> 40) & 255) as u8, + ((timestamp >> 32) & 255) as u8, + ((timestamp >> 24) & 255) as u8, + ((timestamp >> 16) & 255) as u8, + ((timestamp >> 8) & 255) as u8, + (255 & timestamp) as u8, + ]; + + obfuscate_bytes(&mut timestamp_bytes); + let encoded = BASE64.encode(×tamp_bytes); + + match mac_addr { + Some(mac) => format!("{}{}/{}", encoded, device_id, mac), + None => format!("{}{}", encoded, device_id), + } +} diff --git a/src/chat.rs b/src/chat.rs index 2c797f3..a2a4299 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,2 +1,6 @@ -pub mod stream; +pub mod aiserver; +pub mod constant; pub mod error; +pub mod models; +pub mod service; +pub mod stream; diff --git a/src/aiserver.rs b/src/chat/aiserver.rs similarity index 100% rename from src/aiserver.rs rename to src/chat/aiserver.rs diff --git a/src/aiserver/v1.rs b/src/chat/aiserver/v1.rs similarity index 100% rename from src/aiserver/v1.rs rename to src/chat/aiserver/v1.rs diff --git a/src/aiserver/v1/aiserver.proto b/src/chat/aiserver/v1/aiserver.proto similarity index 100% rename from src/aiserver/v1/aiserver.proto rename to src/chat/aiserver/v1/aiserver.proto diff --git a/src/chat/constant.rs b/src/chat/constant.rs new file mode 100644 index 0000000..37e6f85 --- /dev/null +++ b/src/chat/constant.rs @@ -0,0 +1,192 @@ +use super::models::Model; + +macro_rules! def_pub_const { + ($name:ident, $value:expr) => { + pub const $name: &'static str = $value; + }; +} +def_pub_const!(ERR_UNSUPPORTED_GIF, "不支持动态 GIF"); +def_pub_const!(ERR_UNSUPPORTED_IMAGE_FORMAT, "不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF"); + +const MODEL_OBJECT: &str = "model"; +const CREATED: i64 = 1706659200; + +def_pub_const!(ANTHROPIC, "anthropic"); +def_pub_const!(CURSOR, "cursor"); +def_pub_const!(GOOGLE, "google"); +def_pub_const!(OPENAI, "openai"); + +def_pub_const!(CLAUDE_3_5_SONNET, "claude-3.5-sonnet"); +def_pub_const!(GPT_4, "gpt-4"); +def_pub_const!(GPT_4O, "gpt-4o"); +def_pub_const!(CLAUDE_3_OPUS, "claude-3-opus"); +def_pub_const!(CURSOR_FAST, "cursor-fast"); +def_pub_const!(CURSOR_SMALL, "cursor-small"); +def_pub_const!(GPT_3_5_TURBO, "gpt-3.5-turbo"); +def_pub_const!(GPT_4_TURBO_2024_04_09, "gpt-4-turbo-2024-04-09"); +def_pub_const!(GPT_4O_128K, "gpt-4o-128k"); +def_pub_const!(GEMINI_1_5_FLASH_500K, "gemini-1.5-flash-500k"); +def_pub_const!(CLAUDE_3_HAIKU_200K, "claude-3-haiku-200k"); +def_pub_const!(CLAUDE_3_5_SONNET_200K, "claude-3-5-sonnet-200k"); +def_pub_const!(CLAUDE_3_5_SONNET_20241022, "claude-3-5-sonnet-20241022"); +def_pub_const!(GPT_4O_MINI, "gpt-4o-mini"); +def_pub_const!(O1_MINI, "o1-mini"); +def_pub_const!(O1_PREVIEW, "o1-preview"); +def_pub_const!(O1, "o1"); +def_pub_const!(CLAUDE_3_5_HAIKU, "claude-3.5-haiku"); +def_pub_const!(GEMINI_EXP_1206, "gemini-exp-1206"); +def_pub_const!( + GEMINI_2_0_FLASH_THINKING_EXP, + "gemini-2.0-flash-thinking-exp" +); +def_pub_const!(GEMINI_2_0_FLASH_EXP, "gemini-2.0-flash-exp"); + +pub const AVAILABLE_MODELS: &[Model] = &[ + Model { + id: CLAUDE_3_5_SONNET, + created: CREATED, + object: MODEL_OBJECT, + owned_by: ANTHROPIC, + }, + Model { + id: GPT_4, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: GPT_4O, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: CLAUDE_3_OPUS, + created: CREATED, + object: MODEL_OBJECT, + owned_by: ANTHROPIC, + }, + Model { + id: CURSOR_FAST, + created: CREATED, + object: MODEL_OBJECT, + owned_by: CURSOR, + }, + Model { + id: CURSOR_SMALL, + created: CREATED, + object: MODEL_OBJECT, + owned_by: CURSOR, + }, + Model { + id: GPT_3_5_TURBO, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: GPT_4_TURBO_2024_04_09, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: GPT_4O_128K, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: GEMINI_1_5_FLASH_500K, + created: CREATED, + object: MODEL_OBJECT, + owned_by: GOOGLE, + }, + Model { + id: CLAUDE_3_HAIKU_200K, + created: CREATED, + object: MODEL_OBJECT, + owned_by: ANTHROPIC, + }, + Model { + id: CLAUDE_3_5_SONNET_200K, + created: CREATED, + object: MODEL_OBJECT, + owned_by: ANTHROPIC, + }, + Model { + id: CLAUDE_3_5_SONNET_20241022, + created: CREATED, + object: MODEL_OBJECT, + owned_by: ANTHROPIC, + }, + Model { + id: GPT_4O_MINI, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: O1_MINI, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: O1_PREVIEW, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: O1, + created: CREATED, + object: MODEL_OBJECT, + owned_by: OPENAI, + }, + Model { + id: CLAUDE_3_5_HAIKU, + created: CREATED, + object: MODEL_OBJECT, + owned_by: ANTHROPIC, + }, + Model { + id: GEMINI_EXP_1206, + created: CREATED, + object: MODEL_OBJECT, + owned_by: GOOGLE, + }, + Model { + id: GEMINI_2_0_FLASH_THINKING_EXP, + created: CREATED, + object: MODEL_OBJECT, + owned_by: GOOGLE, + }, + Model { + id: GEMINI_2_0_FLASH_EXP, + created: CREATED, + object: MODEL_OBJECT, + owned_by: GOOGLE, + }, +]; + +pub const USAGE_CHECK_MODELS: [&str; 11] = [ + CLAUDE_3_5_SONNET_20241022, + CLAUDE_3_5_SONNET, + GEMINI_EXP_1206, + GPT_4, + GPT_4_TURBO_2024_04_09, + GPT_4O, + CLAUDE_3_5_HAIKU, + GPT_4O_128K, + GEMINI_1_5_FLASH_500K, + CLAUDE_3_HAIKU_200K, + CLAUDE_3_5_SONNET_200K, +]; + +pub const LONG_CONTEXT_MODELS: [&str; 4] = [ + GPT_4O_128K, + GEMINI_1_5_FLASH_500K, + CLAUDE_3_HAIKU_200K, + CLAUDE_3_5_SONNET_200K, +]; diff --git a/src/chat/error.rs b/src/chat/error.rs index 0db0cd4..9f49ee3 100644 --- a/src/chat/error.rs +++ b/src/chat/error.rs @@ -1,4 +1,4 @@ -use crate::aiserver::v1::throw_error_check_request::Error as ErrorType; +use super::aiserver::v1::throw_error_check_request::Error as ErrorType; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; @@ -31,9 +31,9 @@ pub struct ErrorDebug { } impl ErrorDebug { - pub fn is_valid(&self) -> bool { - ErrorType::from_str_name(&self.error).is_some() - } + // pub fn is_valid(&self) -> bool { + // ErrorType::from_str_name(&self.error).is_some() + // } pub fn status_code(&self) -> u16 { match ErrorType::from_str_name(&self.error) { @@ -83,6 +83,8 @@ pub struct ErrorDetails { pub is_retryable: bool, } +use crate::common::models::{ApiStatus, ErrorResponse as CommonErrorResponse}; + impl ChatError { pub fn to_json(&self) -> serde_json::Value { serde_json::to_value(self).unwrap() @@ -135,6 +137,15 @@ impl ErrorResponse { pub fn native_code(&self) -> String { self.code.replace("_", " ").to_lowercase() } + + pub fn to_common(self) -> CommonErrorResponse { + CommonErrorResponse { + status: ApiStatus::Error, + code: Some(self.status), + error: self.error.as_ref().map(|error| error.message.clone()).or(Some(self.code.clone())), + message: self.error.as_ref().map(|error| error.details.clone()), + } + } } pub enum StreamError { diff --git a/src/message.rs b/src/chat/models.rs similarity index 73% rename from src/message.rs rename to src/chat/models.rs index 2c299e6..dd064ee 100644 --- a/src/message.rs +++ b/src/chat/models.rs @@ -76,3 +76,32 @@ pub struct Usage { pub completion_tokens: i32, pub total_tokens: i32, } + +// 模型定义 +#[derive(Serialize, Clone)] +pub struct Model { + pub id: &'static str, + pub created: i64, + pub object: &'static str, + pub owned_by: &'static str, +} + +use crate::{AppConfig, UsageCheck}; +use super::constant::USAGE_CHECK_MODELS; + +impl Model { + pub fn is_usage_check(&self) -> bool { + match AppConfig::get_usage_check() { + UsageCheck::None => false, + UsageCheck::Default => USAGE_CHECK_MODELS.contains(&self.id), + UsageCheck::All => true, + UsageCheck::Custom(models) => models.contains(&self.id), + } + } +} + +#[derive(Serialize)] +pub struct ModelsResponse { + pub object: &'static str, + pub data: &'static [Model], +} diff --git a/src/chat/service.rs b/src/chat/service.rs new file mode 100644 index 0000000..3da3335 --- /dev/null +++ b/src/chat/service.rs @@ -0,0 +1,506 @@ +use axum::{ + body::Body, + extract::State, + http::{HeaderMap, StatusCode}, + response::Response, + Json, +}; +use bytes::Bytes; +use crate::{ + app::{ + client::build_client, + constant::*, + models::*, + statics::*, + token::get_user_usage, + }, chat::{ + error::StreamError, + models::*, + stream::{parse_stream_data, StreamMessage}, + }, common::models::{error::ChatError, ErrorResponse} +}; +use super::constant::AVAILABLE_MODELS; +use futures::{Stream, StreamExt}; +use std::{ + convert::Infallible, sync::{atomic::AtomicBool, Arc} +}; +use std::{ + pin::Pin, + sync::atomic::{AtomicUsize, Ordering}, +}; +use tokio::sync::Mutex; +use uuid::Uuid; + +// 模型列表处理 +pub async fn handle_models() -> Json { + Json(ModelsResponse { + object: "list", + data: AVAILABLE_MODELS, + }) +} + +// 聊天处理函数的签名 +pub async fn handle_chat( + State(state): State>>, + headers: HeaderMap, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + let allow_claude = AppConfig::get_allow_claude(); + // 验证模型是否支持并获取模型信息 + let model = AVAILABLE_MODELS.iter().find(|m| m.id == request.model); + let model_supported = model.is_some(); + + if !(model_supported || allow_claude && request.model.starts_with("claude")) { + return Err(( + StatusCode::BAD_REQUEST, + Json(ChatError::ModelNotSupported(request.model).to_json()), + )); + } + + let request_time = chrono::Local::now(); + + // 验证请求 + if request.messages.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ChatError::EmptyMessages.to_json()), + )); + } + + // 获取并处理认证令牌 + let auth_token = headers + .get(axum::http::header::AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) + .ok_or(( + StatusCode::UNAUTHORIZED, + Json(ChatError::Unauthorized.to_json()), + ))?; + + // 验证 AuthToken + if auth_token != get_auth_token() { + return Err(( + StatusCode::UNAUTHORIZED, + Json(ChatError::Unauthorized.to_json()), + )); + } + + // 完整的令牌处理逻辑和对应的 checksum + let (auth_token, checksum, alias) = { + static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0); + let state_guard = state.lock().await; + let token_infos = &state_guard.token_infos; + + if token_infos.is_empty() { + return Err(( + StatusCode::SERVICE_UNAVAILABLE, + Json(ChatError::NoTokens.to_json()), + )); + } + + let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len(); + let token_info = &token_infos[index]; + ( + token_info.token.clone(), + token_info.checksum.clone(), + token_info.alias.clone(), + ) + }; + + // 更新请求日志 + { + let state_clone = state.clone(); + let mut state = state.lock().await; + state.total_requests += 1; + state.active_requests += 1; + + // 如果有model且需要获取使用情况,创建后台任务获取 + if let Some(model) = model { + if model.is_usage_check() { + let auth_token_clone = auth_token.clone(); + let checksum_clone = checksum.clone(); + let state_clone = state_clone.clone(); + + tokio::spawn(async move { + let usage = get_user_usage(&auth_token_clone, &checksum_clone).await; + let mut state = state_clone.lock().await; + // 根据时间戳找到对应的日志 + if let Some(log) = state + .request_logs + .iter_mut() + .find(|log| log.timestamp == request_time) + { + log.token_info.usage = usage; + } + }); + } + } + + state.request_logs.push(RequestLog { + timestamp: request_time, + model: request.model.clone(), + token_info: TokenInfo { + token: auth_token.clone(), + checksum: checksum.clone(), + alias: alias.clone(), + usage: None, + }, + prompt: None, + stream: request.stream, + status: "pending", + error: None, + }); + + if state.request_logs.len() > 100 { + state.request_logs.remove(0); + } + } + + // 将消息转换为hex格式 + let hex_data = crate::encode_chat_message(request.messages, &request.model) + .await + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + ChatError::RequestFailed("Failed to encode chat message".to_string()).to_json(), + ), + ) + })?; + + // 构建请求客户端 + let client = build_client(&auth_token, &checksum, CURSOR_API2_STREAM_CHAT); + let response = client.body(hex_data).send().await; + + // 处理请求结果 + let response = match response { + Ok(resp) => { + // 更新请求日志为成功 + { + let mut state = state.lock().await; + state.request_logs.last_mut().unwrap().status = STATUS_SUCCESS; + } + resp + } + Err(e) => { + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.status = STATUS_FAILED; + last_log.error = Some(e.to_string()); + } + } + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::RequestFailed(e.to_string()).to_json()), + )); + } + }; + + // 释放活动请求计数 + { + let mut state = state.lock().await; + state.active_requests -= 1; + } + + if request.stream { + let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple()); + let full_text = Arc::new(Mutex::new(String::with_capacity(1024))); + let is_start = Arc::new(AtomicBool::new(true)); + + let stream = { + // 创建新的 stream + let mut stream = response.bytes_stream(); + + let enable_stream_check = AppConfig::get_stream_check(); + + if enable_stream_check { + // 检查第一个 chunk + match stream.next().await { + Some(first_chunk) => { + let chunk = first_chunk.map_err(|e| { + let error_message = format!("Failed to read response chunk: {}", e); + // 理论上,若程序正常,必定成功,因为前面判断过了 + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::RequestFailed(error_message).to_json()), + ) + })?; + + match parse_stream_data(&chunk) { + Err(StreamError::ChatError(error)) => { + let error_respone = error.to_error_response(); + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.status = STATUS_FAILED; + last_log.error = Some(error_respone.native_code()); + } + } + return Err(( + error_respone.status_code(), + Json(error_respone.to_common()), + )); + } + Ok(_) | Err(_) => { + // 创建一个包含第一个 chunk 的 stream + Box::pin( + futures::stream::once(async move { Ok(chunk) }).chain(stream), + ) + as Pin< + Box< + dyn Stream> + Send, + >, + > + } + } + } + None => { + // Box::pin(stream) + // as Pin> + Send>> + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.status = STATUS_FAILED; + last_log.error = Some("Empty stream response".to_string()); + } + } + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + ChatError::RequestFailed("Empty stream response".to_string()) + .to_json(), + ), + )); + } + } + } else { + Box::pin(stream) + as Pin> + Send>> + } + } + .then(move |chunk| { + let response_id = response_id.clone(); + let model = request.model.clone(); + let is_start = is_start.clone(); + let full_text = full_text.clone(); + let state = state.clone(); + + async move { + let chunk = chunk.unwrap_or_default(); + match parse_stream_data(&chunk) { + Ok(StreamMessage::Content(texts)) => { + let mut response_data = String::new(); + + for text in texts { + let mut text_guard = full_text.lock().await; + text_guard.push_str(&text); + let is_first = is_start.load(Ordering::SeqCst); + + let response = ChatResponse { + id: response_id.clone(), + object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), + created: chrono::Utc::now().timestamp(), + model: if is_first { Some(model.clone()) } else { None }, + choices: vec![Choice { + index: 0, + message: None, + delta: Some(Delta { + role: if is_first { + is_start.store(false, Ordering::SeqCst); + Some(Role::Assistant) + } else { + None + }, + content: Some(text), + }), + finish_reason: None, + }], + usage: None, + }; + + response_data.push_str(&format!( + "data: {}\n\n", + serde_json::to_string(&response).unwrap() + )); + } + + Ok::<_, Infallible>(Bytes::from(response_data)) + } + Ok(StreamMessage::StreamStart) => { + // 发送初始响应,包含模型信息 + let response = ChatResponse { + id: response_id.clone(), + object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), + created: chrono::Utc::now().timestamp(), + model: { + is_start.store(true, Ordering::SeqCst); + Some(model.clone()) + }, + choices: vec![Choice { + index: 0, + message: None, + delta: Some(Delta { + role: Some(Role::Assistant), + content: Some(String::new()), + }), + finish_reason: None, + }], + usage: None, + }; + + Ok(Bytes::from(format!( + "data: {}\n\n", + serde_json::to_string(&response).unwrap() + ))) + } + Ok(StreamMessage::StreamEnd) => { + // 根据配置决定是否发送最后的 finish_reason + let include_finish_reason = AppConfig::get_stop_stream(); + + if include_finish_reason { + let response = ChatResponse { + id: response_id.clone(), + object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), + created: chrono::Utc::now().timestamp(), + model: None, + choices: vec![Choice { + index: 0, + message: None, + delta: Some(Delta { + role: None, + content: None, + }), + finish_reason: Some(FINISH_REASON_STOP.to_string()), + }], + usage: None, + }; + Ok(Bytes::from(format!( + "data: {}\n\ndata: [DONE]\n\n", + serde_json::to_string(&response).unwrap() + ))) + } else { + Ok(Bytes::from("data: [DONE]\n\n")) + } + } + Ok(StreamMessage::Debug(debug_prompt)) => { + if let Ok(mut state) = state.try_lock() { + if let Some(last_log) = state.request_logs.last_mut() { + last_log.prompt = Some(debug_prompt.clone()); + } + } + Ok(Bytes::new()) + } + Err(StreamError::ChatError(error)) => { + eprintln!("Stream error occurred: {}", error.to_json()); + Ok(Bytes::new()) + } + Err(e) => { + eprintln!("[警告] Stream error: {}", e); + Ok(Bytes::new()) + } + } + } + }); + + Ok(Response::builder() + .header("Cache-Control", "no-cache") + .header("Connection", "keep-alive") + .header(HEADER_NAME_CONTENT_TYPE, "text/event-stream") + .body(Body::from_stream(stream)) + .unwrap()) + } else { + // 非流式响应 + let mut full_text = String::with_capacity(1024); // 预分配合适的容量 + let mut stream = response.bytes_stream(); + let mut prompt = None; + + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + ChatError::RequestFailed(format!("Failed to read response chunk: {}", e)) + .to_json(), + ), + ) + })?; + + match parse_stream_data(&chunk) { + Ok(StreamMessage::Content(texts)) => { + for text in texts { + full_text.push_str(&text); + } + } + Ok(StreamMessage::Debug(debug_prompt)) => { + prompt = Some(debug_prompt); + } + Ok(StreamMessage::StreamStart) | Ok(StreamMessage::StreamEnd) => {} + Err(StreamError::ChatError(error)) => { + return Err(( + StatusCode::from_u16(error.error.details[0].debug.status_code()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + Json(error.to_error_response().to_common()), + )); + } + Err(_) => continue, + } + } + + // 检查响应是否为空 + if full_text.is_empty() { + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.status = STATUS_FAILED; + last_log.error = Some("Empty response received".to_string()); + if let Some(p) = prompt { + last_log.prompt = Some(p); + } + } + } + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::RequestFailed("Empty response received".to_string()).to_json()), + )); + } + + // 更新请求日志提示词 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.prompt = prompt; + } + } + + let response_data = ChatResponse { + id: format!("chatcmpl-{}", Uuid::new_v4().simple()), + object: OBJECT_CHAT_COMPLETION.to_string(), + created: chrono::Utc::now().timestamp(), + model: Some(request.model), + choices: vec![Choice { + index: 0, + message: Some(Message { + role: Role::Assistant, + content: MessageContent::Text(full_text), + }), + delta: None, + finish_reason: Some(FINISH_REASON_STOP.to_string()), + }], + usage: Some(Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }), + }; + + Ok(Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&response_data).unwrap())) + .unwrap()) + } +} diff --git a/src/chat/stream.rs b/src/chat/stream.rs index 218a7f2..f91d474 100644 --- a/src/chat/stream.rs +++ b/src/chat/stream.rs @@ -1,4 +1,4 @@ -use crate::aiserver::v1::StreamChatResponse; +use super::aiserver::v1::StreamChatResponse; use flate2::read::GzDecoder; use prost::Message; use std::io::Read; diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..ff92946 --- /dev/null +++ b/src/common.rs @@ -0,0 +1 @@ +pub mod models; \ No newline at end of file diff --git a/src/common/models.rs b/src/common/models.rs new file mode 100644 index 0000000..76d5ae2 --- /dev/null +++ b/src/common/models.rs @@ -0,0 +1,70 @@ +pub mod error; +pub mod health; +pub mod config; + +use config::ConfigData; + +use serde::Serialize; + +#[derive(Serialize)] +pub enum ApiStatus { + #[serde(rename = "healthy")] + Healthy, + #[serde(rename = "success")] + Success, + #[serde(rename = "error")] + Error, + #[serde(rename = "failed")] + Failed, +} + +// #[derive(Serialize)] +// #[serde(untagged)] +// pub enum ApiResponse { +// HealthCheck(HealthCheckResponse), +// ConfigData(NormalResponse), +// Error(ErrorResponse), +// } + +// impl ApiResponse { +// pub fn to_string(&self) -> String { +// serde_json::to_string(self).unwrap() +// } +// } + +#[derive(Serialize)] +pub struct NormalResponse { + pub status: ApiStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +impl std::fmt::Display for NormalResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", serde_json::to_string(self).unwrap()) + } +} + +#[derive(Serialize)] +pub struct NormalResponseNoData { + pub status: ApiStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Serialize)] +pub struct ErrorResponse { + // status -> 成功 / 失败 + pub status: ApiStatus, + // HTTP 请求的状态码 + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + // HTTP 请求的错误码 + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + // 错误详情 + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} diff --git a/src/common/models/config.rs b/src/common/models/config.rs new file mode 100644 index 0000000..9fdbdc6 --- /dev/null +++ b/src/common/models/config.rs @@ -0,0 +1,36 @@ +use serde::{Deserialize, Serialize}; + +use crate::{PageContent, UsageCheck, VisionAbility}; + +#[derive(Serialize)] +pub struct ConfigData { + pub page_content: Option, + pub enable_stream_check: bool, + pub include_stop_stream: bool, + pub vision_ability: VisionAbility, + pub enable_slow_pool: bool, + pub enable_all_claude: bool, + pub check_usage_models: UsageCheck, +} + +#[derive(Deserialize)] +pub struct ConfigUpdateRequest { + #[serde(default)] + pub action: String, // "get", "update", "reset" + #[serde(default)] + pub path: String, + #[serde(default)] + pub content: Option, // "default", "text", "html" + #[serde(default)] + pub enable_stream_check: Option, + #[serde(default)] + pub include_stop_stream: Option, + #[serde(default)] + pub vision_ability: Option, + #[serde(default)] + pub enable_slow_pool: Option, + #[serde(default)] + pub enable_all_claude: Option, + #[serde(default)] + pub check_usage_models: Option, +} diff --git a/src/common/models/error.rs b/src/common/models/error.rs new file mode 100644 index 0000000..79430db --- /dev/null +++ b/src/common/models/error.rs @@ -0,0 +1,34 @@ +use super::ErrorResponse; + +pub enum ChatError { + ModelNotSupported(String), + EmptyMessages, + NoTokens, + RequestFailed(String), + Unauthorized, +} + +impl ChatError { + pub fn to_json(&self) -> ErrorResponse { + let (error, message) = match self { + ChatError::ModelNotSupported(model) => ( + "model_not_supported", + format!("Model '{}' is not supported", model), + ), + ChatError::EmptyMessages => ( + "empty_messages", + "Message array cannot be empty".to_string(), + ), + ChatError::NoTokens => ("no_tokens", "No available tokens".to_string()), + ChatError::RequestFailed(err) => ("request_failed", format!("Request failed: {}", err)), + ChatError::Unauthorized => ("unauthorized", "Invalid authorization token".to_string()), + }; + + ErrorResponse { + status: super::ApiStatus::Error, + code: None, + error: Some(error.to_string()), + message: Some(message), + } + } +} diff --git a/src/common/models/health.rs b/src/common/models/health.rs new file mode 100644 index 0000000..43d2241 --- /dev/null +++ b/src/common/models/health.rs @@ -0,0 +1,37 @@ +use serde::Serialize; + +use super::ApiStatus; + +#[derive(Serialize)] +pub struct HealthCheckResponse { + pub status: ApiStatus, + pub version: &'static str, + pub uptime: i64, + pub stats: SystemStats, + pub models: Vec<&'static str>, + pub endpoints: Vec<&'static str>, +} + +#[derive(Serialize)] +pub struct SystemStats { + pub started: String, + pub total_requests: u64, + pub active_requests: u64, + pub system: SystemInfo, +} + +#[derive(Serialize)] +pub struct SystemInfo { + pub memory: MemoryInfo, + pub cpu: CpuInfo, +} + +#[derive(Serialize)] +pub struct MemoryInfo { + pub rss: u64, // 物理内存使用量(字节) +} + +#[derive(Serialize)] +pub struct CpuInfo { + pub usage: f32, // CPU 使用率(百分比) +} diff --git a/src/lib.rs b/src/lib.rs index df1ec42..be27b01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,23 +1,25 @@ use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use image::guess_format; use prost::Message as _; -use rand::Rng; -use sha2::{Digest, Sha256}; use uuid::Uuid; -mod aiserver; -use aiserver::v1::*; - -pub mod message; -use message::*; +pub mod common; pub mod app; -use app::{models::*,constant::*}; +use app::{constant::EMPTY_STRING, models::*}; pub mod chat; +use chat::{ + aiserver::v1::{ + conversation_message, image_proto, ConversationMessage, ExplicitContext, GetChatRequest, + ImageProto, ModelDetails, + }, + constant::{LONG_CONTEXT_MODELS, ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT}, + models::{Message, MessageContent, Role}, +}; async fn process_chat_inputs(inputs: Vec) -> (String, Vec) { - // 收集 system 和 developer 指令 + // 收集 system 指令 let instructions = inputs .iter() .filter(|input| input.role == Role::System) @@ -56,7 +58,7 @@ async fn process_chat_inputs(inputs: Vec) -> (String, Vec) -> (String, Vec) -> (String, Vec) -> (String, Vec 1 { - return Err("不支持动态 GIF".into()); + return Err(ERR_UNSUPPORTED_GIF.into()); } } } @@ -304,11 +306,11 @@ async fn process_http_image( gif::DecodeOptions::new().read_info(std::io::Cursor::new(&image_data)) { if frames.into_iter().count() > 1 { - return Err("不支持动态 GIF".into()); + return Err(ERR_UNSUPPORTED_GIF.into()); } } } - _ => return Err("不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF".into()), + _ => return Err(ERR_UNSUPPORTED_IMAGE_FORMAT.into()), } // 获取图片尺寸 @@ -399,44 +401,3 @@ pub async fn encode_chat_message( Ok(hex::decode(len_prefix + &content)?) } - -pub fn generate_hash() -> String { - let random_bytes = rand::thread_rng().gen::<[u8; 32]>(); - let mut hasher = Sha256::new(); - hasher.update(random_bytes); - hex::encode(hasher.finalize()) -} - -fn obfuscate_bytes(bytes: &mut [u8]) { - let mut prev: u8 = 165; - for (idx, byte) in bytes.iter_mut().enumerate() { - let old_value = *byte; - *byte = (old_value ^ prev).wrapping_add((idx % 256) as u8); - prev = *byte; - } -} - -pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String { - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() - / 1_000_000; - - let mut timestamp_bytes = vec![ - ((timestamp >> 40) & 255) as u8, - ((timestamp >> 32) & 255) as u8, - ((timestamp >> 24) & 255) as u8, - ((timestamp >> 16) & 255) as u8, - ((timestamp >> 8) & 255) as u8, - (255 & timestamp) as u8, - ]; - - obfuscate_bytes(&mut timestamp_bytes); - let encoded = BASE64.encode(×tamp_bytes); - - match mac_addr { - Some(mac) => format!("{}{}/{}", encoded, device_id, mac), - None => format!("{}{}", encoded, device_id), - } -} diff --git a/src/main.rs b/src/main.rs index 6544e9a..822c419 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,73 +6,32 @@ use axum::{ routing::{get, post}, Json, Router, }; -use bytes::Bytes; use chrono::Local; use cursor_api::{ app::{ - client::build_client, + config::handle_config_update, constant::*, models::*, + statics::*, token::{ - get_user_info, get_user_usage, handle_get_tokeninfo, handle_update_tokeninfo, + get_user_info, handle_get_checksum, handle_get_tokeninfo, handle_update_tokeninfo, handle_update_tokeninfo_post, load_tokens, }, utils::{parse_bool_from_env, parse_string_from_env}, }, - chat::{error::StreamError, stream::parse_stream_data}, -}; -use cursor_api::{chat::stream::StreamMessage, message::*}; -use futures::{Stream, StreamExt}; -use std::{ - convert::Infallible, - sync::{atomic::AtomicBool, Arc}, -}; -use std::{ - pin::Pin, - sync::atomic::{AtomicUsize, Ordering}, + chat::{ + constant::AVAILABLE_MODELS, + service::{handle_chat, handle_models}, + }, + common::models::{ + health::{CpuInfo, HealthCheckResponse, MemoryInfo, SystemInfo, SystemStats}, + ApiStatus, + }, }; +use std::sync::Arc; use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System}; use tokio::sync::Mutex; use tower_http::cors::CorsLayer; -use uuid::Uuid; - -// 支持的模型列表 -mod models; -use models::AVAILABLE_MODELS; - -// 自定义错误类型 -enum ChatError { - ModelNotSupported(String), - EmptyMessages, - NoTokens, - RequestFailed(String), - Unauthorized, -} - -impl ChatError { - fn to_json(&self) -> serde_json::Value { - let (code, message) = match self { - ChatError::ModelNotSupported(model) => ( - "model_not_supported", - format!("Model '{}' is not supported", model), - ), - ChatError::EmptyMessages => ( - "empty_messages", - "Message array cannot be empty".to_string(), - ), - ChatError::NoTokens => ("no_tokens", "No available tokens".to_string()), - ChatError::RequestFailed(err) => ("request_failed", format!("Request failed: {}", err)), - ChatError::Unauthorized => ("unauthorized", "Invalid authorization token".to_string()), - }; - - serde_json::json!({ - "error": { - "code": code, - MESSAGE: message - } - }) - } -} #[tokio::main] async fn main() { @@ -89,18 +48,17 @@ async fn main() { // 加载环境变量 dotenvy::dotenv().ok(); + if get_auth_token() == EMPTY_STRING { + panic!("AUTH_TOKEN must be set") + }; + // 初始化全局配置 AppConfig::init( parse_bool_from_env("ENABLE_STREAM_CHECK", true), parse_bool_from_env("INCLUDE_STOP_REASON_STREAM", true), - VisionAbility::from_str(parse_string_from_env("VISION_ABILITY", "base64").as_str()) - .unwrap_or_default(), + VisionAbility::from_str(&parse_string_from_env("VISION_ABILITY", EMPTY_STRING)), parse_bool_from_env("ENABLE_SLOW_POOL", false), parse_bool_from_env("PASS_ANY_CLAUDE", false), - std::env::var("AUTH_TOKEN").expect("AUTH_TOKEN must be set"), - parse_string_from_env("TOKEN_FILE", ".token"), - parse_string_from_env("TOKEN_LIST_FILE", ".token-list"), - parse_string_from_env("ROUTE_PREFIX", ""), ); // 加载 tokens @@ -109,43 +67,40 @@ async fn main() { // 初始化应用状态 let state = Arc::new(Mutex::new(AppState::new(token_infos))); - let route_prefix = AppConfig::get_route_prefix(); - // 设置路由 let app = Router::new() - .route(ROUTER_ROOT_PATH, get(handle_root)) - .route(ROUTER_HEALTH_PATH, get(handle_health)) - .route(ROUTER_TOKENINFO_PATH, get(handle_tokeninfo_page)) - .route(&format!("{}/v1/models", route_prefix), get(handle_models)) - .route(ROUTER_GET_CHECKSUM, get(handle_get_checksum)) - .route(ROUTER_GET_USER_INFO_PATH, get(get_user_info)) - .route(ROUTER_UPDATE_TOKENINFO_PATH, get(handle_update_tokeninfo)) - .route(ROUTER_GET_TOKENINFO_PATH, post(handle_get_tokeninfo)) + .route(ROUTE_ROOT_PATH, get(handle_root)) + .route(ROUTE_HEALTH_PATH, get(handle_health)) + .route(ROUTE_TOKENINFO_PATH, get(handle_tokeninfo_page)) + .route(ROUTE_MODELS_PATH.as_str(), get(handle_models)) + .route(ROUTE_GET_CHECKSUM, get(handle_get_checksum)) + .route(ROUTE_GET_USER_INFO_PATH, get(get_user_info)) + .route(ROUTE_UPDATE_TOKENINFO_PATH, get(handle_update_tokeninfo)) + .route(ROUTE_GET_TOKENINFO_PATH, post(handle_get_tokeninfo)) .route( - ROUTER_UPDATE_TOKENINFO_PATH, + ROUTE_UPDATE_TOKENINFO_PATH, post(handle_update_tokeninfo_post), ) - .route( - &format!("{}/v1/chat/completions", route_prefix), - post(handle_chat), - ) - .route(ROUTER_LOGS_PATH, get(handle_logs)) - .route(ROUTER_LOGS_PATH, post(handle_logs_post)) - .route(ROUTER_ENV_EXAMPLE_PATH, get(handle_env_example)) - .route(ROUTER_CONFIG_PATH, get(handle_config_page)) - .route(ROUTER_CONFIG_PATH, post(handle_config_update)) - .route("/static/:path", get(handle_static)) + .route(ROUTE_CHAT_PATH.as_str(), post(handle_chat)) + .route(ROUTE_LOGS_PATH, get(handle_logs)) + .route(ROUTE_LOGS_PATH, post(handle_logs_post)) + .route(ROUTE_ENV_EXAMPLE_PATH, get(handle_env_example)) + .route(ROUTE_CONFIG_PATH, get(handle_config_page)) + .route(ROUTE_CONFIG_PATH, post(handle_config_update)) + .route(ROUTE_STATIC_PATH, get(handle_static)) + .route(ROUTE_ABOUT_PATH, get(handle_about)) + .route(ROUTE_README_PATH, get(handle_readme)) .layer(CorsLayer::permissive()) .with_state(state); // 启动服务器 - let port = std::env::var("PORT").unwrap_or_else(|_| "3000".to_string()); + let port = parse_string_from_env("PORT", "3000"); let addr = format!("0.0.0.0:{}", port); println!("服务器运行在端口 {}", port); println!("当前版本: v{}", PKG_VERSION); - if !std::env::args().any(|arg| arg == "--no-instruction") { - println!(include_str!("../start_instruction")); - } + // if !std::env::args().any(|arg| arg == "--no-instruction") { + // println!(include_str!("../start_instruction")); + // } let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); @@ -153,10 +108,10 @@ async fn main() { // 根路由处理 async fn handle_root() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTER_ROOT_PATH).unwrap_or_default() { + match AppConfig::get_page_content(ROUTE_ROOT_PATH).unwrap_or_default() { PageContent::Default => Response::builder() .status(StatusCode::TEMPORARY_REDIRECT) - .header("Location", ROUTER_HEALTH_PATH) + .header(HEADER_NAME_LOCATION, ROUTE_HEALTH_PATH) .body(Body::empty()) .unwrap(), PageContent::Text(content) => Response::builder() @@ -170,9 +125,8 @@ async fn handle_root() -> impl IntoResponse { } } -async fn handle_health(State(state): State>>) -> Json { - let start_time = APP_CONFIG.read().unwrap().start_time; - let route_prefix = AppConfig::get_route_prefix(); +async fn handle_health(State(state): State>>) -> Json { + let start_time = get_start_time(); // 创建系统信息实例,只监控 CPU 和内存 let mut sys = System::new_with_specifics( @@ -199,42 +153,45 @@ async fn handle_health(State(state): State>>) -> Json>(), - "endpoints": [ - &format!("{}/v1/chat/completions", route_prefix), - &format!("{}/v1/models", route_prefix), - ROUTER_GET_CHECKSUM, - ROUTER_TOKENINFO_PATH, - ROUTER_UPDATE_TOKENINFO_PATH, - ROUTER_GET_TOKENINFO_PATH, - ROUTER_LOGS_PATH, - ROUTER_GET_USER_INFO_PATH, - ROUTER_ENV_EXAMPLE_PATH, - ROUTER_CONFIG_PATH, - "/static" - ] - })) + models: AVAILABLE_MODELS.iter().map(|m| m.id).collect::>(), + endpoints: vec![ + ROUTE_CHAT_PATH.as_str(), + ROUTE_MODELS_PATH.as_str(), + ROUTE_GET_CHECKSUM, + ROUTE_TOKENINFO_PATH, + ROUTE_UPDATE_TOKENINFO_PATH, + ROUTE_GET_TOKENINFO_PATH, + ROUTE_LOGS_PATH, + ROUTE_GET_USER_INFO_PATH, + ROUTE_ENV_EXAMPLE_PATH, + ROUTE_CONFIG_PATH, + ROUTE_STATIC_PATH, + ROUTE_ABOUT_PATH, + ROUTE_README_PATH + + ], + }) } async fn handle_tokeninfo_page() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTER_TOKENINFO_PATH).unwrap_or_default() { + match AppConfig::get_page_content(ROUTE_TOKENINFO_PATH).unwrap_or_default() { PageContent::Default => Response::builder() .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(include_str!("../static/tokeninfo.min.html").to_string()) @@ -250,27 +207,9 @@ async fn handle_tokeninfo_page() -> impl IntoResponse { } } -// 模型列表处理 -async fn handle_models() -> Json { - Json(serde_json::json!({ - "object": "list", - "data": AVAILABLE_MODELS.to_vec() - })) -} - -async fn handle_get_checksum() -> Json { - let checksum = cursor_api::generate_checksum( - &cursor_api::generate_hash(), - Some(&cursor_api::generate_hash()), - ); - Json(serde_json::json!({ - "checksum": checksum - })) -} - // 日志处理 async fn handle_logs() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTER_LOGS_PATH).unwrap_or_default() { + match AppConfig::get_page_content(ROUTE_LOGS_PATH).unwrap_or_default() { PageContent::Default => Response::builder() .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from( @@ -291,8 +230,8 @@ async fn handle_logs() -> impl IntoResponse { async fn handle_logs_post( State(state): State>>, headers: HeaderMap, -) -> Result, StatusCode> { - let auth_token = AppConfig::get_auth_token(); +) -> Result, StatusCode> { + let auth_token = get_auth_token(); // 验证 AUTH_TOKEN let auth_header = headers @@ -306,12 +245,20 @@ async fn handle_logs_post( } let state = state.lock().await; - Ok(Json(serde_json::json!({ - "total": state.request_logs.len(), - "logs": state.request_logs, - "timestamp": Local::now(), - STATUS: STATUS_SUCCESS - }))) + Ok(Json(LogsResponse { + status: ApiStatus::Success, + total: state.request_logs.len(), + logs: state.request_logs.clone(), + timestamp: Local::now().to_string(), + })) +} + +#[derive(serde::Serialize)] +struct LogsResponse { + status: ApiStatus, + total: usize, + logs: Vec, + timestamp: String, } async fn handle_env_example() -> impl IntoResponse { @@ -321,471 +268,9 @@ async fn handle_env_example() -> impl IntoResponse { .unwrap() } -// 聊天处理函数的签名 -async fn handle_chat( - State(state): State>>, - headers: HeaderMap, - Json(request): Json, -) -> Result, (StatusCode, Json)> { - let allow_claude = AppConfig::get_allow_claude(); - - // 验证模型是否支持 - let model_supported = AVAILABLE_MODELS.iter().any(|m| m.id == request.model); - - if !(model_supported || allow_claude && request.model.starts_with("claude")) { - return Err(( - StatusCode::BAD_REQUEST, - Json(ChatError::ModelNotSupported(request.model).to_json()), - )); - } - - let request_time = Local::now(); - - // 验证请求 - if request.messages.is_empty() { - return Err(( - StatusCode::BAD_REQUEST, - Json(ChatError::EmptyMessages.to_json()), - )); - } - - // 获取并处理认证令牌 - let auth_token = headers - .get(axum::http::header::AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - ))?; - - // 验证 AuthToken - if auth_token != AppConfig::get_auth_token() { - return Err(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - )); - } - - // 完整的令牌处理逻辑和对应的 checksum - let (auth_token, checksum, alias) = { - static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0); - let state_guard = state.lock().await; - let token_infos = &state_guard.token_infos; - - if token_infos.is_empty() { - return Err(( - StatusCode::SERVICE_UNAVAILABLE, - Json(ChatError::NoTokens.to_json()), - )); - } - - let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len(); - let token_info = &token_infos[index]; - ( - token_info.token.clone(), - token_info.checksum.clone(), - token_info.alias.clone(), - ) - }; - - // 更新请求日志 - { - let state_clone = state.clone(); - let mut state = state.lock().await; - state.total_requests += 1; - state.active_requests += 1; - - // 创建一个后台任务来获取使用情况 - let auth_token_clone = auth_token.clone(); - let checksum_clone = checksum.clone(); - - tokio::spawn(async move { - let usage = get_user_usage(&auth_token_clone, &checksum_clone).await; - let mut state = state_clone.lock().await; - // 根据时间戳找到对应的日志 - if let Some(log) = state - .request_logs - .iter_mut() - .find(|log| log.timestamp == request_time) - { - log.token_info.usage = usage; - } - }); - - state.request_logs.push(RequestLog { - timestamp: request_time, - model: request.model.clone(), - token_info: TokenInfo { - token: auth_token.clone(), - checksum: checksum.clone(), - alias: alias.clone(), - usage: None, - }, - prompt: None, - stream: request.stream, - status: "pending".to_string(), - error: None, - }); - - if state.request_logs.len() > 100 { - state.request_logs.remove(0); - } - } - - // 将消息转换为hex格式 - let hex_data = cursor_api::encode_chat_message(request.messages, &request.model) - .await - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json( - ChatError::RequestFailed("Failed to encode chat message".to_string()).to_json(), - ), - ) - })?; - - // 构建请求客户端 - let client = build_client(&auth_token, &checksum, CURSOR_API2_STREAM_CHAT); - let response = client.body(hex_data).send().await; - - // 处理请求结果 - let response = match response { - Ok(resp) => { - // 更新请求日志为成功 - { - let mut state = state.lock().await; - state.request_logs.last_mut().unwrap().status = STATUS_SUCCESS.to_string(); - } - resp - } - Err(e) => { - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED.to_string(); - last_log.error = Some(e.to_string()); - } - } - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed(e.to_string()).to_json()), - )); - } - }; - - // 释放活动请求计数 - { - let mut state = state.lock().await; - state.active_requests -= 1; - } - - if request.stream { - let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple()); - let full_text = Arc::new(Mutex::new(String::with_capacity(1024))); - let is_start = Arc::new(AtomicBool::new(true)); - - let stream = { - // 创建新的 stream - let mut stream = response.bytes_stream(); - - let enable_stream_check = AppConfig::get_stream_check(); - - if enable_stream_check { - // 检查第一个 chunk - match stream.next().await { - Some(first_chunk) => { - let chunk = first_chunk.map_err(|e| { - let error_message = format!("Failed to read response chunk: {}", e); - // 理论上,若程序正常,必定成功,因为前面判断过了 - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed(error_message).to_json()), - ) - })?; - - match parse_stream_data(&chunk) { - Err(StreamError::ChatError(error)) => { - let error_respone = error.to_error_response(); - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED.to_string(); - last_log.error = Some(error_respone.native_code()); - } - } - return Err(( - error_respone.status_code(), - Json(error_respone.to_json()), - )); - } - Ok(_) | Err(_) => { - // 创建一个包含第一个 chunk 的 stream - Box::pin( - futures::stream::once(async move { Ok(chunk) }).chain(stream), - ) - as Pin< - Box< - dyn Stream> + Send, - >, - > - } - } - } - None => { - // Box::pin(stream) - // as Pin> + Send>> - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED.to_string(); - last_log.error = Some("Empty stream response".to_string()); - } - } - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json( - ChatError::RequestFailed("Empty stream response".to_string()) - .to_json(), - ), - )); - } - } - } else { - Box::pin(stream) - as Pin> + Send>> - } - } - .then(move |chunk| { - let response_id = response_id.clone(); - let model = request.model.clone(); - let is_start = is_start.clone(); - let full_text = full_text.clone(); - let state = state.clone(); - - async move { - let chunk = chunk.unwrap_or_default(); - match parse_stream_data(&chunk) { - Ok(StreamMessage::Content(texts)) => { - let mut response_data = String::new(); - - for text in texts { - let mut text_guard = full_text.lock().await; - text_guard.push_str(&text); - let is_first = is_start.load(Ordering::SeqCst); - - let response = ChatResponse { - id: response_id.clone(), - object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), - created: chrono::Utc::now().timestamp(), - model: if is_first { Some(model.clone()) } else { None }, - choices: vec![Choice { - index: 0, - message: None, - delta: Some(Delta { - role: if is_first { - is_start.store(false, Ordering::SeqCst); - Some(Role::Assistant) - } else { - None - }, - content: Some(text), - }), - finish_reason: None, - }], - usage: None, - }; - - response_data.push_str(&format!( - "data: {}\n\n", - serde_json::to_string(&response).unwrap() - )); - } - - Ok::<_, Infallible>(Bytes::from(response_data)) - } - Ok(StreamMessage::StreamStart) => { - // 发送初始响应,包含模型信息 - let response = ChatResponse { - id: response_id.clone(), - object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), - created: chrono::Utc::now().timestamp(), - model: { - is_start.store(true, Ordering::SeqCst); - Some(model.clone()) - }, - choices: vec![Choice { - index: 0, - message: None, - delta: Some(Delta { - role: Some(Role::Assistant), - content: Some(String::new()), - }), - finish_reason: None, - }], - usage: None, - }; - - Ok(Bytes::from(format!( - "data: {}\n\n", - serde_json::to_string(&response).unwrap() - ))) - } - Ok(StreamMessage::StreamEnd) => { - // 根据配置决定是否发送最后的 finish_reason - let include_finish_reason = - parse_bool_from_env("INCLUDE_STOP_FINISH_REASON_STREAM", true); - - if include_finish_reason { - let response = ChatResponse { - id: response_id.clone(), - object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), - created: chrono::Utc::now().timestamp(), - model: None, - choices: vec![Choice { - index: 0, - message: None, - delta: Some(Delta { - role: None, - content: None, - }), - finish_reason: Some(FINISH_REASON_STOP.to_string()), - }], - usage: None, - }; - Ok(Bytes::from(format!( - "data: {}\n\ndata: [DONE]\n\n", - serde_json::to_string(&response).unwrap() - ))) - } else { - Ok(Bytes::from("data: [DONE]\n\n")) - } - } - Ok(StreamMessage::Debug(debug_prompt)) => { - if let Ok(mut state) = state.try_lock() { - if let Some(last_log) = state.request_logs.last_mut() { - last_log.prompt = Some(debug_prompt.clone()); - } - } - Ok(Bytes::new()) - } - Err(StreamError::ChatError(error)) => { - eprintln!("Stream error occurred: {}", error.to_json()); - Ok(Bytes::new()) - } - Err(e) => { - eprintln!("[警告] Stream error: {}", e); - Ok(Bytes::new()) - } - } - } - }); - - Ok(Response::builder() - .header("Cache-Control", "no-cache") - .header("Connection", "keep-alive") - .header(HEADER_NAME_CONTENT_TYPE, "text/event-stream") - .body(Body::from_stream(stream)) - .unwrap()) - } else { - // 非流式响应 - let mut full_text = String::with_capacity(1024); // 预分配合适的容量 - let mut stream = response.bytes_stream(); - let mut prompt = None; - - while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json( - ChatError::RequestFailed(format!("Failed to read response chunk: {}", e)) - .to_json(), - ), - ) - })?; - - match parse_stream_data(&chunk) { - Ok(StreamMessage::Content(texts)) => { - for text in texts { - full_text.push_str(&text); - } - } - Ok(StreamMessage::Debug(debug_prompt)) => { - prompt = Some(debug_prompt); - } - Ok(StreamMessage::StreamStart) | Ok(StreamMessage::StreamEnd) => {} - Err(StreamError::ChatError(error)) => { - return Err(( - StatusCode::from_u16(error.error.details[0].debug.status_code()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), - Json(error.to_json()), - )); - } - Err(_) => continue, - } - } - - // 检查响应是否为空 - if full_text.is_empty() { - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED.to_string(); - last_log.error = Some("Empty response received".to_string()); - if let Some(p) = prompt { - last_log.prompt = Some(p); - } - } - } - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed("Empty response received".to_string()).to_json()), - )); - } - - // 更新请求日志提示词 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.prompt = prompt; - } - } - - let response_data = ChatResponse { - id: format!("chatcmpl-{}", Uuid::new_v4().simple()), - object: OBJECT_CHAT_COMPLETION.to_string(), - created: chrono::Utc::now().timestamp(), - model: Some(request.model), - choices: vec![Choice { - index: 0, - message: Some(Message { - role: Role::Assistant, - content: MessageContent::Text(full_text), - }), - delta: None, - finish_reason: Some(FINISH_REASON_STOP.to_string()), - }], - usage: Some(Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }), - }; - - Ok(Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_string(&response_data).unwrap())) - .unwrap()) - } -} - // 配置页面处理函数 async fn handle_config_page() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTER_CONFIG_PATH).unwrap_or_default() { + match AppConfig::get_page_content(ROUTE_CONFIG_PATH).unwrap_or_default() { PageContent::Default => Response::builder() .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(include_str!("../static/config.min.html").to_string()) @@ -801,244 +286,70 @@ async fn handle_config_page() -> impl IntoResponse { } } -// 配置更新处理函数 -async fn handle_config_update( - State(_state): State>>, - headers: HeaderMap, - Json(request): Json, -) -> Result, (StatusCode, Json)> { - // 验证 AUTH_TOKEN - let auth_token = AppConfig::get_auth_token(); - - let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(( - StatusCode::UNAUTHORIZED, - Json(serde_json::json!({ - "error": "未提供认证令牌" - })), - ))?; - - if auth_header != auth_token { - return Err(( - StatusCode::UNAUTHORIZED, - Json(serde_json::json!({ - "error": "无效的认证令牌" - })), - )); - } - - match request.action.as_str() { - "get" => Ok(Json(serde_json::json!({ - STATUS: STATUS_SUCCESS, - "data": { - "page_content": AppConfig::get_page_content(&request.path), - "enable_stream_check": AppConfig::get_stream_check(), - "include_stop_stream": AppConfig::get_stop_stream(), - "vision_ability": AppConfig::get_vision_ability(), - "enable_slow_pool": AppConfig::get_slow_pool(), - "enable_all_claude": AppConfig::get_allow_claude(), - } - }))), - - "update" => { - // 处理页面内容更新 - if !request.path.is_empty() && request.content.is_some() { - let content = request.content.unwrap(); - - if let Err(e) = AppConfig::update_page_content(&request.path, content) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("更新页面内容失败: {}", e) - })), - )); - } - } - - // 处理 enable_stream_check 更新 - if let Some(enable_stream_check) = request.enable_stream_check { - if let Err(e) = AppConfig::update_stream_check(enable_stream_check) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("更新 enable_stream_check 失败: {}", e) - })), - )); - } - } - - // 处理 include_stop_stream 更新 - if let Some(include_stop_stream) = request.include_stop_stream { - if let Err(e) = AppConfig::update_stop_stream(include_stop_stream) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("更新 include_stop_stream 失败: {}", e) - })), - )); - } - } - - // 处理 vision_ability 更新 - if let Some(vision_ability) = request.vision_ability { - if let Err(e) = AppConfig::update_vision_ability(vision_ability) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("更新 vision_ability 失败: {}", e) - })), - )); - } - } - - // 处理 enable_slow_pool 更新 - if let Some(enable_slow_pool) = request.enable_slow_pool { - if let Err(e) = AppConfig::update_slow_pool(enable_slow_pool) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("更新 enable_slow_pool 失败: {}", e) - })), - )); - } - } - - // 处理 enable_all_claude 更新 - if let Some(enable_all_claude) = request.enable_all_claude { - if let Err(e) = AppConfig::update_allow_claude(enable_all_claude) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("更新 enable_all_claude 失败: {}", e) - })), - )); - } - } - - Ok(Json(serde_json::json!({ - STATUS: STATUS_SUCCESS, - MESSAGE: "配置已更新" - }))) - } - - "reset" => { - // 重置页面内容 - if !request.path.is_empty() { - if let Err(e) = AppConfig::reset_page_content(&request.path) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("重置页面内容失败: {}", e) - })), - )); - } - } - - // 重置 enable_stream_check - if request.enable_stream_check.is_some() { - if let Err(e) = AppConfig::reset_stream_check() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("重置 enable_stream_check 失败: {}", e) - })), - )); - } - } - - // 重置 include_stop_stream - if request.include_stop_stream.is_some() { - if let Err(e) = AppConfig::reset_stop_stream() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("重置 include_stop_stream 失败: {}", e) - })), - )); - } - } - - // 重置 vision_ability - if request.vision_ability.is_some() { - if let Err(e) = AppConfig::reset_vision_ability() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("重置 vision_ability 失败: {}", e) - })), - )); - } - } - - // 重置 enable_slow_pool - if request.enable_slow_pool.is_some() { - if let Err(e) = AppConfig::reset_slow_pool() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("重置 enable_slow_pool 失败: {}", e) - })), - )); - } - } - - // 重置 enable_all_claude - if request.enable_all_claude.is_some() { - if let Err(e) = AppConfig::reset_allow_claude() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("重置 enable_slow_pool 失败: {}", e) - })), - )); - } - } - Ok(Json(serde_json::json!({ - STATUS: STATUS_SUCCESS, - MESSAGE: "配置已重置" - }))) - } - - _ => Err(( - StatusCode::BAD_REQUEST, - Json(serde_json::json!({ - "error": "无效的操作类型" - })), - )), - } -} - async fn handle_static(Path(path): Path) -> impl IntoResponse { match path.as_str() { "shared-styles.css" => { - match AppConfig::get_page_content(ROUTER_SHARED_STYLES_PATH).unwrap_or_default() { + match AppConfig::get_page_content(ROUTE_SHARED_STYLES_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, "text/css;charset=utf-8") + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) .body(include_str!("../static/shared-styles.min.css").to_string()) .unwrap(), PageContent::Text(content) | PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, "text/css;charset=utf-8") + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) + .body(content.clone()) + .unwrap(), + } + } + "shared.js" => { + match AppConfig::get_page_content(ROUTE_SHARED_JS_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) + .body(include_str!("../static/shared.min.js").to_string()) + .unwrap(), + PageContent::Text(content) | PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) .body(content.clone()) .unwrap(), } } - "shared.js" => match AppConfig::get_page_content(ROUTER_SHARED_JS_PATH).unwrap_or_default() - { - PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, "text/javascript;charset=utf-8") - .body(include_str!("../static/shared.min.js").to_string()) - .unwrap(), - PageContent::Text(content) | PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, "text/javascript;charset=utf-8") - .body(content.clone()) - .unwrap(), - }, _ => Response::builder() .status(StatusCode::NOT_FOUND) .body("Not found".to_string()) .unwrap(), } } + +async fn handle_about() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_ABOUT_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(include_str!("../static/readme.min.html").to_string()) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(content.clone()) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(content.clone()) + .unwrap(), + } +} + +async fn handle_readme() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_README_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .status(StatusCode::TEMPORARY_REDIRECT) + .header(HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH) + .body(Body::empty()) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(Body::from(content.clone())) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(content.clone())) + .unwrap(), + } +} diff --git a/src/message.proto b/src/message.proto deleted file mode 100644 index 561ee16..0000000 --- a/src/message.proto +++ /dev/null @@ -1,53 +0,0 @@ -syntax = "proto3"; - -package cursor; - -message ChatMessage { - message FileContent { - message Position { - int32 line = 1; - int32 column = 2; - } - message Range { - Position start = 1; - Position end = 2; - } - - string filename = 1; - string content = 2; - Position position = 3; - string language = 5; - Range range = 6; - int32 length = 8; - int32 type = 9; - int32 error_code = 11; - } - - message Message { - string content = 1; - int32 role = 2; - string message_id = 13; - } - - message Instructions { - string content = 1; - } - - message Model { - string name = 1; - string empty = 4; - } - - // repeated FileContent files = 1; - repeated Message messages = 2; - Instructions instructions = 4; - string projectPath = 5; - Model model = 7; - string requestId = 9; - string summary = 11; // 或许是空的,描述会话做了什么事情,但是不是标题 或许可以当作额外的设定来用 - string conversationId = 15; // 又来一个uuid -} - -message ResMessage { - string msg = 1; -} \ No newline at end of file diff --git a/src/models.rs b/src/models.rs deleted file mode 100644 index d8ea513..0000000 --- a/src/models.rs +++ /dev/null @@ -1,141 +0,0 @@ -use crate::Model; -use std::sync::LazyLock; - -use super::{ANTHROPIC, CURSOR, GOOGLE, MODEL_OBJECT, OPENAI}; - -pub static AVAILABLE_MODELS: LazyLock> = LazyLock::new(|| { - vec![ - Model { - id: "claude-3.5-sonnet".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: ANTHROPIC.into(), - }, - Model { - id: "gpt-3.5".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "gpt-4".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "gpt-4o".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "claude-3-opus".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: ANTHROPIC.into(), - }, - Model { - id: "cursor-fast".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: CURSOR.into(), - }, - Model { - id: "cursor-small".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: CURSOR.into(), - }, - Model { - id: "gpt-3.5-turbo".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "gpt-4-turbo-2024-04-09".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "gpt-4o-128k".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "gemini-1.5-flash-500k".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: GOOGLE.into(), - }, - Model { - id: "claude-3-haiku-200k".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: ANTHROPIC.into(), - }, - Model { - id: "claude-3-5-sonnet-200k".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: ANTHROPIC.into(), - }, - Model { - id: "claude-3-5-sonnet-20241022".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: ANTHROPIC.into(), - }, - Model { - id: "gpt-4o-mini".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "o1-mini".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "o1-preview".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "o1".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: OPENAI.into(), - }, - Model { - id: "claude-3.5-haiku".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: ANTHROPIC.into(), - }, - Model { - id: "gemini-exp-1206".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: GOOGLE.into(), - }, - Model { - id: "gemini-2.0-flash-thinking-exp".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: GOOGLE.into(), - }, - Model { - id: "gemini-2.0-flash-exp".into(), - created: 1706659200, - object: MODEL_OBJECT.into(), - owned_by: GOOGLE.into(), - }, - ] -}); diff --git a/start_instruction b/start_instruction deleted file mode 100644 index cad93bf..0000000 --- a/start_instruction +++ /dev/null @@ -1,6 +0,0 @@ -当前版本已稳定,若发现响应出现缺字漏字,与本程序无关。 -若发现首字慢,与本程序无关。 -若发现响应出现乱码,也与本程序无关。 -属于官方的问题,请不要像作者反馈。 -本程序拥有堪比客户端原本的速度,甚至可能更快。 -本程序的性能是非常厉害的。 \ No newline at end of file diff --git a/static/config.html b/static/config.html index 529a79e..48585a6 100644 --- a/static/config.html +++ b/static/config.html @@ -23,13 +23,15 @@ + +
@@ -91,6 +93,18 @@
+
+ + + +
+
@@ -147,6 +161,8 @@ parseStringFromBoolean(data.data.enable_slow_pool, ''); document.getElementById('enable_all_claude').value = parseStringFromBoolean(data.data.enable_all_claude, ''); + document.getElementById('check_usage_models_type').value = data.data.check_usage_models?.type || ''; + document.getElementById('check_usage_models_list').value = data.data.check_usage_models?.type === 'list' ? data.data.check_usage_models?.content || '' : document.getElementById('check_usage_models_list').value; } } @@ -186,6 +202,14 @@ }), ...(document.getElementById('enable_all_claude').value && { enable_all_claude: parseBooleanFromString(document.getElementById('enable_all_claude').value) + }), + ...(document.getElementById('check_usage_models_type').value && { + check_usage_models: { + type: document.getElementById('check_usage_models_type').value, + ...(document.getElementById('check_usage_models_type').value === 'list' && { + content: document.getElementById('check_usage_models_list').value + }) + } }) }; @@ -220,6 +244,12 @@ // 初始化 token 处理 initializeTokenHandling('authToken'); + + // 添加使用量检查模型类型变更处理 + document.getElementById('check_usage_models_type').addEventListener('change', function() { + const input = document.getElementById('check_usage_models_list'); + input.style.display = this.value === 'list' ? 'inline-block' : 'none'; + }); diff --git a/static/readme.html b/static/readme.html new file mode 100644 index 0000000..2d9d327 --- /dev/null +++ b/static/readme.html @@ -0,0 +1,399 @@ +

cursor-api

+ +

说明

+ +
    +
  • 当前版本已稳定,若发现响应出现缺字漏字,与本程序无关。
  • +
  • 若发现首字慢,与本程序无关。
  • +
  • 若发现响应出现乱码,也与本程序无关。
  • +
  • 属于官方的问题,请不要像作者反馈。
  • +
  • 本程序拥有堪比客户端原本的速度,甚至可能更快。
  • +
  • 本程序的性能是非常厉害的。
  • +
+ +

获取key

+ +
    +
  1. 访问 www.cursor.com 并完成注册登录
  2. +
  3. 在浏览器中打开开发者工具(F12)
  4. +
  5. 在 Application-Cookies 中查找名为 WorkosCursorSessionToken 的条目,并复制其第三个字段。请注意,%3A%3A 是 :: 的 URL 编码形式,cookie 的值使用冒号 (:) 进行分隔。
  6. +
+ +

配置说明

+ +

环境变量

+ +
    +
  • PORT: 服务器端口号(默认:3000)
  • +
  • AUTH_TOKEN: 认证令牌(必须,用于API认证)
  • +
  • ROUTE_PREFIX: 路由前缀(可选)
  • +
  • TOKEN_FILE: token文件路径(默认:.token)
  • +
  • TOKEN_LIST_FILE: token列表文件路径(默认:.token-list)
  • +
+ +

更多请查看 /env-example

+ +

Token文件格式

+ +
    +
  1. +

    .token 文件:每行一个token,支持以下格式:

    + +
    # 这是注释
    +token1
    +# alias与标签的作用差不多
    +alias::token2
    +
    + +

    alias 可以是任意值,用于区分不同的 token,更方便管理,WorkosCursorSessionToken 是相同格式
    +该文件将自动向.token-list文件中追加token,同时自动生成checksum

    +
  2. + +
  3. +

    .token-list 文件:每行为token和checksum的对应关系:

    + +
    # 这里的#表示这行在下次读取要删除
    +token1,checksum1
    +# 支持像.token一样的alias,冲突时以.token为准
    +alias::token2,checksum2
    +
    + +

    该文件可以被自动管理,但用户仅可在确认自己拥有修改能力时修改,一般仅有以下情况需要手动修改:

    + +
      +
    • 需要删除某个 token
    • +
    • 需要使用已有 checksum 来对应某一个 token
    • +
    +
  4. +
+ +

模型列表

+ +

写死了,后续也不会会支持自定义模型列表

+ +
claude-3.5-sonnet
+gpt-4
+gpt-4o
+claude-3-opus
+cursor-fast
+cursor-small
+gpt-3.5-turbo
+gpt-4-turbo-2024-04-09
+gpt-4o-128k
+gemini-1.5-flash-500k
+claude-3-haiku-200k
+claude-3-5-sonnet-200k
+claude-3-5-sonnet-20241022
+gpt-4o-mini
+o1-mini
+o1-preview
+o1
+claude-3.5-haiku
+gemini-exp-1206
+gemini-2.0-flash-thinking-exp
+gemini-2.0-flash-exp
+
+ +

接口说明

+ +

基础对话

+ +
    +
  • 接口地址: /v1/chat/completions
  • +
  • 请求方法: POST
  • +
  • 认证方式: Bearer Token +
      +
    1. 使用环境变量 AUTH_TOKEN 进行认证
    2. +
    3. 使用 .token 文件中的令牌列表进行轮询认证
    4. +
  • +
+ +

请求格式

+ +
{
+  "model": "string",
+  "messages": [
+    {
+      "role": "system" | "user" | "assistant", // 也可以是 "developer" | "human" | "ai"
+      "content": "string" | [
+        {
+          "type": "text" | "image_url",
+          "text": "string",
+          "image_url": {
+            "url": "string"
+          }
+        }
+      ]
+    }
+  ],
+  "stream": boolean
+}
+
+ +

响应格式

+ +

如果 streamfalse:

+ +
{
+  "id": "string",
+  "object": "chat.completion",
+  "created": number,
+  "model": "string",
+  "choices": [
+    {
+      "index": number,
+      "message": {
+        "role": "assistant",
+        "content": "string"
+      },
+      "finish_reason": "stop" | "length"
+    }
+  ],
+  "usage": {
+    "prompt_tokens": number,
+    "completion_tokens": number,
+    "total_tokens": number
+  }
+}
+
+ +

如果 streamtrue:

+ +
data: {"id":"string","object":"chat.completion.chunk","created":number,"model":"string","choices":[{"index":number,"delta":{"role":"assistant","content":"string"},"finish_reason":null}]}
+
+data: {"id":"string","object":"chat.completion.chunk","created":number,"model":"string","choices":[{"index":number,"delta":{"content":"string"},"finish_reason":null}]}
+
+data: {"id":"string","object":"chat.completion.chunk","created":number,"model":"string","choices":[{"index":number,"delta":{},"finish_reason":"stop"}]}
+
+data: [DONE]
+
+ +

Token管理接口

+ +

简易Token信息管理页面

+ +
    +
  • 接口地址: /tokeninfo
  • +
  • 请求方法: GET
  • +
  • 响应格式: HTML页面
  • +
  • 功能: 获取 .token 和 .token-list 文件内容,并允许用户方便地使用 API 修改文件内容
  • +
+ +

更新Token信息 (GET)

+ +
    +
  • 接口地址: /update-tokeninfo
  • +
  • 请求方法: GET
  • +
  • 认证方式: 不需要
  • +
  • 功能: 请求内容不包括文件内容,直接修改文件,调用重载函数
  • +
+ +

更新Token信息 (POST)

+ +
    +
  • 接口地址: /update-tokeninfo
  • +
  • 请求方法: POST
  • +
  • 认证方式: Bearer Token
  • +
  • 请求格式:
  • +
+ +
{
+  "tokens": "string",
+  "token_list": "string"
+}
+
+ +
    +
  • 响应格式:
  • +
+ +
{
+  "status": "success",
+  "message": "Token files have been updated and reloaded",
+  "token_file": "string",
+  "token_list_file": "string",
+  "token_count": number
+}
+
+ +

获取Token信息

+ +
    +
  • 接口地址: /get-tokeninfo
  • +
  • 请求方法: POST
  • +
  • 认证方式: Bearer Token
  • +
  • 响应格式:
  • +
+ +
{
+  "status": "success",
+  "token_file": "string",
+  "token_list_file": "string",
+  "tokens": "string",
+  "token_list": "string"
+}
+
+ +

配置管理接口

+ +

配置页面

+ +
    +
  • 接口地址: /config
  • +
  • 请求方法: GET
  • +
  • 响应格式: HTML页面
  • +
  • 功能: 提供配置管理界面,可以修改页面内容和系统配置
  • +
+ +

更新配置

+ +
    +
  • 接口地址: /config
  • +
  • 请求方法: POST
  • +
  • 认证方式: Bearer Token
  • +
  • 请求格式:
  • +
+ +
{
+  "action": "get" | "update" | "reset",
+  "path": "string",
+  "content": "string",
+  "content_type": "default" | "text" | "html",
+  "enable_stream_check": boolean,
+  "enable_stream_check": boolean,
+  "vision_ability": "none" | "base64" | "all", // "disabled" | "base64-only" | "base64-http"
+  "enable_slow_pool": boolean,
+  "enable_slow_pool": boolean
+}
+
+ +
    +
  • 响应格式:
  • +
+ +
{
+  "status": "success",
+  "message": "string",
+  "data": {
+    "page_content": {
+      "type": "default" | "text" | "html",
+      "content": "string"
+    },
+    "enable_stream_check": boolean,
+    "vision_ability": "base64" | "url" | "none", 
+    "enable_slow_pool": boolean
+  }
+}
+
+ +

静态资源接口

+ +

获取共享样式

+ +
    +
  • 接口地址: /static/shared-styles.css
  • +
  • 请求方法: GET
  • +
  • 响应格式: CSS文件
  • +
  • 功能: 获取共享样式表
  • +
+ +

获取共享脚本

+ +
    +
  • 接口地址: /static/shared.js
  • +
  • 请求方法: GET
  • +
  • 响应格式: JavaScript文件
  • +
  • 功能: 获取共享JavaScript代码
  • +
+ +

环境变量示例

+ +
    +
  • 接口地址: /env-example
  • +
  • 请求方法: GET
  • +
  • 响应格式: 文本文件
  • +
  • 功能: 获取环境变量配置示例
  • +
+ +

其他接口

+ +

获取模型列表

+ +
    +
  • 接口地址: /v1/models
  • +
  • 请求方法: GET
  • +
  • 响应格式:
  • +
+ +
{
+  "object": "list",
+  "data": [
+    {
+      "id": "string",
+      "object": "model",
+      "created": number,
+      "owned_by": "string"
+    }
+  ]
+}
+
+ +

获取随机checksum

+ +
    +
  • 接口地址: /checksum
  • +
  • 请求方法: GET
  • +
  • 响应格式:
  • +
+ +
{
+  "checksum": "string"
+}
+
+ +

健康检查接口

+ +
    +
  • 接口地址: /health/(重定向)
  • +
  • 请求方法: GET
  • +
  • 响应格式: 根据配置返回不同的内容类型(默认、文本或HTML)
  • +
+ +

获取日志接口

+ +
    +
  • 接口地址: /logs
  • +
  • 请求方法: GET
  • +
  • 响应格式: 根据配置返回不同的内容类型(默认、文本或HTML)
  • +
+ +

获取日志数据

+ +
    +
  • 接口地址: /logs
  • +
  • 请求方法: POST
  • +
  • 认证方式: Bearer Token
  • +
  • 响应格式:
  • +
+ +
{
+  "total": number,
+  "logs": [
+    {
+      "timestamp": "string",
+      "model": "string",
+      "token_info": {
+        "token": "string",
+        "checksum": "string",
+        "alias": "string"
+      },
+      "prompt": "string",
+      "stream": boolean,
+      "status": "string",
+      "error": "string"
+    }
+  ],
+  "timestamp": "string",
+  "status": "success"
+}
+
diff --git a/static/shared-styles.css b/static/shared-styles.css index 8e8a885..23556b0 100644 --- a/static/shared-styles.css +++ b/static/shared-styles.css @@ -44,25 +44,38 @@ label { font-weight: 500; } -input[type="text"], +/* input[type="text"], 由于minify.js会删除input[type="text"],所以改为input */ +input, input[type="password"], select, textarea { width: 100%; - padding: 8px 12px; + padding: 10px 12px; border: 1px solid #ddd; border-radius: 4px; font-size: 14px; - transition: border-color 0.2s; + transition: border-color 0.2s, box-shadow 0.2s; + background: white; + color: #333; + appearance: none; } -input[type="text"]:focus, +/* input[type="text"]:focus, 由于minify.js会删除input[type="text"]:focus,所以改为input:focus */ +input:focus, input[type="password"]:focus, select:focus, textarea:focus { border-color: var(--primary-color); outline: none; - box-shadow: 0 0 0 2px rgba(33, 150, 243, 0.1); + box-shadow: 0 0 0 2px rgba(33, 150, 243, 0.2); +} + +select { + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='%23757575'%3E%3Cpath d='M7 10l5 5 5-5H7z'/%3E%3C/svg%3E"); + background-repeat: no-repeat; + background-position: right 8px center; + background-size: 20px; + padding-right: 36px; } textarea {