diff --git a/.env.example b/.env.example index a781392..5c606d4 100644 --- a/.env.example +++ b/.env.example @@ -34,3 +34,6 @@ PASS_ANY_CLAUDE=false # - all 或 base64-http:支持 base64 和 HTTP 图片 # 注意:启用 HTTP 支持可能会暴露服务器 IP VISION_ABILITY=base64 + +# 默认提示词 +DEFAULT_INSTRUCTIONS="Respond in Chinese by default" diff --git a/Cargo.lock b/Cargo.lock index 22a2711..2b01dfc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -292,7 +292,7 @@ dependencies = [ [[package]] name = "cursor-api" -version = "0.1.3-rc.2" +version = "0.1.3-rc.3" dependencies = [ "axum", "base64", diff --git a/Cargo.toml b/Cargo.toml index c8ab72d..714c3d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cursor-api" -version = "0.1.3-rc.2" +version = "0.1.3-rc.3" edition = "2021" authors = ["wisdgod "] # license = "MIT" diff --git a/src/app.rs b/src/app.rs index c9d83c6..a08e33d 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,7 +1,4 @@ -pub mod client; pub mod config; pub mod constant; -pub mod models; -pub mod statics; -pub mod token; -pub mod utils; +pub mod model; +pub mod lazy; diff --git a/src/app/config.rs b/src/app/config.rs index c68b319..d7c7ab7 100644 --- a/src/app/config.rs +++ b/src/app/config.rs @@ -1,7 +1,7 @@ use super::{ - constant::*, - models::{AppConfig, AppState}, - statics::*, + constant::{HEADER_NAME_AUTHORIZATION, AUTHORIZATION_BEARER_PREFIX}, + model::{AppConfig, AppState}, + lazy::AUTH_TOKEN, }; use crate::common::models::{ config::{ConfigData, ConfigUpdateRequest}, @@ -15,6 +15,44 @@ use axum::{ use std::sync::Arc; use tokio::sync::Mutex; +// 定义处理更新操作的宏 +macro_rules! handle_update { + ($request:expr, $field:ident, $update_fn:expr, $field_name:expr) => { + if let Some($field) = $request.$field { + if let Err(e) = $update_fn($field) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("更新 {} 失败: {}", $field_name, e)), + message: None, + }), + )); + } + } + }; +} + +// 定义处理重置操作的宏 +macro_rules! handle_reset { + ($request:expr, $field:ident, $reset_fn:expr, $field_name:expr) => { + if $request.$field.is_some() { + if let Err(e) = $reset_fn() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failed, + code: Some(500), + error: Some(format!("重置 {} 失败: {}", $field_name, e)), + message: None, + }), + )); + } + } + }; +} + pub async fn handle_config_update( State(_state): State>>, headers: HeaderMap, @@ -34,7 +72,7 @@ pub async fn handle_config_update( }), ))?; - if auth_header != get_auth_token() { + if auth_header != AUTH_TOKEN.as_str() { return Err(( StatusCode::UNAUTHORIZED, Json(ErrorResponse { @@ -65,7 +103,6 @@ pub async fn handle_config_update( // 处理页面内容更新 if !request.path.is_empty() && request.content.is_some() { let content = request.content.unwrap(); - if let Err(e) = AppConfig::update_page_content(&request.path, content) { return Err(( StatusCode::INTERNAL_SERVER_ERROR, @@ -79,95 +116,12 @@ pub async fn handle_config_update( } } - // 处理 enable_stream_check 更新 - if let Some(enable_stream_check) = request.enable_stream_check { - if let Err(e) = AppConfig::update_stream_check(enable_stream_check) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("更新 enable_stream_check 失败: {}", e)), - message: None, - }), - )); - } - } - - // 处理 include_stop_stream 更新 - if let Some(include_stop_stream) = request.include_stop_stream { - if let Err(e) = AppConfig::update_stop_stream(include_stop_stream) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("更新 include_stop_stream 失败: {}", e)), - message: None, - }), - )); - } - } - - // 处理 vision_ability 更新 - if let Some(vision_ability) = request.vision_ability { - if let Err(e) = AppConfig::update_vision_ability(vision_ability) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("更新 vision_ability 失败: {}", e)), - message: None, - }), - )); - } - } - - // 处理 enable_slow_pool 更新 - if let Some(enable_slow_pool) = request.enable_slow_pool { - if let Err(e) = AppConfig::update_slow_pool(enable_slow_pool) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("更新 enable_slow_pool 失败: {}", e)), - message: None, - }), - )); - } - } - - // 处理 enable_all_claude 更新 - if let Some(enable_all_claude) = request.enable_all_claude { - if let Err(e) = AppConfig::update_allow_claude(enable_all_claude) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("更新 enable_all_claude 失败: {}", e)), - message: None, - }), - )); - } - } - - // 处理 check_usage_models 更新 - if let Some(check_usage_models) = request.check_usage_models { - if let Err(e) = AppConfig::update_usage_check(check_usage_models) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("更新 check_usage_models 失败: {}", e)), - message: None, - }), - )); - } - } + handle_update!(request, enable_stream_check, AppConfig::update_stream_check, "enable_stream_check"); + handle_update!(request, include_stop_stream, AppConfig::update_stop_stream, "include_stop_stream"); + handle_update!(request, vision_ability, AppConfig::update_vision_ability, "vision_ability"); + handle_update!(request, enable_slow_pool, AppConfig::update_slow_pool, "enable_slow_pool"); + handle_update!(request, enable_all_claude, AppConfig::update_allow_claude, "enable_all_claude"); + handle_update!(request, check_usage_models, AppConfig::update_usage_check, "check_usage_models"); Ok(Json(NormalResponse { status: ApiStatus::Success, @@ -192,95 +146,13 @@ pub async fn handle_config_update( } } - // 重置 enable_stream_check - if request.enable_stream_check.is_some() { - if let Err(e) = AppConfig::reset_stream_check() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("重置 enable_stream_check 失败: {}", e)), - message: None, - }), - )); - } - } + handle_reset!(request, enable_stream_check, AppConfig::reset_stream_check, "enable_stream_check"); + handle_reset!(request, include_stop_stream, AppConfig::reset_stop_stream, "include_stop_stream"); + handle_reset!(request, vision_ability, AppConfig::reset_vision_ability, "vision_ability"); + handle_reset!(request, enable_slow_pool, AppConfig::reset_slow_pool, "enable_slow_pool"); + handle_reset!(request, enable_all_claude, AppConfig::reset_allow_claude, "enable_all_claude"); + handle_reset!(request, check_usage_models, AppConfig::reset_usage_check, "check_usage_models"); - // 重置 include_stop_stream - if request.include_stop_stream.is_some() { - if let Err(e) = AppConfig::reset_stop_stream() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("重置 include_stop_stream 失败: {}", e)), - message: None, - }), - )); - } - } - - // 重置 vision_ability - if request.vision_ability.is_some() { - if let Err(e) = AppConfig::reset_vision_ability() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("重置 vision_ability 失败: {}", e)), - message: None, - }), - )); - } - } - - // 重置 enable_slow_pool - if request.enable_slow_pool.is_some() { - if let Err(e) = AppConfig::reset_slow_pool() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("重置 enable_slow_pool 失败: {}", e)), - message: None, - }), - )); - } - } - - // 重置 enable_all_claude - if request.enable_all_claude.is_some() { - if let Err(e) = AppConfig::reset_allow_claude() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("重置 enable_slow_pool 失败: {}", e)), - message: None, - }), - )); - } - } - - // 重置 check_usage_models - if request.check_usage_models.is_some() { - if let Err(e) = AppConfig::reset_usage_check() { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failed, - code: Some(500), - error: Some(format!("重置 check_usage_models 失败: {}", e)), - message: None, - }), - )); - } - } Ok(Json(NormalResponse { status: ApiStatus::Success, data: None, diff --git a/src/app/constant.rs b/src/app/constant.rs index 15cff8f..0658add 100644 --- a/src/app/constant.rs +++ b/src/app/constant.rs @@ -5,10 +5,10 @@ macro_rules! def_pub_const { } def_pub_const!(PKG_VERSION, env!("CARGO_PKG_VERSION")); -def_pub_const!(PKG_NAME, env!("CARGO_PKG_NAME")); -def_pub_const!(PKG_DESCRIPTION, env!("CARGO_PKG_DESCRIPTION")); -def_pub_const!(PKG_AUTHORS, env!("CARGO_PKG_AUTHORS")); -def_pub_const!(PKG_REPOSITORY, env!("CARGO_PKG_REPOSITORY")); +// def_pub_const!(PKG_NAME, env!("CARGO_PKG_NAME")); +// def_pub_const!(PKG_DESCRIPTION, env!("CARGO_PKG_DESCRIPTION")); +// def_pub_const!(PKG_AUTHORS, env!("CARGO_PKG_AUTHORS")); +// def_pub_const!(PKG_REPOSITORY, env!("CARGO_PKG_REPOSITORY")); def_pub_const!(EMPTY_STRING, ""); @@ -28,16 +28,8 @@ def_pub_const!(ROUTE_SHARED_JS_PATH, "/static/shared.js"); def_pub_const!(ROUTE_ABOUT_PATH, "/about"); def_pub_const!(ROUTE_README_PATH, "/readme"); -def_pub_const!(STATUS, "status"); -def_pub_const!(MESSAGE, "message"); -def_pub_const!(ERROR, "error"); - -def_pub_const!(TOKEN_FILE, "token_file"); def_pub_const!(DEFAULT_TOKEN_FILE_NAME, ".token"); -def_pub_const!(TOKEN_LIST_FILE, "token_list_file"); def_pub_const!(DEFAULT_TOKEN_LIST_FILE_NAME, ".token-list"); -def_pub_const!(TOKENS, "tokens"); -def_pub_const!(TOKEN_LIST, "token_list"); def_pub_const!(STATUS_SUCCESS, "success"); def_pub_const!(STATUS_FAILED, "failed"); diff --git a/src/app/statics.rs b/src/app/lazy.rs similarity index 60% rename from src/app/statics.rs rename to src/app/lazy.rs index 43b559f..553a87a 100644 --- a/src/app/statics.rs +++ b/src/app/lazy.rs @@ -1,6 +1,6 @@ -use super::{ - constant::{DEFAULT_TOKEN_FILE_NAME, DEFAULT_TOKEN_LIST_FILE_NAME, EMPTY_STRING}, - utils::parse_string_from_env, +use crate::{ + app::constant::{DEFAULT_TOKEN_FILE_NAME, DEFAULT_TOKEN_LIST_FILE_NAME, EMPTY_STRING}, + common::utils::parse_string_from_env, }; use std::sync::LazyLock; @@ -8,28 +8,24 @@ macro_rules! def_pub_static { // 基础版本:直接存储 String ($name:ident, $value:expr) => { pub static $name: LazyLock = LazyLock::new(|| $value); - - def_pub_static_getter!($name); }; // 环境变量版本 ($name:ident, env: $env_key:expr, default: $default:expr) => { pub static $name: LazyLock = LazyLock::new(|| parse_string_from_env($env_key, $default).trim().to_string()); - - def_pub_static_getter!($name); }; } -macro_rules! def_pub_static_getter { - ($name:ident) => { - paste::paste! { - pub fn []() -> String { - (*$name).clone() - } - } - }; -} +// macro_rules! def_pub_static_getter { +// ($name:ident) => { +// paste::paste! { +// pub fn []() -> String { +// (*$name).clone() +// } +// } +// }; +// } def_pub_static!(ROUTE_PREFIX, env: "ROUTE_PREFIX", default: EMPTY_STRING); def_pub_static!(AUTH_TOKEN, env: "AUTH_TOKEN", default: EMPTY_STRING); @@ -50,3 +46,16 @@ pub static START_TIME: LazyLock> = pub fn get_start_time() -> chrono::DateTime { *START_TIME } + +def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default"); + +// pub static DEBUG: LazyLock = LazyLock::new(|| parse_bool_from_env("DEBUG", false)); + +// #[macro_export] +// macro_rules! debug_println { +// ($($arg:tt)*) => { +// if *crate::app::statics::DEBUG { +// println!($($arg)*); +// } +// }; +// } diff --git a/src/app/models.rs b/src/app/model.rs similarity index 96% rename from src/app/models.rs rename to src/app/model.rs index d71a1ef..0476065 100644 --- a/src/app/models.rs +++ b/src/app/model.rs @@ -1,5 +1,12 @@ -use super::{constant::*, token::UserUsageInfo}; -use crate::chat::models::Message; +use crate::{ + app::constant::{ + ERR_INVALID_PATH, ERR_RESET_CONFIG, ERR_UPDATE_CONFIG, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, + ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_SHARED_JS_PATH, + ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENINFO_PATH, + }, + common::models::usage::UserUsageInfo, +}; +use crate::chat::model::Message; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use std::sync::RwLock; @@ -113,7 +120,7 @@ macro_rules! config_methods { .map(|config| config.$field.clone()) .unwrap_or($default) } - + pub fn [](value: $type) -> Result<(), &'static str> { if let Ok(mut config) = APP_CONFIG.write() { config.$field = value; @@ -122,7 +129,7 @@ macro_rules! config_methods { Err(ERR_UPDATE_CONFIG) } } - + pub fn []() -> Result<(), &'static str> { if let Ok(mut config) = APP_CONFIG.write() { config.$field = $default; @@ -188,7 +195,6 @@ impl AppConfig { .unwrap_or_default() } - pub fn update_vision_ability(new_ability: VisionAbility) -> Result<(), &'static str> { if let Ok(mut config) = APP_CONFIG.write() { config.vision_ability = new_ability; @@ -275,10 +281,6 @@ impl AppState { token_infos, } } - - pub fn update_token_infos(&mut self, token_infos: Vec) { - self.token_infos = token_infos; - } } // 请求日志 diff --git a/src/app/models/usage_check.rs b/src/app/model/usage_check.rs similarity index 100% rename from src/app/models/usage_check.rs rename to src/app/model/usage_check.rs diff --git a/src/app/token.rs b/src/app/token.rs deleted file mode 100644 index edd6459..0000000 --- a/src/app/token.rs +++ /dev/null @@ -1,341 +0,0 @@ -use super::{ - constant::*, - models::{AppState, TokenInfo, TokenUpdateRequest}, - statics::*, - utils::{generate_checksum, generate_hash, i32_to_u32}, -}; -use crate::{chat::aiserver::v1::GetUserInfoResponse, common::models::{ApiStatus, NormalResponseNoData}}; -use axum::http::HeaderMap; -use axum::{ - extract::{Query, State}, - Json, -}; -use image::EncodableLayout; -use prost::Message; -use reqwest::StatusCode; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use tokio::sync::Mutex; - -// 规范化文件内容并写入 -fn normalize_and_write(content: &str, file_path: &str) -> String { - let normalized = content.replace("\r\n", "\n"); - if normalized != content { - if let Err(e) = std::fs::write(file_path, &normalized) { - eprintln!("警告: 无法更新规范化的文件: {}", e); - } - } - normalized -} - -// 解析token和别名 -fn parse_token_alias(token_part: &str, line: &str) -> Option<(String, Option)> { - match token_part.split("::").collect::>() { - parts if parts.len() == 1 => Some((parts[0].to_string(), None)), - parts if parts.len() == 2 => Some((parts[1].to_string(), Some(parts[0].to_string()))), - _ => { - eprintln!("警告: 忽略无效的行: {}", line); - None - } - } -} - -// Token 加载函数 -pub fn load_tokens() -> Vec { - let token_file = get_token_file(); - let token_list_file = get_token_list_file(); - - // 确保文件存在 - for file in [&token_file, &token_list_file] { - if !std::path::Path::new(file).exists() { - if let Err(e) = std::fs::write(file, EMPTY_STRING) { - eprintln!("警告: 无法创建文件 '{}': {}", file, e); - } - } - } - - // 读取和规范化 token 文件 - let token_entries = match std::fs::read_to_string(&token_file) { - Ok(content) => { - let normalized = normalize_and_write(&content, &token_file); - normalized - .lines() - .filter_map(|line| { - let line = line.trim(); - if line.is_empty() || line.starts_with('#') { - return None; - } - parse_token_alias(line, line) - }) - .collect::>() - } - Err(e) => { - eprintln!("警告: 无法读取token文件 '{}': {}", token_file, e); - Vec::new() - } - }; - - // 读取和规范化 token-list 文件 - let mut token_map: std::collections::HashMap)> = - match std::fs::read_to_string(&token_list_file) { - Ok(content) => { - let normalized = normalize_and_write(&content, &token_list_file); - normalized - .lines() - .filter_map(|line| { - let line = line.trim(); - if line.is_empty() || line.starts_with('#') { - return None; - } - - let parts: Vec<&str> = line.split(',').collect(); - match parts[..] { - [token_part, checksum] => { - let (token, alias) = parse_token_alias(token_part, line)?; - Some((token, (checksum.to_string(), alias))) - } - _ => { - eprintln!("警告: 忽略无效的token-list行: {}", line); - None - } - } - }) - .collect() - } - Err(e) => { - eprintln!("警告: 无法读取token-list文件: {}", e); - std::collections::HashMap::new() - } - }; - - // 更新或添加新token - for (token, alias) in token_entries { - if let Some((_, existing_alias)) = token_map.get(&token) { - // 只在alias不同时更新已存在的token - if alias != *existing_alias { - if let Some((checksum, _)) = token_map.get(&token) { - token_map.insert(token.clone(), (checksum.clone(), alias)); - } - } - } else { - // 为新token生成checksum - let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); - token_map.insert(token, (checksum, alias)); - } - } - - // 更新 token-list 文件 - let token_list_content = token_map - .iter() - .map(|(token, (checksum, alias))| { - if let Some(alias) = alias { - format!("{}::{},{}", alias, token, checksum) - } else { - format!("{},{}", token, checksum) - } - }) - .collect::>() - .join("\n"); - - if let Err(e) = std::fs::write(&token_list_file, token_list_content) { - eprintln!("警告: 无法更新token-list文件: {}", e); - } - - // 转换为 TokenInfo vector - token_map - .into_iter() - .map(|(token, (checksum, alias))| TokenInfo { - token, - checksum, - alias, - usage: None, - }) - .collect() -} - -#[derive(Serialize)] -pub struct ChecksumResponse { - pub checksum: String, -} - -pub async fn handle_get_checksum() -> Json { - let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); - Json(ChecksumResponse { checksum }) -} - -// 更新 TokenInfo 处理 -pub async fn handle_update_tokeninfo( - State(state): State>>, -) -> Json { - // 重新加载 tokens - let token_infos = load_tokens(); - - // 更新应用状态 - { - let mut state = state.lock().await; - state.token_infos = token_infos; - } - - Json(NormalResponseNoData { - status: ApiStatus::Success, - message: Some("Token list has been reloaded".to_string()), - }) -} - -// 获取 TokenInfo 处理 -pub async fn handle_get_tokeninfo( - State(_state): State>>, - headers: HeaderMap, -) -> Result, StatusCode> { - let auth_token = get_auth_token(); - let token_file = get_token_file(); - let token_list_file = get_token_list_file(); - - // 验证 AUTH_TOKEN - let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(StatusCode::UNAUTHORIZED)?; - - if auth_header != auth_token { - return Err(StatusCode::UNAUTHORIZED); - } - - // 读取文件内容 - let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new()); - let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new()); - - Ok(Json(TokenInfoResponse { - status: ApiStatus::Success, - token_file: token_file.clone(), - token_list_file: token_list_file.clone(), - tokens: Some(tokens.clone()), - tokens_count: Some(tokens.len()), - token_list: Some(token_list), - message: None, - })) -} - -#[derive(Serialize)] -pub struct TokenInfoResponse { - pub status: ApiStatus, - pub token_file: String, - pub token_list_file: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tokens_count: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub token_list: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub message: Option, -} - -pub async fn handle_update_tokeninfo_post( - State(state): State>>, - headers: HeaderMap, - Json(request): Json, -) -> Result, StatusCode> { - let auth_token = get_auth_token(); - let token_file = get_token_file(); - let token_list_file = get_token_list_file(); - - // 验证 AUTH_TOKEN - let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(StatusCode::UNAUTHORIZED)?; - - if auth_header != auth_token { - return Err(StatusCode::UNAUTHORIZED); - } - - // 写入 .token 文件 - std::fs::write(&token_file, &request.tokens).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - // 如果提供了 token_list,则写入 - if let Some(token_list) = request.token_list { - std::fs::write(&token_list_file, token_list) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - } - - // 重新加载 tokens - let token_infos = load_tokens(); - let token_infos_len = token_infos.len(); - - // 更新应用状态 - { - let mut state = state.lock().await; - state.token_infos = token_infos; - } - - Ok(Json(TokenInfoResponse { - status: ApiStatus::Success, - token_file: token_file.clone(), - token_list_file: token_list_file.clone(), - tokens: None, - tokens_count: Some(token_infos_len), - token_list: None, - message: Some("Token files have been updated and reloaded".to_string()), - })) -} - -#[derive(Deserialize)] -pub struct GetUserInfoQuery { - alias: String, -} - -pub async fn get_user_info( - State(state): State>>, - Query(query): Query, -) -> Json { - let token_infos = &state.lock().await.token_infos; - let token_info = token_infos - .iter() - .find(|token_info| token_info.alias == Some(query.alias.clone())); - - let (auth_token, checksum) = match token_info { - Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()), - None => return Json(GetUserInfo::Error("No data".to_string())), - }; - - match get_user_usage(&auth_token, &checksum).await { - Some(usage) => Json(GetUserInfo::Usage(usage)), - None => Json(GetUserInfo::Error("No data".to_string())), - } -} - -pub async fn get_user_usage(auth_token: &str, checksum: &str) -> Option { - // 构建请求客户端 - let client = super::client::build_client(auth_token, checksum, CURSOR_API2_GET_USER_INFO); - let response = client - .body(Vec::new()) - .send() - .await - .ok()? - .bytes() - .await - .ok()?; - let user_info = GetUserInfoResponse::decode(response.as_bytes()).ok()?; - - user_info.usage.map(|user_usage| UserUsageInfo { - fast_requests: i32_to_u32(user_usage.gpt4_requests), - max_fast_requests: i32_to_u32(user_usage.gpt4_max_requests), - }) -} - -#[derive(Serialize)] -pub enum GetUserInfo { - #[serde(rename = "usage")] - Usage(UserUsageInfo), - #[serde(rename = "error")] - Error(String), -} - -#[derive(Serialize, Clone)] -pub struct UserUsageInfo { - pub fast_requests: u32, - pub max_fast_requests: u32, -} diff --git a/src/app/utils.rs b/src/app/utils.rs deleted file mode 100644 index d44fe3a..0000000 --- a/src/app/utils.rs +++ /dev/null @@ -1,25 +0,0 @@ -mod checksum; -pub use checksum::*; - -pub fn parse_bool_from_env(key: &str, default: bool) -> bool { - std::env::var(key) - .ok() - .map(|v| match v.to_lowercase().as_str() { - "true" | "1" => true, - "false" | "0" => false, - _ => default, - }) - .unwrap_or(default) -} - -pub fn parse_string_from_env(key: &str, default: &str) -> String { - std::env::var(key).unwrap_or_else(|_| default.to_string()) -} - -pub fn i32_to_u32(value: i32) -> u32 { - if value < 0 { - 0 - } else { - value as u32 - } -} diff --git a/src/chat.rs b/src/chat.rs index a2a4299..b557892 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,6 +1,8 @@ +pub mod adapter; pub mod aiserver; pub mod constant; pub mod error; -pub mod models; +pub mod model; +pub mod route; pub mod service; pub mod stream; diff --git a/src/lib.rs b/src/chat/adapter.rs similarity index 97% rename from src/lib.rs rename to src/chat/adapter.rs index be27b01..f57f0ea 100644 --- a/src/lib.rs +++ b/src/chat/adapter.rs @@ -3,19 +3,19 @@ use image::guess_format; use prost::Message as _; use uuid::Uuid; -pub mod common; +use crate::app::{ + constant::EMPTY_STRING, + model::{AppConfig, VisionAbility}, + lazy::DEFAULT_INSTRUCTIONS, +}; -pub mod app; -use app::{constant::EMPTY_STRING, models::*}; - -pub mod chat; -use chat::{ +use super::{ aiserver::v1::{ conversation_message, image_proto, ConversationMessage, ExplicitContext, GetChatRequest, ImageProto, ModelDetails, }, - constant::{LONG_CONTEXT_MODELS, ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT}, - models::{Message, MessageContent, Role}, + constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS}, + model::{Message, MessageContent, Role}, }; async fn process_chat_inputs(inputs: Vec) -> (String, Vec) { @@ -42,7 +42,7 @@ async fn process_chat_inputs(inputs: Vec) -> (String, Vec { @@ -9,7 +9,7 @@ def_pub_const!(ERR_UNSUPPORTED_GIF, "不支持动态 GIF"); def_pub_const!(ERR_UNSUPPORTED_IMAGE_FORMAT, "不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF"); const MODEL_OBJECT: &str = "model"; -const CREATED: i64 = 1706659200; +const CREATED: &i64 = &1706659200; def_pub_const!(ANTHROPIC, "anthropic"); def_pub_const!(CURSOR, "cursor"); @@ -41,7 +41,7 @@ def_pub_const!( ); def_pub_const!(GEMINI_2_0_FLASH_EXP, "gemini-2.0-flash-exp"); -pub const AVAILABLE_MODELS: &[Model] = &[ +pub const AVAILABLE_MODELS: [Model; 21] = [ Model { id: CLAUDE_3_5_SONNET, created: CREATED, diff --git a/src/chat/error.rs b/src/chat/error.rs index 9f49ee3..290c981 100644 --- a/src/chat/error.rs +++ b/src/chat/error.rs @@ -126,9 +126,9 @@ pub struct Error { } impl ErrorResponse { - pub fn to_json(&self) -> serde_json::Value { - serde_json::to_value(self).unwrap() - } + // pub fn to_json(&self) -> serde_json::Value { + // serde_json::to_value(self).unwrap() + // } pub fn status_code(&self) -> StatusCode { StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) diff --git a/src/chat/models.rs b/src/chat/model.rs similarity index 93% rename from src/chat/models.rs rename to src/chat/model.rs index dd064ee..0091dc6 100644 --- a/src/chat/models.rs +++ b/src/chat/model.rs @@ -72,21 +72,21 @@ pub struct Delta { #[derive(Serialize)] pub struct Usage { - pub prompt_tokens: i32, - pub completion_tokens: i32, - pub total_tokens: i32, + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, } // 模型定义 #[derive(Serialize, Clone)] pub struct Model { pub id: &'static str, - pub created: i64, + pub created: &'static i64, pub object: &'static str, pub owned_by: &'static str, } -use crate::{AppConfig, UsageCheck}; +use crate::app::model::{AppConfig, UsageCheck}; use super::constant::USAGE_CHECK_MODELS; impl Model { diff --git a/src/chat/route.rs b/src/chat/route.rs new file mode 100644 index 0000000..e345a32 --- /dev/null +++ b/src/chat/route.rs @@ -0,0 +1,10 @@ +mod logs; +pub use logs::{handle_logs, handle_logs_post}; +mod health; +pub use health::{handle_root, handle_health}; +mod token; +pub use token::{handle_get_checksum, handle_update_tokeninfo, handle_get_tokeninfo, handle_update_tokeninfo_post, handle_tokeninfo_page}; +mod usage; +pub use usage::get_user_info; +mod config; +pub use config::{handle_env_example, handle_config_page, handle_static, handle_readme, handle_about}; diff --git a/src/chat/route/config.rs b/src/chat/route/config.rs new file mode 100644 index 0000000..328d541 --- /dev/null +++ b/src/chat/route/config.rs @@ -0,0 +1,108 @@ +use crate::app::{ + constant::{ + CONTENT_TYPE_TEXT_CSS_WITH_UTF8, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, + CONTENT_TYPE_TEXT_JS_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, + HEADER_NAME_CONTENT_TYPE, HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, + ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, + }, + model::{AppConfig, PageContent}, +}; +use axum::{ + body::Body, + extract::Path, + http::StatusCode, + response::{IntoResponse, Response}, +}; + +pub async fn handle_env_example() -> impl IntoResponse { + Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(include_str!("../../../.env.example").to_string()) + .unwrap() +} + +// 配置页面处理函数 +pub async fn handle_config_page() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_CONFIG_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(include_str!("../../../static/config.min.html").to_string()) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(content.clone()) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(content.clone()) + .unwrap(), + } +} + +pub async fn handle_static(Path(path): Path) -> impl IntoResponse { + match path.as_str() { + "shared-styles.css" => { + match AppConfig::get_page_content(ROUTE_SHARED_STYLES_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) + .body(include_str!("../../../static/shared-styles.min.css").to_string()) + .unwrap(), + PageContent::Text(content) | PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) + .body(content.clone()) + .unwrap(), + } + } + "shared.js" => { + match AppConfig::get_page_content(ROUTE_SHARED_JS_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) + .body(include_str!("../../../static/shared.min.js").to_string()) + .unwrap(), + PageContent::Text(content) | PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) + .body(content.clone()) + .unwrap(), + } + } + _ => Response::builder() + .status(StatusCode::NOT_FOUND) + .body("Not found".to_string()) + .unwrap(), + } +} + +pub async fn handle_about() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_ABOUT_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(include_str!("../../../static/readme.min.html").to_string()) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(content.clone()) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(content.clone()) + .unwrap(), + } +} + +pub async fn handle_readme() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_README_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .status(StatusCode::TEMPORARY_REDIRECT) + .header(HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH) + .body(Body::empty()) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(Body::from(content.clone())) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(content.clone())) + .unwrap(), + } +} diff --git a/src/chat/route/health.rs b/src/chat/route/health.rs new file mode 100644 index 0000000..e5acdb1 --- /dev/null +++ b/src/chat/route/health.rs @@ -0,0 +1,112 @@ +use crate::{ + app::{ + constant::{ + CONTENT_TYPE_TEXT_HTML_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, + HEADER_NAME_CONTENT_TYPE, HEADER_NAME_LOCATION, PKG_VERSION, ROUTE_ABOUT_PATH, + ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, + ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, + ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKENINFO_PATH, + ROUTE_UPDATE_TOKENINFO_PATH, + }, + model::{AppConfig, AppState, PageContent}, + lazy::{get_start_time, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH}, + }, + chat::constant::AVAILABLE_MODELS, + common::models::{ + health::{CpuInfo, HealthCheckResponse, MemoryInfo, SystemInfo, SystemStats}, + ApiStatus, + }, +}; +use axum::{ + body::Body, + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use chrono::Local; +use std::sync::Arc; +use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System}; +use tokio::sync::Mutex; + +pub async fn handle_root() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_ROOT_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .status(StatusCode::TEMPORARY_REDIRECT) + .header(HEADER_NAME_LOCATION, ROUTE_HEALTH_PATH) + .body(Body::empty()) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(Body::from(content.clone())) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(content.clone())) + .unwrap(), + } +} + +pub async fn handle_health(State(state): State>>) -> Json { + let start_time = get_start_time(); + + // 创建系统信息实例,只监控 CPU 和内存 + let mut sys = System::new_with_specifics( + RefreshKind::nothing() + .with_memory(MemoryRefreshKind::everything()) + .with_cpu(CpuRefreshKind::everything()), + ); + + std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL); + + // 刷新 CPU 和内存信息 + sys.refresh_memory(); + sys.refresh_cpu_usage(); + + let pid = std::process::id() as usize; + let process = sys.process(pid.into()); + + // 获取内存信息 + let memory = process.map(|p| p.memory()).unwrap_or(0); + + // 获取 CPU 使用率 + let cpu_usage = sys.global_cpu_usage(); + + let state = state.lock().await; + let uptime = (Local::now() - start_time).num_seconds(); + + Json(HealthCheckResponse { + status: ApiStatus::Healthy, + version: PKG_VERSION, + uptime, + stats: SystemStats { + started: start_time.to_string(), + total_requests: state.total_requests, + active_requests: state.active_requests, + system: SystemInfo { + memory: MemoryInfo { + rss: memory, // 物理内存使用量(字节) + }, + cpu: CpuInfo { + usage: cpu_usage, // CPU 使用率(百分比) + }, + }, + }, + models: AVAILABLE_MODELS.iter().map(|m| m.id).collect::>(), + endpoints: vec![ + ROUTE_CHAT_PATH.as_str(), + ROUTE_MODELS_PATH.as_str(), + ROUTE_GET_CHECKSUM, + ROUTE_TOKENINFO_PATH, + ROUTE_UPDATE_TOKENINFO_PATH, + ROUTE_GET_TOKENINFO_PATH, + ROUTE_LOGS_PATH, + ROUTE_GET_USER_INFO_PATH, + ROUTE_ENV_EXAMPLE_PATH, + ROUTE_CONFIG_PATH, + ROUTE_STATIC_PATH, + ROUTE_ABOUT_PATH, + ROUTE_README_PATH, + ], + }) +} diff --git a/src/chat/route/logs.rs b/src/chat/route/logs.rs new file mode 100644 index 0000000..a7576bb --- /dev/null +++ b/src/chat/route/logs.rs @@ -0,0 +1,76 @@ +use crate::{ + app::{ + constant::{ + AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, + CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE, + ROUTE_LOGS_PATH, + }, + model::{AppConfig, AppState, PageContent, RequestLog}, + lazy::AUTH_TOKEN, + }, + common::models::ApiStatus, +}; +use axum::{ + body::Body, + extract::State, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, + Json, +}; +use chrono::Local; +use std::sync::Arc; +use tokio::sync::Mutex; + +// 日志处理 +pub async fn handle_logs() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_LOGS_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from( + include_str!("../../../static/logs.min.html").to_string(), + )) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(Body::from(content.clone())) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(content.clone())) + .unwrap(), + } +} + +pub async fn handle_logs_post( + State(state): State>>, + headers: HeaderMap, +) -> Result, StatusCode> { + let auth_token = AUTH_TOKEN.as_str(); + + // 验证 AUTH_TOKEN + let auth_header = headers + .get(HEADER_NAME_AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if auth_header != auth_token { + return Err(StatusCode::UNAUTHORIZED); + } + + let state = state.lock().await; + Ok(Json(LogsResponse { + status: ApiStatus::Success, + total: state.request_logs.len(), + logs: state.request_logs.clone(), + timestamp: Local::now().to_string(), + })) +} + +#[derive(serde::Serialize)] +pub struct LogsResponse { + pub status: ApiStatus, + pub total: usize, + pub logs: Vec, + pub timestamp: String, +} diff --git a/src/chat/route/token.rs b/src/chat/route/token.rs new file mode 100644 index 0000000..f637c13 --- /dev/null +++ b/src/chat/route/token.rs @@ -0,0 +1,171 @@ +use crate::{ + app::{ + constant::{ + AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, + CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE, + ROUTE_TOKENINFO_PATH, + }, + model::{AppConfig, AppState, PageContent, TokenUpdateRequest}, + lazy::{AUTH_TOKEN, TOKEN_FILE, TOKEN_LIST_FILE}, + }, + common::{ + models::{ApiStatus, NormalResponseNoData}, + utils::{generate_checksum, generate_hash, tokens::load_tokens}, + }, +}; +use axum::{ + extract::State, + http::HeaderMap, + response::{IntoResponse, Response}, + Json, +}; +use reqwest::StatusCode; +use serde::Serialize; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Serialize)] +pub struct ChecksumResponse { + pub checksum: String, +} + +pub async fn handle_get_checksum() -> Json { + let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); + Json(ChecksumResponse { checksum }) +} + +// 更新 TokenInfo 处理 +pub async fn handle_update_tokeninfo( + State(state): State>>, +) -> Json { + // 重新加载 tokens + let token_infos = load_tokens(); + + // 更新应用状态 + { + let mut state = state.lock().await; + state.token_infos = token_infos; + } + + Json(NormalResponseNoData { + status: ApiStatus::Success, + message: Some("Token list has been reloaded".to_string()), + }) +} + +// 获取 TokenInfo 处理 +pub async fn handle_get_tokeninfo( + State(_state): State>>, + headers: HeaderMap, +) -> Result, StatusCode> { + let auth_token = AUTH_TOKEN.as_str(); + let token_file = TOKEN_FILE.as_str(); + let token_list_file = TOKEN_LIST_FILE.as_str(); + + // 验证 AUTH_TOKEN + let auth_header = headers + .get(HEADER_NAME_AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if auth_header != auth_token { + return Err(StatusCode::UNAUTHORIZED); + } + + // 读取文件内容 + let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new()); + let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new()); + + Ok(Json(TokenInfoResponse { + status: ApiStatus::Success, + token_file: token_file.to_string(), + token_list_file: token_list_file.to_string(), + tokens: Some(tokens.clone()), + tokens_count: Some(tokens.len()), + token_list: Some(token_list), + message: None, + })) +} + +#[derive(Serialize)] +pub struct TokenInfoResponse { + pub status: ApiStatus, + pub token_file: String, + pub token_list_file: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub token_list: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +pub async fn handle_update_tokeninfo_post( + State(state): State>>, + headers: HeaderMap, + Json(request): Json, +) -> Result, StatusCode> { + let auth_token = AUTH_TOKEN.as_str(); + let token_file = TOKEN_FILE.as_str(); + let token_list_file = TOKEN_LIST_FILE.as_str(); + + // 验证 AUTH_TOKEN + let auth_header = headers + .get(HEADER_NAME_AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if auth_header != auth_token { + return Err(StatusCode::UNAUTHORIZED); + } + + // 写入 .token 文件 + std::fs::write(&token_file, &request.tokens).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // 如果提供了 token_list,则写入 + if let Some(token_list) = request.token_list { + std::fs::write(&token_list_file, token_list) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + } + + // 重新加载 tokens + let token_infos = load_tokens(); + let token_infos_len = token_infos.len(); + + // 更新应用状态 + { + let mut state = state.lock().await; + state.token_infos = token_infos; + } + + Ok(Json(TokenInfoResponse { + status: ApiStatus::Success, + token_file: token_file.to_string(), + token_list_file: token_list_file.to_string(), + tokens: None, + tokens_count: Some(token_infos_len), + token_list: None, + message: Some("Token files have been updated and reloaded".to_string()), + })) +} + +pub async fn handle_tokeninfo_page() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_TOKENINFO_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(include_str!("../../../static/tokeninfo.min.html").to_string()) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(content.clone()) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(content.clone()) + .unwrap(), + } +} diff --git a/src/chat/route/usage.rs b/src/chat/route/usage.rs new file mode 100644 index 0000000..ae16f50 --- /dev/null +++ b/src/chat/route/usage.rs @@ -0,0 +1,36 @@ +use crate::{ + app::model::AppState, + common::{models::usage::GetUserInfo, utils::get_user_usage}, +}; +use axum::{ + extract::{Query, State}, + Json, +}; +use serde::Deserialize; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Deserialize)] +pub struct GetUserInfoQuery { + alias: String, +} + +pub async fn get_user_info( + State(state): State>>, + Query(query): Query, +) -> Json { + let token_infos = &state.lock().await.token_infos; + let token_info = token_infos + .iter() + .find(|token_info| token_info.alias == Some(query.alias.clone())); + + let (auth_token, checksum) = match token_info { + Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()), + None => return Json(GetUserInfo::Error("No data".to_string())), + }; + + match get_user_usage(&auth_token, &checksum).await { + Some(usage) => Json(GetUserInfo::Usage(usage)), + None => Json(GetUserInfo::Error("No data".to_string())), + } +} diff --git a/src/chat/service.rs b/src/chat/service.rs index 3da3335..24a7758 100644 --- a/src/chat/service.rs +++ b/src/chat/service.rs @@ -1,28 +1,39 @@ +use super::constant::AVAILABLE_MODELS; +use crate::{ + app::{ + constant::{ + AUTHORIZATION_BEARER_PREFIX, CURSOR_API2_STREAM_CHAT, FINISH_REASON_STOP, + HEADER_NAME_CONTENT_TYPE, OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK, + STATUS_FAILED, STATUS_SUCCESS, + }, + model::{AppConfig, AppState, ChatRequest, RequestLog, TokenInfo}, + lazy::AUTH_TOKEN, + }, + chat::{ + error::StreamError, + model::{ + ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage, + }, + stream::{parse_stream_data, StreamMessage}, + }, + common::{ + client::build_client, + models::{error::ChatError, ErrorResponse}, + utils::get_user_usage, + }, +}; use axum::{ - body::Body, - extract::State, - http::{HeaderMap, StatusCode}, - response::Response, - Json, + body::Body, + extract::State, + http::{HeaderMap, StatusCode}, + response::Response, + Json, }; use bytes::Bytes; -use crate::{ - app::{ - client::build_client, - constant::*, - models::*, - statics::*, - token::get_user_usage, - }, chat::{ - error::StreamError, - models::*, - stream::{parse_stream_data, StreamMessage}, - }, common::models::{error::ChatError, ErrorResponse} -}; -use super::constant::AVAILABLE_MODELS; use futures::{Stream, StreamExt}; use std::{ - convert::Infallible, sync::{atomic::AtomicBool, Arc} + convert::Infallible, + sync::{atomic::AtomicBool, Arc}, }; use std::{ pin::Pin, @@ -33,474 +44,509 @@ use uuid::Uuid; // 模型列表处理 pub async fn handle_models() -> Json { - Json(ModelsResponse { - object: "list", - data: AVAILABLE_MODELS, - }) + Json(ModelsResponse { + object: "list", + data: &AVAILABLE_MODELS, + }) } // 聊天处理函数的签名 pub async fn handle_chat( - State(state): State>>, - headers: HeaderMap, - Json(request): Json, + State(state): State>>, + headers: HeaderMap, + Json(request): Json, ) -> Result, (StatusCode, Json)> { - let allow_claude = AppConfig::get_allow_claude(); - // 验证模型是否支持并获取模型信息 - let model = AVAILABLE_MODELS.iter().find(|m| m.id == request.model); - let model_supported = model.is_some(); + let allow_claude = AppConfig::get_allow_claude(); + // 验证模型是否支持并获取模型信息 + let model = AVAILABLE_MODELS.iter().find(|m| m.id == request.model); + let model_supported = model.is_some(); - if !(model_supported || allow_claude && request.model.starts_with("claude")) { - return Err(( - StatusCode::BAD_REQUEST, - Json(ChatError::ModelNotSupported(request.model).to_json()), - )); - } + if !(model_supported || allow_claude && request.model.starts_with("claude")) { + return Err(( + StatusCode::BAD_REQUEST, + Json(ChatError::ModelNotSupported(request.model).to_json()), + )); + } - let request_time = chrono::Local::now(); + let request_time = chrono::Local::now(); - // 验证请求 - if request.messages.is_empty() { - return Err(( - StatusCode::BAD_REQUEST, - Json(ChatError::EmptyMessages.to_json()), - )); - } + // 验证请求 + if request.messages.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ChatError::EmptyMessages.to_json()), + )); + } - // 获取并处理认证令牌 - let auth_token = headers - .get(axum::http::header::AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - ))?; + // 获取并处理认证令牌 + let auth_token = headers + .get(axum::http::header::AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) + .ok_or(( + StatusCode::UNAUTHORIZED, + Json(ChatError::Unauthorized.to_json()), + ))?; - // 验证 AuthToken - if auth_token != get_auth_token() { - return Err(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - )); - } + // 验证 AuthToken + if auth_token != AUTH_TOKEN.as_str() { + return Err(( + StatusCode::UNAUTHORIZED, + Json(ChatError::Unauthorized.to_json()), + )); + } - // 完整的令牌处理逻辑和对应的 checksum - let (auth_token, checksum, alias) = { - static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0); - let state_guard = state.lock().await; - let token_infos = &state_guard.token_infos; + // 完整的令牌处理逻辑和对应的 checksum + let (auth_token, checksum, alias) = { + static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0); + let state_guard = state.lock().await; + let token_infos = &state_guard.token_infos; - if token_infos.is_empty() { - return Err(( - StatusCode::SERVICE_UNAVAILABLE, - Json(ChatError::NoTokens.to_json()), - )); - } + if token_infos.is_empty() { + return Err(( + StatusCode::SERVICE_UNAVAILABLE, + Json(ChatError::NoTokens.to_json()), + )); + } - let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len(); - let token_info = &token_infos[index]; - ( - token_info.token.clone(), - token_info.checksum.clone(), - token_info.alias.clone(), - ) - }; + let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len(); + let token_info = &token_infos[index]; + ( + token_info.token.clone(), + token_info.checksum.clone(), + token_info.alias.clone(), + ) + }; - // 更新请求日志 - { - let state_clone = state.clone(); - let mut state = state.lock().await; - state.total_requests += 1; - state.active_requests += 1; + // 更新请求日志 + { + let state_clone = state.clone(); + let mut state = state.lock().await; + state.total_requests += 1; + state.active_requests += 1; - // 如果有model且需要获取使用情况,创建后台任务获取 - if let Some(model) = model { - if model.is_usage_check() { - let auth_token_clone = auth_token.clone(); - let checksum_clone = checksum.clone(); - let state_clone = state_clone.clone(); + // 如果有model且需要获取使用情况,创建后台任务获取 + if let Some(model) = model { + if model.is_usage_check() { + let auth_token_clone = auth_token.clone(); + let checksum_clone = checksum.clone(); + let state_clone = state_clone.clone(); - tokio::spawn(async move { - let usage = get_user_usage(&auth_token_clone, &checksum_clone).await; - let mut state = state_clone.lock().await; - // 根据时间戳找到对应的日志 - if let Some(log) = state - .request_logs - .iter_mut() - .find(|log| log.timestamp == request_time) - { - log.token_info.usage = usage; - } - }); - } - } + tokio::spawn(async move { + let usage = get_user_usage(&auth_token_clone, &checksum_clone).await; + let mut state = state_clone.lock().await; + // 根据时间戳找到对应的日志 + if let Some(log) = state + .request_logs + .iter_mut() + .find(|log| log.timestamp == request_time) + { + log.token_info.usage = usage; + } + }); + } + } - state.request_logs.push(RequestLog { - timestamp: request_time, - model: request.model.clone(), - token_info: TokenInfo { - token: auth_token.clone(), - checksum: checksum.clone(), - alias: alias.clone(), - usage: None, - }, - prompt: None, - stream: request.stream, - status: "pending", - error: None, - }); + state.request_logs.push(RequestLog { + timestamp: request_time, + model: request.model.clone(), + token_info: TokenInfo { + token: auth_token.clone(), + checksum: checksum.clone(), + alias: alias.clone(), + usage: None, + }, + prompt: None, + stream: request.stream, + status: "pending", + error: None, + }); - if state.request_logs.len() > 100 { - state.request_logs.remove(0); - } - } + if state.request_logs.len() > 100 { + state.request_logs.remove(0); + } + } - // 将消息转换为hex格式 - let hex_data = crate::encode_chat_message(request.messages, &request.model) - .await - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json( - ChatError::RequestFailed("Failed to encode chat message".to_string()).to_json(), - ), - ) - })?; + // 将消息转换为hex格式 + let hex_data = super::adapter::encode_chat_message(request.messages, &request.model) + .await + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + ChatError::RequestFailed("Failed to encode chat message".to_string()).to_json(), + ), + ) + })?; - // 构建请求客户端 - let client = build_client(&auth_token, &checksum, CURSOR_API2_STREAM_CHAT); - let response = client.body(hex_data).send().await; + // 构建请求客户端 + let client = build_client(&auth_token, &checksum, CURSOR_API2_STREAM_CHAT); + let response = client.body(hex_data).send().await; - // 处理请求结果 - let response = match response { - Ok(resp) => { - // 更新请求日志为成功 - { - let mut state = state.lock().await; - state.request_logs.last_mut().unwrap().status = STATUS_SUCCESS; - } - resp - } - Err(e) => { - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED; - last_log.error = Some(e.to_string()); - } - } - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed(e.to_string()).to_json()), - )); - } - }; + // 处理请求结果 + let response = match response { + Ok(resp) => { + // 更新请求日志为成功 + { + let mut state = state.lock().await; + state.request_logs.last_mut().unwrap().status = STATUS_SUCCESS; + } + resp + } + Err(e) => { + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.status = STATUS_FAILED; + last_log.error = Some(e.to_string()); + } + } + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::RequestFailed(e.to_string()).to_json()), + )); + } + }; - // 释放活动请求计数 - { - let mut state = state.lock().await; - state.active_requests -= 1; - } + // 释放活动请求计数 + { + let mut state = state.lock().await; + state.active_requests -= 1; + } - if request.stream { - let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple()); - let full_text = Arc::new(Mutex::new(String::with_capacity(1024))); - let is_start = Arc::new(AtomicBool::new(true)); + if request.stream { + let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple()); + let full_text = Arc::new(Mutex::new(String::with_capacity(1024))); + let is_start = Arc::new(AtomicBool::new(true)); - let stream = { - // 创建新的 stream - let mut stream = response.bytes_stream(); + let stream = { + // 创建新的 stream + let mut stream = response.bytes_stream(); - let enable_stream_check = AppConfig::get_stream_check(); + let enable_stream_check = AppConfig::get_stream_check(); - if enable_stream_check { - // 检查第一个 chunk - match stream.next().await { - Some(first_chunk) => { - let chunk = first_chunk.map_err(|e| { - let error_message = format!("Failed to read response chunk: {}", e); - // 理论上,若程序正常,必定成功,因为前面判断过了 - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed(error_message).to_json()), - ) - })?; + if enable_stream_check { + // 检查第一个 chunk + match stream.next().await { + Some(first_chunk) => { + let chunk = first_chunk.map_err(|e| { + let error_message = format!("Failed to read response chunk: {}", e); + // 理论上,若程序正常,必定成功,因为前面判断过了 + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::RequestFailed(error_message).to_json()), + ) + })?; - match parse_stream_data(&chunk) { - Err(StreamError::ChatError(error)) => { - let error_respone = error.to_error_response(); - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED; - last_log.error = Some(error_respone.native_code()); - } - } - return Err(( - error_respone.status_code(), - Json(error_respone.to_common()), - )); - } - Ok(_) | Err(_) => { - // 创建一个包含第一个 chunk 的 stream - Box::pin( - futures::stream::once(async move { Ok(chunk) }).chain(stream), - ) - as Pin< - Box< - dyn Stream> + Send, - >, - > - } - } - } - None => { - // Box::pin(stream) - // as Pin> + Send>> - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED; - last_log.error = Some("Empty stream response".to_string()); - } - } - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json( - ChatError::RequestFailed("Empty stream response".to_string()) - .to_json(), - ), - )); - } - } - } else { - Box::pin(stream) - as Pin> + Send>> - } - } - .then(move |chunk| { - let response_id = response_id.clone(); - let model = request.model.clone(); - let is_start = is_start.clone(); - let full_text = full_text.clone(); - let state = state.clone(); + match parse_stream_data(&chunk) { + Err(StreamError::ChatError(error)) => { + let error_respone = error.to_error_response(); + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.status = STATUS_FAILED; + last_log.error = Some(error_respone.native_code()); + } + } + return Err(( + error_respone.status_code(), + Json(error_respone.to_common()), + )); + } + Ok(_) | Err(_) => { + // 创建一个包含第一个 chunk 的 stream + Box::pin( + futures::stream::once(async move { Ok(chunk) }).chain(stream), + ) + as Pin< + Box< + dyn Stream> + Send, + >, + > + } + } + } + None => { + // Box::pin(stream) + // as Pin> + Send>> + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.status = STATUS_FAILED; + last_log.error = Some("Empty stream response".to_string()); + } + } + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + ChatError::RequestFailed("Empty stream response".to_string()) + .to_json(), + ), + )); + } + } + } else { + Box::pin(stream) + as Pin> + Send>> + } + } + .then({ + let buffer = Arc::new(Mutex::new(Vec::new())); // 创建共享的buffer - async move { - let chunk = chunk.unwrap_or_default(); - match parse_stream_data(&chunk) { - Ok(StreamMessage::Content(texts)) => { - let mut response_data = String::new(); + move |chunk| { + let buffer = buffer.clone(); + let response_id = response_id.clone(); + let model = request.model.clone(); + let is_start = is_start.clone(); + let full_text = full_text.clone(); + let state = state.clone(); - for text in texts { - let mut text_guard = full_text.lock().await; - text_guard.push_str(&text); - let is_first = is_start.load(Ordering::SeqCst); + async move { + let chunk = chunk.unwrap_or_default(); + let mut buffer_guard = buffer.lock().await; + buffer_guard.extend_from_slice(&chunk); - let response = ChatResponse { - id: response_id.clone(), - object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), - created: chrono::Utc::now().timestamp(), - model: if is_first { Some(model.clone()) } else { None }, - choices: vec![Choice { - index: 0, - message: None, - delta: Some(Delta { - role: if is_first { - is_start.store(false, Ordering::SeqCst); - Some(Role::Assistant) - } else { - None - }, - content: Some(text), - }), - finish_reason: None, - }], - usage: None, - }; + match parse_stream_data(&buffer_guard) { + Ok(StreamMessage::Content(texts)) => { + buffer_guard.clear(); + let mut response_data = String::new(); - response_data.push_str(&format!( - "data: {}\n\n", - serde_json::to_string(&response).unwrap() - )); - } + for text in texts { + let mut text_guard = full_text.lock().await; + text_guard.push_str(&text); + let is_first = is_start.load(Ordering::SeqCst); - Ok::<_, Infallible>(Bytes::from(response_data)) - } - Ok(StreamMessage::StreamStart) => { - // 发送初始响应,包含模型信息 - let response = ChatResponse { - id: response_id.clone(), - object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), - created: chrono::Utc::now().timestamp(), - model: { - is_start.store(true, Ordering::SeqCst); - Some(model.clone()) - }, - choices: vec![Choice { - index: 0, - message: None, - delta: Some(Delta { - role: Some(Role::Assistant), - content: Some(String::new()), - }), - finish_reason: None, - }], - usage: None, - }; + let response = ChatResponse { + id: response_id.clone(), + object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), + created: chrono::Utc::now().timestamp(), + model: if is_first { Some(model.clone()) } else { None }, + choices: vec![Choice { + index: 0, + message: None, + delta: Some(Delta { + role: if is_first { + is_start.store(false, Ordering::SeqCst); + Some(Role::Assistant) + } else { + None + }, + content: Some(text), + }), + finish_reason: None, + }], + usage: None, + }; - Ok(Bytes::from(format!( - "data: {}\n\n", - serde_json::to_string(&response).unwrap() - ))) - } - Ok(StreamMessage::StreamEnd) => { - // 根据配置决定是否发送最后的 finish_reason - let include_finish_reason = AppConfig::get_stop_stream(); + response_data.push_str(&format!( + "data: {}\n\n", + serde_json::to_string(&response).unwrap() + )); + } - if include_finish_reason { - let response = ChatResponse { - id: response_id.clone(), - object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), - created: chrono::Utc::now().timestamp(), - model: None, - choices: vec![Choice { - index: 0, - message: None, - delta: Some(Delta { - role: None, - content: None, - }), - finish_reason: Some(FINISH_REASON_STOP.to_string()), - }], - usage: None, - }; - Ok(Bytes::from(format!( - "data: {}\n\ndata: [DONE]\n\n", - serde_json::to_string(&response).unwrap() - ))) - } else { - Ok(Bytes::from("data: [DONE]\n\n")) - } - } - Ok(StreamMessage::Debug(debug_prompt)) => { - if let Ok(mut state) = state.try_lock() { - if let Some(last_log) = state.request_logs.last_mut() { - last_log.prompt = Some(debug_prompt.clone()); - } - } - Ok(Bytes::new()) - } - Err(StreamError::ChatError(error)) => { - eprintln!("Stream error occurred: {}", error.to_json()); - Ok(Bytes::new()) - } - Err(e) => { - eprintln!("[警告] Stream error: {}", e); - Ok(Bytes::new()) - } - } - } - }); + Ok::<_, Infallible>(Bytes::from(response_data)) + } + Ok(StreamMessage::StreamStart) => { + buffer_guard.clear(); + // 发送初始响应,包含模型信息 + let response = ChatResponse { + id: response_id.clone(), + object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), + created: chrono::Utc::now().timestamp(), + model: { + is_start.store(true, Ordering::SeqCst); + Some(model.clone()) + }, + choices: vec![Choice { + index: 0, + message: None, + delta: Some(Delta { + role: Some(Role::Assistant), + content: Some(String::new()), + }), + finish_reason: None, + }], + usage: None, + }; - Ok(Response::builder() - .header("Cache-Control", "no-cache") - .header("Connection", "keep-alive") - .header(HEADER_NAME_CONTENT_TYPE, "text/event-stream") - .body(Body::from_stream(stream)) - .unwrap()) - } else { - // 非流式响应 - let mut full_text = String::with_capacity(1024); // 预分配合适的容量 - let mut stream = response.bytes_stream(); - let mut prompt = None; + Ok(Bytes::from(format!( + "data: {}\n\n", + serde_json::to_string(&response).unwrap() + ))) + } + Ok(StreamMessage::StreamEnd) => { + buffer_guard.clear(); + // 根据配置决定是否发送最后的 finish_reason + let include_finish_reason = AppConfig::get_stop_stream(); - while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json( - ChatError::RequestFailed(format!("Failed to read response chunk: {}", e)) - .to_json(), - ), - ) - })?; + if include_finish_reason { + let response = ChatResponse { + id: response_id.clone(), + object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), + created: chrono::Utc::now().timestamp(), + model: None, + choices: vec![Choice { + index: 0, + message: None, + delta: Some(Delta { + role: None, + content: None, + }), + finish_reason: Some(FINISH_REASON_STOP.to_string()), + }], + usage: None, + }; + Ok(Bytes::from(format!( + "data: {}\n\ndata: [DONE]\n\n", + serde_json::to_string(&response).unwrap() + ))) + } else { + Ok(Bytes::from("data: [DONE]\n\n")) + } + } + Ok(StreamMessage::Incomplete) => { + // 保持buffer中的数据以待下一个chunk + Ok(Bytes::new()) + } + Ok(StreamMessage::Debug(debug_prompt)) => { + buffer_guard.clear(); + if let Ok(mut state) = state.try_lock() { + if let Some(last_log) = state.request_logs.last_mut() { + last_log.prompt = Some(debug_prompt.clone()); + } + } + Ok(Bytes::new()) + } + Err(StreamError::ChatError(error)) => { + buffer_guard.clear(); + eprintln!("Stream error occurred: {}", error.to_json()); + Ok(Bytes::new()) + } + Err(e) => { + buffer_guard.clear(); + eprintln!("[警告] Stream error: {}", e); + Ok(Bytes::new()) + } + } + } + } + }); - match parse_stream_data(&chunk) { - Ok(StreamMessage::Content(texts)) => { - for text in texts { - full_text.push_str(&text); - } - } - Ok(StreamMessage::Debug(debug_prompt)) => { - prompt = Some(debug_prompt); - } - Ok(StreamMessage::StreamStart) | Ok(StreamMessage::StreamEnd) => {} - Err(StreamError::ChatError(error)) => { - return Err(( - StatusCode::from_u16(error.error.details[0].debug.status_code()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), - Json(error.to_error_response().to_common()), - )); - } - Err(_) => continue, - } - } + Ok(Response::builder() + .header("Cache-Control", "no-cache") + .header("Connection", "keep-alive") + .header(HEADER_NAME_CONTENT_TYPE, "text/event-stream") + .body(Body::from_stream(stream)) + .unwrap()) + } else { + // 非流式响应 + let mut full_text = String::with_capacity(1024); // 预分配合适的容量 + let mut stream = response.bytes_stream(); + let mut prompt = None; - // 检查响应是否为空 - if full_text.is_empty() { - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED; - last_log.error = Some("Empty response received".to_string()); - if let Some(p) = prompt { - last_log.prompt = Some(p); - } - } - } - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed("Empty response received".to_string()).to_json()), - )); - } + let mut buffer = Vec::new(); + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + ChatError::RequestFailed(format!("Failed to read response chunk: {}", e)) + .to_json(), + ), + ) + })?; - // 更新请求日志提示词 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.prompt = prompt; - } - } + buffer.extend_from_slice(&chunk); - let response_data = ChatResponse { - id: format!("chatcmpl-{}", Uuid::new_v4().simple()), - object: OBJECT_CHAT_COMPLETION.to_string(), - created: chrono::Utc::now().timestamp(), - model: Some(request.model), - choices: vec![Choice { - index: 0, - message: Some(Message { - role: Role::Assistant, - content: MessageContent::Text(full_text), - }), - delta: None, - finish_reason: Some(FINISH_REASON_STOP.to_string()), - }], - usage: Some(Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }), - }; + match parse_stream_data(&buffer) { + Ok(StreamMessage::Content(texts)) => { + for text in texts { + full_text.push_str(&text); + } + buffer.clear(); + } + Ok(StreamMessage::Incomplete) => { + continue; + } + Ok(StreamMessage::Debug(debug_prompt)) => { + prompt = Some(debug_prompt); + buffer.clear(); + } + Ok(StreamMessage::StreamStart) | Ok(StreamMessage::StreamEnd) => { + buffer.clear(); + } + Err(StreamError::ChatError(error)) => { + return Err(( + StatusCode::from_u16(error.error.details[0].debug.status_code()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + Json(error.to_error_response().to_common()), + )); + } + Err(_) => { + buffer.clear(); + continue; + } + } + } - Ok(Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_string(&response_data).unwrap())) - .unwrap()) - } + let prompt_tokens = prompt.as_ref().map(|p| p.len() as u32).unwrap_or(0); + let completion_tokens = full_text.len() as u32; + let total_tokens = prompt_tokens + completion_tokens; + + // 检查响应是否为空 + if full_text.is_empty() { + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.status = STATUS_FAILED; + last_log.error = Some("Empty response received".to_string()); + if let Some(p) = prompt { + last_log.prompt = Some(p); + } + } + } + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::RequestFailed("Empty response received".to_string()).to_json()), + )); + } + + // 更新请求日志提示词 + { + let mut state = state.lock().await; + if let Some(last_log) = state.request_logs.last_mut() { + last_log.prompt = prompt; + } + } + + let response_data = ChatResponse { + id: format!("chatcmpl-{}", Uuid::new_v4().simple()), + object: OBJECT_CHAT_COMPLETION.to_string(), + created: chrono::Utc::now().timestamp(), + model: Some(request.model), + choices: vec![Choice { + index: 0, + message: Some(Message { + role: Role::Assistant, + content: MessageContent::Text(full_text), + }), + delta: None, + finish_reason: Some(FINISH_REASON_STOP.to_string()), + }], + usage: Some(Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }), + }; + + Ok(Response::builder() + .header(HEADER_NAME_CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&response_data).unwrap())) + .unwrap()) + } } diff --git a/src/chat/stream.rs b/src/chat/stream.rs index f91d474..c564b29 100644 --- a/src/chat/stream.rs +++ b/src/chat/stream.rs @@ -20,6 +20,8 @@ fn decompress_gzip(data: &[u8]) -> Option> { } pub enum StreamMessage { + // 未完成 + Incomplete, // 调试 Debug(String), // 流开始标志 b"\0\0\0\0\0" @@ -65,7 +67,7 @@ pub fn parse_stream_data(data: &[u8]) -> Result { // 检查剩余数据长度是否足够 if offset + 5 + msg_len > data.len() { - break; + return Ok(StreamMessage::Incomplete); } let msg_data = &data[offset + 5..offset + 5 + msg_len]; @@ -74,12 +76,14 @@ pub fn parse_stream_data(data: &[u8]) -> Result { // 文本消息 0 => { if let Ok(response) = StreamChatResponse::decode(msg_data) { + // crate::debug_println!("[text] StreamChatResponse: {:?}", response); if !response.text.is_empty() { messages.push(response.text); } else { // println!("[text] StreamChatResponse: {:?}", response); return Ok(StreamMessage::Debug( response.filled_prompt.unwrap_or_default(), + // response.is_using_slow_request, )); } } @@ -88,12 +92,14 @@ pub fn parse_stream_data(data: &[u8]) -> Result { 1 => { if let Some(text) = decompress_gzip(msg_data) { let response = StreamChatResponse::decode(&text[..]).unwrap_or_default(); + // crate::debug_println!("[gzip] StreamChatResponse: {:?}", response); if !response.text.is_empty() { messages.push(response.text); } else { // println!("[gzip] StreamChatResponse: {:?}", response); return Ok(StreamMessage::Debug( response.filled_prompt.unwrap_or_default(), + // response.is_using_slow_request, )); } } diff --git a/src/common.rs b/src/common.rs index ff92946..757ddf3 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1 +1,3 @@ -pub mod models; \ No newline at end of file +pub mod models; +pub mod utils; +pub mod client; diff --git a/src/app/client.rs b/src/common/client.rs similarity index 69% rename from src/app/client.rs rename to src/common/client.rs index 1d4947c..70f888e 100644 --- a/src/app/client.rs +++ b/src/common/client.rs @@ -1,4 +1,8 @@ -use crate::app::constant::*; +use crate::app::constant::{ + AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, + CURSOR_API2_BASE_URL, CURSOR_API2_HOST, CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION, + HEADER_NAME_CONTENT_TYPE, +}; use reqwest::Client; use uuid::Uuid; @@ -7,15 +11,18 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest let client = Client::new(); let trace_id = Uuid::new_v4().to_string(); let content_type = if endpoint == CURSOR_API2_STREAM_CHAT { - CONTENT_TYPE_CONNECT_PROTO + CONTENT_TYPE_CONNECT_PROTO } else { - CONTENT_TYPE_PROTO + CONTENT_TYPE_PROTO }; client .post(format!("{}{}", CURSOR_API2_BASE_URL, endpoint)) .header(HEADER_NAME_CONTENT_TYPE, content_type) - .header(HEADER_NAME_AUTHORIZATION, format!("{}{}", AUTHORIZATION_BEARER_PREFIX, auth_token)) + .header( + HEADER_NAME_AUTHORIZATION, + format!("{}{}", AUTHORIZATION_BEARER_PREFIX, auth_token), + ) .header("connect-accept-encoding", "gzip,br") .header("connect-protocol-version", "1") .header("user-agent", "connect-es/1.6.1") diff --git a/src/common/models.rs b/src/common/models.rs index 76d5ae2..bd00ac0 100644 --- a/src/common/models.rs +++ b/src/common/models.rs @@ -1,6 +1,7 @@ pub mod error; pub mod health; pub mod config; +pub mod usage; use config::ConfigData; diff --git a/src/common/models/config.rs b/src/common/models/config.rs index 9fdbdc6..8c14448 100644 --- a/src/common/models/config.rs +++ b/src/common/models/config.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::{PageContent, UsageCheck, VisionAbility}; +use crate::app::model::{PageContent, UsageCheck, VisionAbility}; #[derive(Serialize)] pub struct ConfigData { diff --git a/src/common/models/usage.rs b/src/common/models/usage.rs new file mode 100644 index 0000000..9a61e05 --- /dev/null +++ b/src/common/models/usage.rs @@ -0,0 +1,15 @@ +use serde::Serialize; + +#[derive(Serialize)] +pub enum GetUserInfo { + #[serde(rename = "usage")] + Usage(UserUsageInfo), + #[serde(rename = "error")] + Error(String), +} + +#[derive(Serialize, Clone)] +pub struct UserUsageInfo { + pub fast_requests: u32, + pub max_fast_requests: u32, +} diff --git a/src/common/utils.rs b/src/common/utils.rs new file mode 100644 index 0000000..9e1a402 --- /dev/null +++ b/src/common/utils.rs @@ -0,0 +1,50 @@ +mod checksum; +pub use checksum::*; +pub mod tokens; +use prost::Message as _; + +use crate::{app::constant::CURSOR_API2_GET_USER_INFO, chat::aiserver::v1::GetUserInfoResponse}; + +use super::models::usage::UserUsageInfo; + +pub fn parse_bool_from_env(key: &str, default: bool) -> bool { + std::env::var(key) + .ok() + .map(|v| match v.to_lowercase().as_str() { + "true" | "1" => true, + "false" | "0" => false, + _ => default, + }) + .unwrap_or(default) +} + +pub fn parse_string_from_env(key: &str, default: &str) -> String { + std::env::var(key).unwrap_or_else(|_| default.to_string()) +} + +pub fn i32_to_u32(value: i32) -> u32 { + if value < 0 { + 0 + } else { + value as u32 + } +} + +pub async fn get_user_usage(auth_token: &str, checksum: &str) -> Option { + // 构建请求客户端 + let client = super::client::build_client(auth_token, checksum, CURSOR_API2_GET_USER_INFO); + let response = client + .body(Vec::new()) + .send() + .await + .ok()? + .bytes() + .await + .ok()?; + let user_info = GetUserInfoResponse::decode(response.as_ref()).ok()?; + + user_info.usage.map(|user_usage| UserUsageInfo { + fast_requests: i32_to_u32(user_usage.gpt4_requests), + max_fast_requests: i32_to_u32(user_usage.gpt4_max_requests), + }) +} diff --git a/src/app/utils/checksum.rs b/src/common/utils/checksum.rs similarity index 100% rename from src/app/utils/checksum.rs rename to src/common/utils/checksum.rs diff --git a/src/common/utils/tokens.rs b/src/common/utils/tokens.rs new file mode 100644 index 0000000..acd1d7e --- /dev/null +++ b/src/common/utils/tokens.rs @@ -0,0 +1,144 @@ +use crate::{ + app::{ + constant::EMPTY_STRING, + model::TokenInfo, + lazy::{TOKEN_FILE, TOKEN_LIST_FILE}, + }, + common::utils::{generate_checksum, generate_hash}, +}; + +// 规范化文件内容并写入 +fn normalize_and_write(content: &str, file_path: &str) -> String { + let normalized = content.replace("\r\n", "\n"); + if normalized != content { + if let Err(e) = std::fs::write(file_path, &normalized) { + eprintln!("警告: 无法更新规范化的文件: {}", e); + } + } + normalized +} + +// 解析token和别名 +fn parse_token_alias(token_part: &str, line: &str) -> Option<(String, Option)> { + match token_part.split("::").collect::>() { + parts if parts.len() == 1 => Some((parts[0].to_string(), None)), + parts if parts.len() == 2 => Some((parts[1].to_string(), Some(parts[0].to_string()))), + _ => { + eprintln!("警告: 忽略无效的行: {}", line); + None + } + } +} + +// Token 加载函数 +pub fn load_tokens() -> Vec { + let token_file = TOKEN_FILE.as_str(); + let token_list_file = TOKEN_LIST_FILE.as_str(); + + // 确保文件存在 + for file in [&token_file, &token_list_file] { + if !std::path::Path::new(file).exists() { + if let Err(e) = std::fs::write(file, EMPTY_STRING) { + eprintln!("警告: 无法创建文件 '{}': {}", file, e); + } + } + } + + // 读取和规范化 token 文件 + let token_entries = match std::fs::read_to_string(&token_file) { + Ok(content) => { + let normalized = normalize_and_write(&content, &token_file); + normalized + .lines() + .filter_map(|line| { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + return None; + } + parse_token_alias(line, line) + }) + .collect::>() + } + Err(e) => { + eprintln!("警告: 无法读取token文件 '{}': {}", token_file, e); + Vec::new() + } + }; + + // 读取和规范化 token-list 文件 + let mut token_map: std::collections::HashMap)> = + match std::fs::read_to_string(&token_list_file) { + Ok(content) => { + let normalized = normalize_and_write(&content, &token_list_file); + normalized + .lines() + .filter_map(|line| { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + return None; + } + + let parts: Vec<&str> = line.split(',').collect(); + match parts[..] { + [token_part, checksum] => { + let (token, alias) = parse_token_alias(token_part, line)?; + Some((token, (checksum.to_string(), alias))) + } + _ => { + eprintln!("警告: 忽略无效的token-list行: {}", line); + None + } + } + }) + .collect() + } + Err(e) => { + eprintln!("警告: 无法读取token-list文件: {}", e); + std::collections::HashMap::new() + } + }; + + // 更新或添加新token + for (token, alias) in token_entries { + if let Some((_, existing_alias)) = token_map.get(&token) { + // 只在alias不同时更新已存在的token + if alias != *existing_alias { + if let Some((checksum, _)) = token_map.get(&token) { + token_map.insert(token.clone(), (checksum.clone(), alias)); + } + } + } else { + // 为新token生成checksum + let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); + token_map.insert(token, (checksum, alias)); + } + } + + // 更新 token-list 文件 + let token_list_content = token_map + .iter() + .map(|(token, (checksum, alias))| { + if let Some(alias) = alias { + format!("{}::{},{}", alias, token, checksum) + } else { + format!("{},{}", token, checksum) + } + }) + .collect::>() + .join("\n"); + + if let Err(e) = std::fs::write(&token_list_file, token_list_content) { + eprintln!("警告: 无法更新token-list文件: {}", e); + } + + // 转换为 TokenInfo vector + token_map + .into_iter() + .map(|(token, (checksum, alias))| TokenInfo { + token, + checksum, + alias, + usage: None, + }) + .collect() +} diff --git a/src/main.rs b/src/main.rs index 822c419..b7a4389 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,35 +1,33 @@ +mod app; +mod chat; +mod common; + +use app::{ + config::handle_config_update, + constant::{ + EMPTY_STRING, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, + ROUTE_GET_CHECKSUM, ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, + ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, + ROUTE_TOKENINFO_PATH, ROUTE_UPDATE_TOKENINFO_PATH, + }, + model::*, + lazy::{AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH}, +}; use axum::{ - body::Body, - extract::{Path, State}, - http::{HeaderMap, StatusCode}, - response::{IntoResponse, Response}, routing::{get, post}, - Json, Router, + Router, }; -use chrono::Local; -use cursor_api::{ - app::{ - config::handle_config_update, - constant::*, - models::*, - statics::*, - token::{ - get_user_info, handle_get_checksum, handle_get_tokeninfo, handle_update_tokeninfo, - handle_update_tokeninfo_post, load_tokens, - }, - utils::{parse_bool_from_env, parse_string_from_env}, - }, - chat::{ - constant::AVAILABLE_MODELS, - service::{handle_chat, handle_models}, - }, - common::models::{ - health::{CpuInfo, HealthCheckResponse, MemoryInfo, SystemInfo, SystemStats}, - ApiStatus, +use chat::{ + route::{ + get_user_info, handle_about, handle_config_page, handle_env_example, handle_get_checksum, + handle_get_tokeninfo, handle_health, handle_logs, handle_logs_post, handle_readme, + handle_root, handle_static, handle_tokeninfo_page, handle_update_tokeninfo, + handle_update_tokeninfo_post, }, + service::{handle_chat, handle_models}, }; +use common::utils::{parse_bool_from_env, parse_string_from_env, tokens::load_tokens}; use std::sync::Arc; -use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System}; use tokio::sync::Mutex; use tower_http::cors::CorsLayer; @@ -48,7 +46,7 @@ async fn main() { // 加载环境变量 dotenvy::dotenv().ok(); - if get_auth_token() == EMPTY_STRING { + if AUTH_TOKEN.is_empty() { panic!("AUTH_TOKEN must be set") }; @@ -98,258 +96,7 @@ async fn main() { let addr = format!("0.0.0.0:{}", port); println!("服务器运行在端口 {}", port); println!("当前版本: v{}", PKG_VERSION); - // if !std::env::args().any(|arg| arg == "--no-instruction") { - // println!(include_str!("../start_instruction")); - // } let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); } - -// 根路由处理 -async fn handle_root() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTE_ROOT_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .status(StatusCode::TEMPORARY_REDIRECT) - .header(HEADER_NAME_LOCATION, ROUTE_HEALTH_PATH) - .body(Body::empty()) - .unwrap(), - PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(Body::from(content.clone())) - .unwrap(), - PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(Body::from(content.clone())) - .unwrap(), - } -} - -async fn handle_health(State(state): State>>) -> Json { - let start_time = get_start_time(); - - // 创建系统信息实例,只监控 CPU 和内存 - let mut sys = System::new_with_specifics( - RefreshKind::nothing() - .with_memory(MemoryRefreshKind::everything()) - .with_cpu(CpuRefreshKind::everything()), - ); - - std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL); - - // 刷新 CPU 和内存信息 - sys.refresh_memory(); - sys.refresh_cpu_usage(); - - let pid = std::process::id() as usize; - let process = sys.process(pid.into()); - - // 获取内存信息 - let memory = process.map(|p| p.memory()).unwrap_or(0); - - // 获取 CPU 使用率 - let cpu_usage = sys.global_cpu_usage(); - - let state = state.lock().await; - let uptime = (Local::now() - start_time).num_seconds(); - - Json(HealthCheckResponse { - status: ApiStatus::Healthy, - version: PKG_VERSION, - uptime, - stats: SystemStats { - started: start_time.to_string(), - total_requests: state.total_requests, - active_requests: state.active_requests, - system: SystemInfo { - memory: MemoryInfo { - rss: memory, // 物理内存使用量(字节) - }, - cpu: CpuInfo { - usage: cpu_usage, // CPU 使用率(百分比) - }, - }, - }, - models: AVAILABLE_MODELS.iter().map(|m| m.id).collect::>(), - endpoints: vec![ - ROUTE_CHAT_PATH.as_str(), - ROUTE_MODELS_PATH.as_str(), - ROUTE_GET_CHECKSUM, - ROUTE_TOKENINFO_PATH, - ROUTE_UPDATE_TOKENINFO_PATH, - ROUTE_GET_TOKENINFO_PATH, - ROUTE_LOGS_PATH, - ROUTE_GET_USER_INFO_PATH, - ROUTE_ENV_EXAMPLE_PATH, - ROUTE_CONFIG_PATH, - ROUTE_STATIC_PATH, - ROUTE_ABOUT_PATH, - ROUTE_README_PATH - - ], - }) -} - -async fn handle_tokeninfo_page() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTE_TOKENINFO_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(include_str!("../static/tokeninfo.min.html").to_string()) - .unwrap(), - PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(content.clone()) - .unwrap(), - PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(content.clone()) - .unwrap(), - } -} - -// 日志处理 -async fn handle_logs() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTE_LOGS_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(Body::from( - include_str!("../static/logs.min.html").to_string(), - )) - .unwrap(), - PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(Body::from(content.clone())) - .unwrap(), - PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(Body::from(content.clone())) - .unwrap(), - } -} - -async fn handle_logs_post( - State(state): State>>, - headers: HeaderMap, -) -> Result, StatusCode> { - let auth_token = get_auth_token(); - - // 验证 AUTH_TOKEN - let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(StatusCode::UNAUTHORIZED)?; - - if auth_header != auth_token { - return Err(StatusCode::UNAUTHORIZED); - } - - let state = state.lock().await; - Ok(Json(LogsResponse { - status: ApiStatus::Success, - total: state.request_logs.len(), - logs: state.request_logs.clone(), - timestamp: Local::now().to_string(), - })) -} - -#[derive(serde::Serialize)] -struct LogsResponse { - status: ApiStatus, - total: usize, - logs: Vec, - timestamp: String, -} - -async fn handle_env_example() -> impl IntoResponse { - Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(include_str!("../.env.example").to_string()) - .unwrap() -} - -// 配置页面处理函数 -async fn handle_config_page() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTE_CONFIG_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(include_str!("../static/config.min.html").to_string()) - .unwrap(), - PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(content.clone()) - .unwrap(), - PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(content.clone()) - .unwrap(), - } -} - -async fn handle_static(Path(path): Path) -> impl IntoResponse { - match path.as_str() { - "shared-styles.css" => { - match AppConfig::get_page_content(ROUTE_SHARED_STYLES_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) - .body(include_str!("../static/shared-styles.min.css").to_string()) - .unwrap(), - PageContent::Text(content) | PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) - .body(content.clone()) - .unwrap(), - } - } - "shared.js" => { - match AppConfig::get_page_content(ROUTE_SHARED_JS_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) - .body(include_str!("../static/shared.min.js").to_string()) - .unwrap(), - PageContent::Text(content) | PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) - .body(content.clone()) - .unwrap(), - } - } - _ => Response::builder() - .status(StatusCode::NOT_FOUND) - .body("Not found".to_string()) - .unwrap(), - } -} - -async fn handle_about() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTE_ABOUT_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(include_str!("../static/readme.min.html").to_string()) - .unwrap(), - PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(content.clone()) - .unwrap(), - PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(content.clone()) - .unwrap(), - } -} - -async fn handle_readme() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTE_README_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .status(StatusCode::TEMPORARY_REDIRECT) - .header(HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH) - .body(Body::empty()) - .unwrap(), - PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(Body::from(content.clone())) - .unwrap(), - PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(Body::from(content.clone())) - .unwrap(), - } -} diff --git a/static/shared-styles.css b/static/shared-styles.css index 23556b0..7bcd7ed 100644 --- a/static/shared-styles.css +++ b/static/shared-styles.css @@ -1,12 +1,51 @@ :root { + /* 基础颜色变量 */ --primary-color: #2196F3; --primary-dark: #1976D2; + --primary-color-alpha: rgba(33, 150, 243, 0.1); --success-color: #4CAF50; --error-color: #F44336; --background-color: #F5F5F5; --card-background: #FFFFFF; + --text-primary: #333333; + --text-secondary: #757575; + --border-color: #e0e0e0; + --disabled-bg: #f5f5f5; + + /* 布局变量 */ --border-radius: 8px; --spacing: 20px; + + /* 动画变量 */ + --transition-fast: 0.2s; + --transition-slow: 0.3s; +} + +/* 暗色模式 */ +@media (prefers-color-scheme: dark) { + :root { + --primary-color: #90CAF9; + --primary-dark: #64B5F6; + --background-color: #121212; + --card-background: #1e1e1e; + --text-primary: #e0e0e0; + --text-secondary: #9e9e9e; + --border-color: #404040; + --disabled-bg: #2d2d2d; + color-scheme: dark; + } +} + +/* 基础样式 */ +html { + scroll-behavior: smooth; + box-sizing: border-box; +} + +*, +*:before, +*:after { + box-sizing: inherit; } body { @@ -15,61 +54,111 @@ body { margin: 0 auto; padding: var(--spacing); background: var(--background-color); - color: #333; + color: var(--text-primary); line-height: 1.6; } +/* 容器样式 */ .container { background: var(--card-background); padding: var(--spacing); border-radius: var(--border-radius); box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); margin-bottom: var(--spacing); + transition: transform var(--transition-fast); } +.container:hover { + transform: translateY(-2px); +} + +/* 标题样式 */ h1, h2, h3 { - color: #1a1a1a; + color: var(--text-primary); margin-top: 0; + line-height: 1.2; } +/* 表单元素样式 */ .form-group { - margin-bottom: 15px; + margin-bottom: 20px; } +/* 标签样式 */ label { display: block; margin-bottom: 8px; font-weight: 500; + color: var(--text-primary); } -/* input[type="text"], 由于minify.js会删除input[type="text"],所以改为input */ input, -input[type="password"], select, -textarea { +textarea, +.form-control { width: 100%; padding: 10px 12px; - border: 1px solid #ddd; + border: 1px solid var(--border-color); border-radius: 4px; + background: var(--card-background); + color: var(--text-primary); font-size: 14px; - transition: border-color 0.2s, box-shadow 0.2s; - background: white; - color: #333; + line-height: 1.5; + transition: all var(--transition-fast); appearance: none; } -/* input[type="text"]:focus, 由于minify.js会删除input[type="text"]:focus,所以改为input:focus */ -input:focus, -input[type="password"]:focus, -select:focus, -textarea:focus { - border-color: var(--primary-color); - outline: none; - box-shadow: 0 0 0 2px rgba(33, 150, 243, 0.2); +input[type="checkbox"] { + width: auto; + margin-right: 8px; + cursor: pointer; + appearance: auto; } +input[type="checkbox"] + label { + cursor: pointer; + color: var(--text-primary); + user-select: none; +} + +input:hover, +select:hover, +textarea:hover, +.form-control:hover { + border-color: var(--primary-color); +} + +input:focus, +select:focus, +textarea:focus, +.form-control:focus { + border-color: var(--primary-color); + box-shadow: 0 0 0 2px var(--primary-color-alpha); + outline: none; +} + +/* 禁用状态 */ +input:disabled, +select:disabled, +textarea:disabled, +.form-control:disabled { + background-color: var(--disabled-bg); + border-color: var(--border-color); + cursor: not-allowed; + opacity: 0.7; +} + +/* 错误状态 */ +input.error, +select.error, +textarea.error, +.form-control.error { + border-color: var(--error-color); +} + +/* Select 特殊样式 */ select { background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='%23757575'%3E%3Cpath d='M7 10l5 5 5-5H7z'/%3E%3C/svg%3E"); background-repeat: no-repeat; @@ -78,63 +167,85 @@ select { padding-right: 36px; } +/* Textarea 特殊样式 */ textarea { min-height: 150px; - font-family: monospace; resize: vertical; + font-family: monospace; + line-height: 1.4; } -.button-group { - display: flex; - gap: 10px; - margin: var(--spacing) 0; -} - +/* 按钮基础样式 */ button { + display: inline-flex; + align-items: center; + justify-content: center; + min-height: 44px; + padding: 8px 24px; + border: none; + border-radius: var(--border-radius); background: var(--primary-color); color: white; - padding: 8px 16px; - border: none; - border-radius: 4px; - cursor: pointer; + font-size: 16px; font-weight: 500; - transition: background-color 0.2s, transform 0.1s; + text-align: center; + text-decoration: none; + cursor: pointer; + transition: all var(--transition-fast); + user-select: none; + -webkit-tap-highlight-color: transparent; } +/* 按钮状态 */ button:hover { background: var(--primary-dark); + transform: translateY(-1px); + box-shadow: 0 4px 12px var(--primary-color-alpha); } button:active { transform: translateY(1px); } +button:disabled { + background: var(--disabled-bg); + color: var(--text-secondary); + cursor: not-allowed; + transform: none; + box-shadow: none; +} + +/* 次要按钮样式 */ button.secondary { - background: #757575; + background: var(--text-secondary); } -button.secondary:hover { - background: #616161; +/* 按钮组 */ +.button-group { + display: flex; + gap: 10px; + margin: var(--spacing) 0; } +/* 消息提示 */ .message { padding: 12px; border-radius: var(--border-radius); margin: 10px 0; + border: 1px solid transparent; } .success { - background: #E8F5E9; - color: #2E7D32; - border: 1px solid #A5D6A7; + background: var(--success-color); + color: #fff; } .error { - background: #FFEBEE; - color: #C62828; - border: 1px solid #FFCDD2; + background: var(--error-color); + color: #fff; } +/* 表格样式 */ table { width: 100%; border-collapse: collapse; @@ -148,7 +259,7 @@ th, td { padding: 12px; text-align: left; - border-bottom: 1px solid #eee; + border-bottom: 1px solid var(--text-secondary); } th { @@ -158,15 +269,53 @@ th { } tr:nth-child(even) { - background: #f8f9fa; + background: rgba(0, 0, 0, 0.02); } tr:hover { - background: #f1f3f4; + background: rgba(0, 0, 0, 0.04); +} + +/* 辅助类 */ +.visually-hidden { + position: absolute; + width: 1px; + height: 1px; + padding: 0; + margin: -1px; + overflow: hidden; + clip: rect(0, 0, 0, 0); + border: 0; +} + +.text-center { + text-align: center; +} + +.help-text { + margin-top: 4px; + font-size: 14px; + color: var(--text-secondary); +} + +.error-text { + color: var(--error-color); +} + +.mt-0 { + margin-top: 0; +} + +.mb-0 { + margin-bottom: 0; } /* 响应式设计 */ @media (max-width: 768px) { + :root { + --spacing: 16px; + } + body { padding: 10px; } @@ -175,8 +324,27 @@ tr:hover { flex-direction: column; } + button { + width: 100%; + padding: 12px 20px; + } + + input, + select, + textarea, + .form-control { + font-size: 16px; + padding: 14px 16px; + } + table { display: block; overflow-x: auto; + -webkit-overflow-scrolling: touch; + } + + th, + td { + white-space: nowrap; } } \ No newline at end of file