diff --git a/Dockerfile b/Dockerfile index b397578..1d8a2c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ WORKDIR /app RUN apt-get update && apt-get install -y --no-install-recommends build-essential protobuf-compiler nodejs npm musl-tools && rm -rf /var/lib/apt/lists/* && case "$TARGETARCH" in amd64) rustup target add x86_64-unknown-linux-musl ;; arm64) rustup target add aarch64-unknown-linux-musl ;; *) echo "Unsupported architecture for rustup: $TARGETARCH" && exit 1 ;; esac COPY . . -RUN case "$TARGETARCH" in amd64) TARGET_TRIPLE="x86_64-unknown-linux-musl"; TARGET_CPU="x86-64-v3" ;; arm64) TARGET_TRIPLE="aarch64-unknown-linux-musl"; TARGET_CPU="neoverse-n1" ;; *) echo "Unsupported architecture: $TARGETARCH" && exit 1 ;; esac && cargo build --bin cursor-api --release --target=$TARGET_TRIPLE -- -C link-arg=-s -C target-feature=+crt-static -C target-cpu=$TARGET_CPU -A unused && cp target/$TARGET_TRIPLE/release/cursor-api /app/cursor-api +RUN case "$TARGETARCH" in amd64) TARGET_TRIPLE="x86_64-unknown-linux-musl"; TARGET_CPU="x86-64-v3" ;; arm64) TARGET_TRIPLE="aarch64-unknown-linux-musl"; TARGET_CPU="neoverse-n1" ;; *) echo "Unsupported architecture: $TARGETARCH" && exit 1 ;; esac && RUSTFLAGS="-C link-arg=-s -C target-feature=+crt-static -C target-cpu=$TARGET_CPU -A unused" cargo build --bin cursor-api --release --target=$TARGET_TRIPLE && cp target/$TARGET_TRIPLE/release/cursor-api /app/cursor-api # 运行阶段 FROM scratch diff --git a/src/app/constant.rs b/src/app/constant.rs index 90fbd55..5db535e 100644 --- a/src/app/constant.rs +++ b/src/app/constant.rs @@ -229,7 +229,7 @@ pub fn init_thinking_tags() { return; } - let tag = crate::common::utils::parse_string_from_env("THINKING_TAG", DEFAULT_THINKING_TAG); + let tag = crate::common::utils::parse_from_env("THINKING_TAG", DEFAULT_THINKING_TAG); if tag == DEFAULT_THINKING_TAG { return; diff --git a/src/app/constant/header/version.rs b/src/app/constant/header/version.rs index ff44d0b..ebc42c8 100644 --- a/src/app/constant/header/version.rs +++ b/src/app/constant/header/version.rs @@ -102,7 +102,7 @@ pub fn header_value_ua_cursor_latest() -> http::header::HeaderValue { pub fn initialize_cursor_version() { use ::core::ops::Deref as _; - let version = crate::common::utils::parse_string_from_env( + let version = crate::common::utils::parse_from_env( ENV_CURSOR_CLIENT_VERSION, DEFAULT_CLIENT_VERSION, ); diff --git a/src/app/lazy.rs b/src/app/lazy.rs index e57e76e..5a9f4e0 100644 --- a/src/app/lazy.rs +++ b/src/app/lazy.rs @@ -1,5 +1,11 @@ pub mod log; +use ::std::{ + borrow::Cow, + path::PathBuf, + sync::{LazyLock, OnceLock}, +}; + use super::{ constant::{ CURSOR_API2_HOST, CURSOR_API4_HOST, CURSOR_GCPP_ASIA_HOST, CURSOR_GCPP_EU_HOST, @@ -7,12 +13,7 @@ use super::{ }, model::{DateTime, GcppHost}, }; -use crate::common::utils::{parse_bool_from_env, parse_string_from_env, parse_usize_from_env}; -use std::{ - borrow::Cow, - path::PathBuf, - sync::{LazyLock, OnceLock}, -}; +use crate::common::utils::parse_from_env; macro_rules! def_pub_static { // 基础版本:直接存储 String @@ -23,7 +24,7 @@ macro_rules! def_pub_static { // 环境变量版本 ($name:ident,env: $env_key:expr,default: $default:expr) => { pub static $name: LazyLock> = - LazyLock::new(|| parse_string_from_env($env_key, $default)); + LazyLock::new(|| parse_from_env($env_key, $default)); }; } @@ -38,7 +39,7 @@ pub fn get_start_time() -> &'static chrono::NaiveDateTime { pub static GENERAL_TIMEZONE: LazyLock = LazyLock::new(|| { use std::str::FromStr as _; - let tz = parse_string_from_env("GENERAL_TIMEZONE", EMPTY_STRING); + let tz = parse_from_env("GENERAL_TIMEZONE", EMPTY_STRING); if tz.is_empty() { __eprintln!( "未配置时区,请在环境变量GENERAL_TIMEZONE中设置,格式如'Asia/Shanghai'\n将使用默认时区: Asia/Shanghai" @@ -64,7 +65,7 @@ pub fn get_default_instructions(now_with_tz: chrono::DateTime) -> } pub static GENERAL_GCPP_HOST: LazyLock = LazyLock::new(|| { - let gcpp_host = parse_string_from_env("GENERAL_GCPP_HOST", EMPTY_STRING); + let gcpp_host = parse_from_env("GENERAL_GCPP_HOST", EMPTY_STRING); let gcpp_host = gcpp_host.trim(); if gcpp_host.is_empty() { __eprintln!( @@ -276,7 +277,7 @@ def_cursor_api_url!( ); static DATA_DIR: LazyLock = LazyLock::new(|| { - let data_dir = parse_string_from_env("DATA_DIR", "data"); + let data_dir = parse_from_env("DATA_DIR", "data"); let path = std::env::current_exe() .ok() .and_then(|exe_path| exe_path.parent().map(|p| p.to_path_buf())) @@ -306,7 +307,7 @@ const DEFAULT_TCP_KEEPALIVE: usize = 90; const MAX_TCP_KEEPALIVE: u64 = 600; pub static TCP_KEEPALIVE: LazyLock = LazyLock::new(|| { - let keepalive = parse_usize_from_env("TCP_KEEPALIVE", DEFAULT_TCP_KEEPALIVE); + let keepalive = parse_from_env("TCP_KEEPALIVE", DEFAULT_TCP_KEEPALIVE); u64::try_from(keepalive) .map(|t| t.min(MAX_TCP_KEEPALIVE)) .unwrap_or(DEFAULT_TCP_KEEPALIVE as u64) @@ -316,13 +317,13 @@ const DEFAULT_SERVICE_TIMEOUT: usize = 30; const MAX_SERVICE_TIMEOUT: u64 = 600; pub static SERVICE_TIMEOUT: LazyLock = LazyLock::new(|| { - let timeout = parse_usize_from_env("SERVICE_TIMEOUT", DEFAULT_SERVICE_TIMEOUT); + let timeout = parse_from_env("SERVICE_TIMEOUT", DEFAULT_SERVICE_TIMEOUT); u64::try_from(timeout) .map(|t| t.min(MAX_SERVICE_TIMEOUT)) .unwrap_or(DEFAULT_SERVICE_TIMEOUT as u64) }); -pub static REAL_USAGE: LazyLock = LazyLock::new(|| parse_bool_from_env("REAL_USAGE", true)); +pub static REAL_USAGE: LazyLock = LazyLock::new(|| parse_from_env("REAL_USAGE", true)); // pub static TOKEN_VALIDITY_RANGE: LazyLock = LazyLock::new(|| { // let short = if let Ok(Ok(validity)) = std::env::var("TOKEN_SHORT_VALIDITY") diff --git a/src/app/lazy/log.rs b/src/app/lazy/log.rs index f9d0307..8fb4874 100644 --- a/src/app/lazy/log.rs +++ b/src/app/lazy/log.rs @@ -1,10 +1,6 @@ -use std::{ - borrow::Cow, - sync::{Arc, atomic::AtomicU64}, - time::Duration, -}; - -use tokio::{ +use ::core::{sync::atomic::AtomicU64, time::Duration}; +use ::std::{borrow::Cow, sync::Arc}; +use ::tokio::{ fs::File, io::AsyncWriteExt as _, sync::{ @@ -15,10 +11,7 @@ use tokio::{ task::JoinHandle, }; -use crate::{ - common::utils::{parse_bool_from_env, parse_string_from_env}, - leak::manually_init::ManuallyInit, -}; +use crate::{common::utils::parse_from_env, leak::manually_init::ManuallyInit}; // --- 全局配置 --- @@ -30,8 +23,8 @@ static DEBUG_LOG_FILE: ManuallyInit> = ManuallyInit::new(); #[forbid(unused)] pub fn init() { unsafe { - DEBUG.init(parse_bool_from_env("DEBUG", true)); - DEBUG_LOG_FILE.init(parse_string_from_env("DEBUG_LOG_FILE", "debug.log")); + DEBUG.init(parse_from_env("DEBUG", true)); + DEBUG_LOG_FILE.init(parse_from_env("DEBUG_LOG_FILE", "debug.log")); } } diff --git a/src/app/model/config.rs b/src/app/model/config.rs index 2b52a2b..72c771b 100644 --- a/src/app/model/config.rs +++ b/src/app/model/config.rs @@ -13,7 +13,7 @@ use crate::{ lazy::CONFIG_FILE_PATH, model::FetchMode, }, - common::utils::{parse_bool_from_env, parse_string_from_env}, + common::utils::parse_from_env, leak::manually_init::ManuallyInit, }; @@ -123,16 +123,15 @@ impl AppConfig { let mut config = APP_CONFIG.write(); config.vision_ability = - VisionAbility::from_str(&parse_string_from_env("VISION_ABILITY", EMPTY_STRING)); - config.slow_pool = parse_bool_from_env("ENABLE_SLOW_POOL", false); - config.long_context = parse_bool_from_env("ENABLE_LONG_CONTEXT", false); - config.usage_check = - UsageCheck::from_str(&parse_string_from_env("USAGE_CHECK", EMPTY_STRING)); - config.dynamic_key = parse_bool_from_env("DYNAMIC_KEY", false); - config.share_token = parse_string_from_env("SHARED_TOKEN", EMPTY_STRING).into_owned(); - config.web_refs = parse_bool_from_env("INCLUDE_WEB_REFERENCES", false); + VisionAbility::from_str(&parse_from_env("VISION_ABILITY", EMPTY_STRING)); + config.slow_pool = parse_from_env("ENABLE_SLOW_POOL", false); + config.long_context = parse_from_env("ENABLE_LONG_CONTEXT", false); + config.usage_check = UsageCheck::from_str(&parse_from_env("USAGE_CHECK", EMPTY_STRING)); + config.dynamic_key = parse_from_env("DYNAMIC_KEY", false); + config.share_token = parse_from_env("SHARED_TOKEN", EMPTY_STRING).into_owned(); + config.web_refs = parse_from_env("INCLUDE_WEB_REFERENCES", false); config.fetch_models = - FetchMode::from_str(&parse_string_from_env("FETCH_RAW_MODELS", EMPTY_STRING)); + FetchMode::from_str(&parse_from_env("FETCH_RAW_MODELS", EMPTY_STRING)); } config_methods! { diff --git a/src/app/model/hash.rs b/src/app/model/hash.rs index 532a4f6..b3c2aa6 100644 --- a/src/app/model/hash.rs +++ b/src/app/model/hash.rs @@ -1,16 +1,16 @@ -use rand::{ +use ::core::{fmt, str::FromStr}; +use ::rand::{ RngCore as _, distr::{Distribution, StandardUniform}, }; -use sha2::Digest as _; -use std::{fmt, str::FromStr}; +use ::sha2::Digest as _; use crate::common::utils::hex::HEX_CHARS; static mut SAFE_HASH: bool = false; pub(super) fn init_hash() { - unsafe { SAFE_HASH = crate::common::utils::parse_bool_from_env("SAFE_HASH", true) } + unsafe { SAFE_HASH = crate::common::utils::parse_from_env("SAFE_HASH", true) } } #[derive(Debug)] diff --git a/src/app/model/state/log.rs b/src/app/model/state/log.rs index bbc7bb6..6217d2d 100644 --- a/src/app/model/state/log.rs +++ b/src/app/model/state/log.rs @@ -79,7 +79,7 @@ impl LogManager { /// 从存储中加载日志 #[inline(never)] pub async fn load() -> Result> { - let logs_limit = RequestLogsLimit::from_usize(crate::common::utils::parse_usize_from_env( + let logs_limit = RequestLogsLimit::from_usize(crate::common::utils::parse_from_env( "REQUEST_LOGS_LIMIT", 100, )); diff --git a/src/app/model/tz.rs b/src/app/model/tz.rs index 032abf4..066dc71 100644 --- a/src/app/model/tz.rs +++ b/src/app/model/tz.rs @@ -1,11 +1,11 @@ -use crate::{common::utils::parse_string_from_env, leak::manually_init::ManuallyInit}; +use crate::{common::utils::parse_from_env, leak::manually_init::ManuallyInit}; pub static TZ: ManuallyInit = ManuallyInit::new(); #[inline(always)] pub fn __init() { use std::str::FromStr as _; - let tz = match chrono_tz::Tz::from_str(&parse_string_from_env("TZ", super::EMPTY_STRING)) { + let tz = match chrono_tz::Tz::from_str(&parse_from_env("TZ", super::EMPTY_STRING)) { Ok(tz) => tz, Err(_e) => chrono_tz::Tz::UTC, }; diff --git a/src/common/client.rs b/src/common/client.rs index 9f59f87..15cb6bc 100644 --- a/src/common/client.rs +++ b/src/common/client.rs @@ -28,7 +28,7 @@ use reqwest::{ }, }; -trait RequestBuilderExt { +trait RequestBuilderExt: Sized { fn opt_header(self, key: K, value: Option) -> Self where http::HeaderName: TryFrom, diff --git a/src/common/impls.rs b/src/common/impls.rs new file mode 100644 index 0000000..3fc3da9 --- /dev/null +++ b/src/common/impls.rs @@ -0,0 +1,4 @@ +// pub trait ToBufStr: Copy { +// const BUF_SIZE: usize; +// fn to_str<'buf>(&self, buf: &'buf mut [u8; Self::BUF_SIZE]) -> &'buf mut str; +// } diff --git a/src/common/utils.rs b/src/common/utils.rs index b7d36aa..e90827b 100644 --- a/src/common/utils.rs +++ b/src/common/utils.rs @@ -8,364 +8,391 @@ pub mod duration_fmt; pub mod hex; pub mod string_builder; +use ::base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use ::core::str::FromStr as _; +use ::prost::Message as _; +use ::reqwest::Client; +use ::std::borrow::Cow; +pub use base64::{from_base64, to_base64}; pub use hex::{byte_to_hex, hex_to_byte}; pub use string_builder::StringBuilder; -use std::{borrow::Cow, str::FromStr as _}; - -use ::base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; -pub use base64::*; -use prost::Message as _; -use reqwest::Client; - use super::model::userinfo::{ - GetTeamsResponse, ListActiveSessionsResponse, Session, StripeProfile, Team, UsageProfile, - UserProfile, + GetTeamsResponse, ListActiveSessionsResponse, Session, StripeProfile, Team, UsageProfile, + UserProfile, }; use crate::{ - app::{ - lazy::{ - aggregated_usage_events_url, chat_models_url, filtered_usage_events_url, - is_on_new_pricing_url, server_config_url, teams_url, + app::{ + lazy::{ + aggregated_usage_events_url, chat_models_url, filtered_usage_events_url, + is_on_new_pricing_url, server_config_url, teams_url, + }, + model::{ChainUsage, Checksum, DateTime, ExtToken, GcppHost, Hash, RawToken, Token}, }, - model::{ChainUsage, Checksum, DateTime, ExtToken, GcppHost, Hash, RawToken, Token}, - }, - common::model::userinfo::{MembershipType, SubscriptionStatus}, - core::{ - aiserver::v1::{ - AvailableModelsRequest, AvailableModelsResponse, GetAggregatedUsageEventsRequest, - GetAggregatedUsageEventsResponse, GetFilteredUsageEventsRequest, - GetFilteredUsageEventsResponse, GetServerConfigResponse, + common::model::userinfo::{MembershipType, SubscriptionStatus}, + core::{ + aiserver::v1::{ + AvailableModelsRequest, AvailableModelsResponse, GetAggregatedUsageEventsRequest, + GetAggregatedUsageEventsResponse, GetFilteredUsageEventsRequest, + GetFilteredUsageEventsResponse, GetServerConfigResponse, + }, + config::key_config, }, - config::key_config, - }, }; -pub fn parse_bool_from_env(key: &str, default: bool) -> bool { - std::env::var(key) - .ok() - .map(|mut val| { - let res = { - val.make_ascii_lowercase(); - val.trim() - }; - match res { - "true" | "1" => true, - "false" | "0" => false, - _ => default, - } - }) - .unwrap_or(default) +mod private { + pub trait Sealed: Sized {} + + impl Sealed for bool {} + impl Sealed for &'static str {} + impl Sealed for usize {} } -pub fn parse_string_from_env(key: &str, default: &'static str) -> Cow<'static, str> { - match std::env::var(key) { - Ok(mut value) => { - let trimmed = value.trim(); +pub trait ParseFromEnv: private::Sealed { + type Result = Self; + fn parse_from_env(key: &str, default: Self) -> Self::Result; +} - if trimmed.is_empty() { - // 如果 trim 后为空,使用默认值(不分配) - Cow::Borrowed(default) - } else if trimmed.len() == value.len() { - // 不需要 trim,直接使用 - Cow::Owned(value) - } else { - // 需要 trim - 就地修改 - let trimmed_len = trimmed.len(); - let start_offset = trimmed.as_ptr() as usize - value.as_ptr() as usize; - - unsafe { - let vec = value.as_mut_vec(); - if start_offset > 0 { - vec.copy_within(start_offset..start_offset + trimmed_len, 0); - } - vec.set_len(trimmed_len); - } - - Cow::Owned(value) - } +impl ParseFromEnv for bool { + #[inline] + fn parse_from_env(key: &str, default: Self) -> Self::Result { + ::std::env::var(key) + .ok() + .map(|mut val| { + let res = { + val.make_ascii_lowercase(); + val.trim() + }; + match res { + "true" | "1" => true, + "false" | "0" => false, + _ => default, + } + }) + .unwrap_or(default) } - Err(_) => Cow::Borrowed(default), - } } -pub fn parse_usize_from_env(key: &str, default: usize) -> usize { - std::env::var(key) - .ok() - .and_then(|v| v.trim().parse().ok()) - .unwrap_or(default) +impl ParseFromEnv for &'static str { + type Result = Cow<'static, str>; + #[inline] + fn parse_from_env(key: &str, default: Self) -> Self::Result { + match ::std::env::var(key) { + Ok(mut value) => { + let trimmed = value.trim(); + + if trimmed.is_empty() { + // 如果 trim 后为空,使用默认值(不分配) + Cow::Borrowed(default) + } else if trimmed.len() == value.len() { + // 不需要 trim,直接使用 + Cow::Owned(value) + } else { + // 需要 trim - 就地修改 + let trimmed_len = trimmed.len(); + let start_offset = trimmed.as_ptr() as usize - value.as_ptr() as usize; + + unsafe { + let vec = value.as_mut_vec(); + if start_offset > 0 { + vec.copy_within(start_offset..start_offset + trimmed_len, 0); + } + vec.set_len(trimmed_len); + } + + Cow::Owned(value) + } + } + Err(_) => Cow::Borrowed(default), + } + } +} + +impl ParseFromEnv for usize { + #[inline] + fn parse_from_env(key: &str, default: Self) -> Self::Result { + ::std::env::var(key) + .ok() + .and_then(|v| v.trim().parse().ok()) + .unwrap_or(default) + } +} + +#[inline] +pub fn parse_from_env(key: &str, default: T) -> T::Result { + ParseFromEnv::parse_from_env(key, default) } pub fn now_secs() -> u64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("system time before Unix epoch") - .as_secs() + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system time before Unix epoch") + .as_secs() } const LEN: usize = 2; -pub trait TrimNewlines { - fn trim_leading_newlines(self) -> Self; +pub trait TrimNewlines: Sized { + fn trim_leading_newlines(self) -> Self; } impl TrimNewlines for &str { - #[inline(always)] - fn trim_leading_newlines(self) -> Self { - let bytes = self.as_bytes(); - if bytes.len() >= LEN && bytes[0] == b'\n' && bytes[1] == b'\n' { - return unsafe { self.get_unchecked(LEN..) }; + #[inline(always)] + fn trim_leading_newlines(self) -> Self { + let bytes = self.as_bytes(); + if bytes.len() >= LEN && bytes[0] == b'\n' && bytes[1] == b'\n' { + return unsafe { self.get_unchecked(LEN..) }; + } + self } - self - } } impl TrimNewlines for String { - #[inline(always)] - fn trim_leading_newlines(mut self) -> Self { - let bytes = self.as_bytes(); - if bytes.len() >= LEN && bytes[0] == b'\n' && bytes[1] == b'\n' { - unsafe { - let vec = self.as_mut_vec(); - vec.drain(..LEN); - } + #[inline(always)] + fn trim_leading_newlines(mut self) -> Self { + let bytes = self.as_bytes(); + if bytes.len() >= LEN && bytes[0] == b'\n' && bytes[1] == b'\n' { + unsafe { + let vec = self.as_mut_vec(); + vec.drain(..LEN); + } + } + self } - self - } } /// 获取完整的token配置文件 /// 协调多个数据源,可选择性获取用户信息 #[inline(never)] pub async fn get_token_profile( - client: Client, - token: &Token, - maybe_token: Option<&Token>, - is_pri: bool, - include_user: bool, - include_sessions: bool, + client: Client, + token: &Token, + maybe_token: Option<&Token>, + is_pri: bool, + include_user: bool, + include_sessions: bool, ) -> ( - Option, - Option, - Option>, + Option, + Option, + Option>, ) { - let maybe_token = maybe_token.unwrap_or(token); + let maybe_token = maybe_token.unwrap_or(token); - let mut buf = [0; 31]; - let user_id = maybe_token.raw().subject.id.to_str(&mut buf) as &str; + let mut buf = [0; 31]; + let user_id = maybe_token.raw().subject.id.to_str(&mut buf) as &str; - if include_user { - if include_sessions { - // 并发获取所有数据,user为必需 - let (mut stripe, _, mut user, teams, is_on_new_pricing, sessions) = tokio::join!( - get_stripe_profile(&client, token.as_str(), is_pri), - get_usage_profile(&client, user_id, maybe_token.as_str(), is_pri), - get_user_profile(&client, user_id, maybe_token.as_str(), is_pri), - get_teams(&client, user_id, maybe_token.as_str(), is_pri), - get_is_on_new_pricing(&client, user_id, maybe_token.as_str(), is_pri), - get_sessions(&client, user_id, maybe_token.as_str(), is_pri) - ); + if include_user { + if include_sessions { + // 并发获取所有数据,user为必需 + let (mut stripe, _, mut user, teams, is_on_new_pricing, sessions) = tokio::join!( + get_stripe_profile(&client, token.as_str(), is_pri), + get_usage_profile(&client, user_id, maybe_token.as_str(), is_pri), + get_user_profile(&client, user_id, maybe_token.as_str(), is_pri), + get_teams(&client, user_id, maybe_token.as_str(), is_pri), + get_is_on_new_pricing(&client, user_id, maybe_token.as_str(), is_pri), + get_sessions(&client, user_id, maybe_token.as_str(), is_pri) + ); - if let Some(stripe) = stripe.as_mut() - && teams.is_some_and(|teams| { - teams.into_iter().any(|team| { - team.has_billing - && team.subscription_status.is_some_and(|subscription_status| { - matches!(subscription_status, SubscriptionStatus::Active) - }) - }) - }) - { - stripe.membership_type = MembershipType::Enterprise; - } + if let Some(stripe) = stripe.as_mut() + && teams.is_some_and(|teams| { + teams.into_iter().any(|team| { + team.has_billing + && team.subscription_status.is_some_and(|subscription_status| { + matches!(subscription_status, SubscriptionStatus::Active) + }) + }) + }) + { + stripe.membership_type = MembershipType::Enterprise; + } - if let Some(user) = user.as_mut() { - user.is_on_new_pricing = is_on_new_pricing.unwrap_or(true); - } + if let Some(user) = user.as_mut() { + user.is_on_new_pricing = is_on_new_pricing.unwrap_or(true); + } - // 所有数据都必需成功 - (user, stripe, sessions) + // 所有数据都必需成功 + (user, stripe, sessions) + } else { + // 并发获取所有数据,user为必需 + let (mut stripe, _, mut user, teams, is_on_new_pricing) = tokio::join!( + get_stripe_profile(&client, token.as_str(), is_pri), + get_usage_profile(&client, user_id, maybe_token.as_str(), is_pri), + get_user_profile(&client, user_id, maybe_token.as_str(), is_pri), + get_teams(&client, user_id, maybe_token.as_str(), is_pri), + get_is_on_new_pricing(&client, user_id, maybe_token.as_str(), is_pri) + ); + + if let Some(stripe) = stripe.as_mut() + && teams.is_some_and(|teams| { + teams.into_iter().any(|team| { + team.has_billing + && team.subscription_status.is_some_and(|subscription_status| { + matches!(subscription_status, SubscriptionStatus::Active) + }) + }) + }) + { + stripe.membership_type = MembershipType::Enterprise; + } + + if let Some(user) = user.as_mut() { + user.is_on_new_pricing = is_on_new_pricing.unwrap_or(true); + } + + // 所有数据都必需成功 + (user, stripe, None) + } } else { - // 并发获取所有数据,user为必需 - let (mut stripe, _, mut user, teams, is_on_new_pricing) = tokio::join!( - get_stripe_profile(&client, token.as_str(), is_pri), - get_usage_profile(&client, user_id, maybe_token.as_str(), is_pri), - get_user_profile(&client, user_id, maybe_token.as_str(), is_pri), - get_teams(&client, user_id, maybe_token.as_str(), is_pri), - get_is_on_new_pricing(&client, user_id, maybe_token.as_str(), is_pri) - ); + // 仅获取stripe数据 + let (mut stripe, _, teams) = tokio::join!( + get_stripe_profile(&client, token.as_str(), is_pri), + get_usage_profile(&client, user_id, maybe_token.as_str(), is_pri), + get_teams(&client, user_id, maybe_token.as_str(), is_pri), + ); - if let Some(stripe) = stripe.as_mut() - && teams.is_some_and(|teams| { - teams.into_iter().any(|team| { - team.has_billing - && team.subscription_status.is_some_and(|subscription_status| { - matches!(subscription_status, SubscriptionStatus::Active) - }) - }) - }) - { - stripe.membership_type = MembershipType::Enterprise; - } - - if let Some(user) = user.as_mut() { - user.is_on_new_pricing = is_on_new_pricing.unwrap_or(true); - } - - // 所有数据都必需成功 - (user, stripe, None) - } - } else { - // 仅获取stripe数据 - let (mut stripe, _, teams) = tokio::join!( - get_stripe_profile(&client, token.as_str(), is_pri), - get_usage_profile(&client, user_id, maybe_token.as_str(), is_pri), - get_teams(&client, user_id, maybe_token.as_str(), is_pri), - ); - - if let Some(stripe) = stripe.as_mut() - && teams.is_some_and(|teams| { - teams.into_iter().any(|team| { - team.has_billing - && team.subscription_status.is_some_and(|subscription_status| { - matches!(subscription_status, SubscriptionStatus::Active) + if let Some(stripe) = stripe.as_mut() + && teams.is_some_and(|teams| { + teams.into_iter().any(|team| { + team.has_billing + && team.subscription_status.is_some_and(|subscription_status| { + matches!(subscription_status, SubscriptionStatus::Active) + }) + }) }) - }) - }) - { - stripe.membership_type = MembershipType::Enterprise; - } + { + stripe.membership_type = MembershipType::Enterprise; + } - (None, stripe, None) - } + (None, stripe, None) + } } /// 获取用户使用情况配置文件 pub async fn get_usage_profile(client: &Client, user_id: &str, auth_token: &str, is_pri: bool) { - if !*crate::app::lazy::log::DEBUG { - return; - } - - let request = super::client::build_usage_request(client, user_id, auth_token, is_pri); - let response = match request.send().await { - Ok(r) => r, - Err(_) => { - crate::debug!(" send error"); - return; + if !*crate::app::lazy::log::DEBUG { + return; } - }; - crate::debug!(" got {}", response.status()); - let usage = response.json::().await.ok(); - crate::debug!( - " got {}", - __unwrap!(serde_json::to_string_pretty(&usage)) - ); + + let request = super::client::build_usage_request(client, user_id, auth_token, is_pri); + let response = match request.send().await { + Ok(r) => r, + Err(_) => { + crate::debug!(" send error"); + return; + } + }; + crate::debug!(" got {}", response.status()); + let usage = response.json::().await.ok(); + crate::debug!( + " got {}", + __unwrap!(serde_json::to_string_pretty(&usage)) + ); } /// 获取Stripe付费配置文件 pub async fn get_stripe_profile( - client: &Client, - auth_token: &str, - is_pri: bool, + client: &Client, + auth_token: &str, + is_pri: bool, ) -> Option { - let request = super::client::build_profile_request(client, auth_token, is_pri); + let request = super::client::build_profile_request(client, auth_token, is_pri); - let response = request.send().await.ok()?; - crate::debug!(" {}", response.status()); - response.json::().await.ok() + let response = request.send().await.ok()?; + crate::debug!(" {}", response.status()); + response.json::().await.ok() } /// 获取用户基础配置文件 pub async fn get_user_profile( - client: &Client, - user_id: &str, - auth_token: &str, - is_pri: bool, + client: &Client, + user_id: &str, + auth_token: &str, + is_pri: bool, ) -> Option { - let request = super::client::build_userinfo_request(client, user_id, auth_token, is_pri); + let request = super::client::build_userinfo_request(client, user_id, auth_token, is_pri); - // let response = request.send().await.ok()?; - // crate::debug!("get_user_profile \n{response:?}"); - // let bytes = response.bytes().await.ok()?; - // crate::debug!("bytes \n{:?}", unsafe { std::str::from_utf8_unchecked(&bytes[..]) }); - // serde_json::from_slice::(&bytes).ok() - let response = request.send().await.ok()?; - crate::debug!(" {}", response.status()); - response.json::().await.ok() + // let response = request.send().await.ok()?; + // crate::debug!("get_user_profile \n{response:?}"); + // let bytes = response.bytes().await.ok()?; + // crate::debug!("bytes \n{:?}", unsafe { std::str::from_utf8_unchecked(&bytes[..]) }); + // serde_json::from_slice::(&bytes).ok() + let response = request.send().await.ok()?; + crate::debug!(" {}", response.status()); + response.json::().await.ok() } pub async fn get_available_models( - ext_token: ExtToken, - is_pri: bool, - request: AvailableModelsRequest, + ext_token: ExtToken, + is_pri: bool, + request: AvailableModelsRequest, ) -> Option { - let response = { - let client = super::client::build_client_request(super::client::AiServiceRequest { - ext_token, - fs_client_key: None, - url: chat_models_url(is_pri), - is_stream: false, - trace_id: Some(new_uuid_v4()), - is_pri, - cookie: None, - }); - client - .body(__unwrap!(encode_message(&request, false))) - .send() - .await - .ok()? - .bytes() - .await - .ok()? - }; - let available_models = AvailableModelsResponse::decode(response.as_ref()).ok()?; - Some(available_models) + let response = { + let client = super::client::build_client_request(super::client::AiServiceRequest { + ext_token, + fs_client_key: None, + url: chat_models_url(is_pri), + is_stream: false, + trace_id: Some(new_uuid_v4()), + is_pri, + cookie: None, + }); + client + .body(__unwrap!(encode_message(&request, false))) + .send() + .await + .ok()? + .bytes() + .await + .ok()? + }; + let available_models = AvailableModelsResponse::decode(response.as_ref()).ok()?; + Some(available_models) } pub async fn get_token_usage( - ext_token: ExtToken, - is_pri: bool, - time: DateTime, - model_id: &'static str, + ext_token: ExtToken, + is_pri: bool, + time: DateTime, + model_id: &'static str, ) -> Option { - let maybe_token = ext_token - .secondary_token - .as_ref() - .unwrap_or(&ext_token.primary_token); + let maybe_token = ext_token + .secondary_token + .as_ref() + .unwrap_or(&ext_token.primary_token); - let mut buf = [0; 31]; - let user_id = maybe_token.raw().subject.id.to_str(&mut buf) as &str; - let mut token_usage = None; + let mut buf = [0; 31]; + let user_id = maybe_token.raw().subject.id.to_str(&mut buf) as &str; + let mut token_usage = None; - for _ in 0..5 { - tokio::time::sleep(::core::time::Duration::from_millis(POLL_INTERVAL_MS)).await; - let res = get_filtered_usage_events( - &ext_token.get_client(), - user_id, - maybe_token.as_str(), - is_pri, - FilteredUsageArgs { - start: Some(time), - end: None, - model_id: Some(model_id), - size: Some(10), - }, - ) - .await?; + for _ in 0..5 { + tokio::time::sleep(::core::time::Duration::from_millis(POLL_INTERVAL_MS)).await; + let res = get_filtered_usage_events( + &ext_token.get_client(), + user_id, + maybe_token.as_str(), + is_pri, + FilteredUsageArgs { + start: Some(time), + end: None, + model_id: Some(model_id), + size: Some(10), + }, + ) + .await?; - if let Some(usage) = res.usage_events_display.first()?.token_usage { - token_usage = Some(usage); - break; - }; - } + if let Some(usage) = res.usage_events_display.first()?.token_usage { + token_usage = Some(usage); + break; + }; + } - token_usage.map(|token_usage| ChainUsage { - input: token_usage.input_tokens, - output: token_usage.output_tokens, - cache_write: token_usage.cache_write_tokens, - cache_read: token_usage.cache_read_tokens, - cents: token_usage.total_cents, - }) + token_usage.map(|token_usage| ChainUsage { + input: token_usage.input_tokens, + output: token_usage.output_tokens, + cache_write: token_usage.cache_write_tokens, + cache_read: token_usage.cache_read_tokens, + cents: token_usage.total_cents, + }) } // pub fn validate_token_and_checksum(auth_token: &str) -> Option<(String, Checksum)> { @@ -455,321 +482,327 @@ pub fn format_time_ms(seconds: f64) -> f64 { (seconds * 1000.0).round() / 1000.0 /// 将 JWT token 转换为 TokenInfo #[inline] pub fn token_to_tokeninfo( - token: RawToken, - checksum: Checksum, - client_key: Hash, - config_version: Option, - session_id: uuid::Uuid, - proxy_name: Option, - timezone: Option, - gcpp_host: Option, + token: RawToken, + checksum: Checksum, + client_key: Hash, + config_version: Option, + session_id: uuid::Uuid, + proxy_name: Option, + timezone: Option, + gcpp_host: Option, ) -> key_config::TokenInfo { - key_config::TokenInfo { - token: Some(key_config::token_info::Token::from_raw(token)), - checksum: checksum.into_bytes().to_vec(), - client_key: client_key.into_bytes().to_vec(), - config_version: config_version.map(|v| v.into_bytes().to_vec()), - session_id: session_id.into_bytes().to_vec(), - proxy_name, - timezone, - gcpp_host, - } + key_config::TokenInfo { + token: Some(key_config::token_info::Token::from_raw(token)), + checksum: checksum.into_bytes().to_vec(), + client_key: client_key.into_bytes().to_vec(), + config_version: config_version.map(|v| v.into_bytes().to_vec()), + session_id: session_id.into_bytes().to_vec(), + proxy_name, + timezone, + gcpp_host, + } } /// 将 TokenInfo 转换为 JWT token #[inline] pub fn tokeninfo_to_token(info: key_config::TokenInfo) -> Option { - let checksum = Checksum::from_bytes(info.checksum.try_into().ok()?); - let client_key = Hash::from_bytes(info.client_key.try_into().ok()?); - let config_version = info - .config_version - .and_then(|v| uuid::Uuid::from_slice(&v).ok()); - let session_id = uuid::Uuid::from_slice(&info.session_id).ok()?; - let timezone = info.timezone.and_then(|s| chrono_tz::Tz::from_str(&s).ok()); - let gcpp_host = info.gcpp_host.and_then(GcppHost::from_i32); - Some(ExtToken { - primary_token: Token::new(info.token?.into_raw()?, None), - secondary_token: None, - checksum, - client_key, - config_version, - session_id, - proxy: info.proxy_name, - timezone, - gcpp_host, - user: None, - }) + let checksum = Checksum::from_bytes(info.checksum.try_into().ok()?); + let client_key = Hash::from_bytes(info.client_key.try_into().ok()?); + let config_version = info + .config_version + .and_then(|v| uuid::Uuid::from_slice(&v).ok()); + let session_id = uuid::Uuid::from_slice(&info.session_id).ok()?; + let timezone = info.timezone.and_then(|s| chrono_tz::Tz::from_str(&s).ok()); + let gcpp_host = info.gcpp_host.and_then(GcppHost::from_i32); + Some(ExtToken { + primary_token: Token::new(info.token?.into_raw()?, None), + secondary_token: None, + checksum, + client_key, + config_version, + session_id, + proxy: info.proxy_name, + timezone, + gcpp_host, + user: None, + }) } /// 压缩数据为gzip格式 #[inline] fn compress_gzip(data: &[u8]) -> Result, std::io::Error> { - use flate2::{Compression, write::GzEncoder}; - use std::io::Write as _; + use std::io::Write as _; - const LEVEL: Compression = Compression::new(6); + use flate2::{Compression, write::GzEncoder}; - let mut encoder = GzEncoder::new(Vec::new(), LEVEL); - encoder.write_all(data)?; - encoder.finish() + const LEVEL: Compression = Compression::new(6); + + let mut encoder = GzEncoder::new(Vec::new(), LEVEL); + encoder.write_all(data)?; + encoder.finish() } #[allow(clippy::uninit_vec)] #[inline(always)] pub fn encode_message( - message: &impl prost::Message, - maybe_stream: bool, + message: &impl prost::Message, + maybe_stream: bool, ) -> Result, Box> { - const COMPRESSION_THRESHOLD: usize = 1024; // 1KB - const LENGTH_OVERFLOW_MSG: &str = "Message length exceeds ~4 GiB"; + const COMPRESSION_THRESHOLD: usize = 1024; // 1KB + const LENGTH_OVERFLOW_MSG: &str = "Message length exceeds ~4 GiB"; - let estimated_size = message.encoded_len(); + let estimated_size = message.encoded_len(); - if !maybe_stream { - let mut encoded = Vec::with_capacity(estimated_size); - __unwrap!(message.encode(&mut encoded)); - return Ok(encoded); - } + if !maybe_stream { + let mut encoded = Vec::with_capacity(estimated_size); + __unwrap!(message.encode(&mut encoded)); + return Ok(encoded); + } - // 预留头部空间 - let mut buf = Vec::with_capacity(5 + estimated_size); + // 预留头部空间 + let mut buf = Vec::with_capacity(5 + estimated_size); - unsafe { - // 跳过头部5字节 - buf.set_len(5); + unsafe { + // 跳过头部5字节 + buf.set_len(5); - // 编码消息 - __unwrap!(message.encode(&mut buf)); - let message_len = buf.len() - 5; + // 编码消息 + __unwrap!(message.encode(&mut buf)); + let message_len = buf.len() - 5; - // 判断是否需要压缩 - let (compression_flag, final_len) = if message_len >= COMPRESSION_THRESHOLD { - // 需要压缩 - let compressed = compress_gzip(buf.get_unchecked(5..))?; - let compressed_len = compressed.len(); + // 判断是否需要压缩 + let (compression_flag, final_len) = if message_len >= COMPRESSION_THRESHOLD { + // 需要压缩 + let compressed = compress_gzip(buf.get_unchecked(5..))?; + let compressed_len = compressed.len(); - // 只在压缩后更小时才使用压缩版本 - if compressed_len < message_len { - // 直接覆盖原数据 - let dst = buf.as_mut_ptr().add(5); - ::core::ptr::copy_nonoverlapping(compressed.as_ptr(), dst, compressed_len); - // 截断到正确长度 - buf.set_len(5 + compressed_len); - (0x01, compressed_len) - } else { - // 压缩后反而更大,保持原样 - (0x00, message_len) - } - } else { - // 不需要压缩 - (0x00, message_len) - }; + // 只在压缩后更小时才使用压缩版本 + if compressed_len < message_len { + // 直接覆盖原数据 + let dst = buf.as_mut_ptr().add(5); + ::core::ptr::copy_nonoverlapping(compressed.as_ptr(), dst, compressed_len); + // 截断到正确长度 + buf.set_len(5 + compressed_len); + (0x01, compressed_len) + } else { + // 压缩后反而更大,保持原样 + (0x00, message_len) + } + } else { + // 不需要压缩 + (0x00, message_len) + }; - // 统一写入头部 - let len = u32::try_from(final_len).map_err(|_| LENGTH_OVERFLOW_MSG)?; - let ptr = buf.as_mut_ptr(); - *ptr = compression_flag; - *(ptr.add(1) as *mut [u8; 4]) = len.to_be_bytes(); - } + // 统一写入头部 + let len = u32::try_from(final_len).map_err(|_| LENGTH_OVERFLOW_MSG)?; + let ptr = buf.as_mut_ptr(); + *ptr = compression_flag; + *(ptr.add(1) as *mut [u8; 4]) = len.to_be_bytes(); + } - Ok(buf) + Ok(buf) } /// 生成 PKCE code_verifier 和对应的 code_challenge (S256 method). /// 返回一个包含 (verifier, challenge) 的元组。 #[inline] fn generate_pkce_pair() -> ([u8; 43], [u8; 43]) { - use rand::TryRngCore as _; - use sha2::Digest as _; + use rand::TryRngCore as _; + use sha2::Digest as _; - // 1. 生成 code_verifier 的原始随机字节 (32 bytes is recommended) - let mut verifier_bytes = [0u8; 32]; + // 1. 生成 code_verifier 的原始随机字节 (32 bytes is recommended) + let mut verifier_bytes = [0u8; 32]; - // 使用 OsRng 填充字节。如果失败(极其罕见),则直接 panic - rand::rngs::OsRng - .try_fill_bytes(&mut verifier_bytes) - .expect("获取系统安全随机数失败,这是一个严重错误!"); + // 使用 OsRng 填充字节。如果失败(极其罕见),则直接 panic + rand::rngs::OsRng + .try_fill_bytes(&mut verifier_bytes) + .expect("获取系统安全随机数失败,这是一个严重错误!"); - // 2. 将随机字节编码为 URL 安全 Base64 字符串,这就是 code_verifier - let mut code_verifier = [0; 43]; - __unwrap_panic!(URL_SAFE_NO_PAD.encode_slice(verifier_bytes, &mut code_verifier)); + // 2. 将随机字节编码为 URL 安全 Base64 字符串,这就是 code_verifier + let mut code_verifier = [0; 43]; + __unwrap_panic!(URL_SAFE_NO_PAD.encode_slice(verifier_bytes, &mut code_verifier)); - // 3. 计算 code_verifier 字符串的 SHA-256 哈希值 - let hash_result = sha2::Sha256::digest(code_verifier); + // 3. 计算 code_verifier 字符串的 SHA-256 哈希值 + let hash_result = sha2::Sha256::digest(code_verifier); - // 4. 将哈希结果编码为 URL 安全 Base64 字符串,这就是 code_challenge - let mut code_challenge = [0; 43]; - __unwrap_panic!(URL_SAFE_NO_PAD.encode_slice(hash_result, &mut code_challenge)); + // 4. 将哈希结果编码为 URL 安全 Base64 字符串,这就是 code_challenge + let mut code_challenge = [0; 43]; + __unwrap_panic!(URL_SAFE_NO_PAD.encode_slice(hash_result, &mut code_challenge)); - // 5. 同时返回 verifier 和 challenge - (code_verifier, code_challenge) + // 5. 同时返回 verifier 和 challenge + (code_verifier, code_challenge) } const POLL_MAX_ATTEMPTS: u8 = 5; const POLL_INTERVAL_MS: u64 = 1000; pub async fn get_new_token(ext_token: &mut ExtToken, is_pri: bool) -> bool { - let is_session = ext_token.primary_token.is_session(); + let is_session = ext_token.primary_token.is_session(); - match if is_session { - refresh_token(ext_token, is_pri).await - } else { - upgrade_token(ext_token, is_pri).await - } { - Some((new_token, s)) => { - let tmp = Token::new(new_token, Some(s)); - if !is_session && ext_token.secondary_token.is_none() { - let old_token = ::core::mem::replace(&mut ext_token.primary_token, tmp); - ext_token.secondary_token = Some(old_token); - } else { - ext_token.primary_token = tmp; - } - true + match if is_session { + refresh_token(ext_token, is_pri).await + } else { + upgrade_token(ext_token, is_pri).await + } { + Some((new_token, s)) => { + let tmp = Token::new(new_token, Some(s)); + if !is_session && ext_token.secondary_token.is_none() { + let old_token = ::core::mem::replace(&mut ext_token.primary_token, tmp); + ext_token.secondary_token = Some(old_token); + } else { + ext_token.primary_token = tmp; + } + true + } + None => false, } - None => false, - } } async fn upgrade_token(ext_token: &ExtToken, is_pri: bool) -> Option<(RawToken, String)> { - #[derive(::serde::Deserialize)] - #[serde(rename_all = "camelCase")] - struct PollResponse { - pub access_token: String, - // pub refresh_token: String, - // pub challenge: String, - // pub auth_id: String, - // pub uuid: String, - } - - let (verifier, challenge) = generate_pkce_pair(); - let verifier = unsafe { ::core::str::from_utf8_unchecked(&verifier) }; - let challenge = unsafe { ::core::str::from_utf8_unchecked(&challenge) }; - let mut buf = [0; 36]; - let uuid = uuid::Uuid::new_v4().hyphenated().encode_lower(&mut buf) as &str; - - let token = ext_token - .secondary_token - .as_ref() - .unwrap_or(&ext_token.primary_token); - let mut buf = [0; 31]; - let user_id = token.raw().subject.id.to_str(&mut buf) as &str; - let auth_token = token.as_str(); - - // 发起刷新请求 - let upgrade_response = super::client::build_token_upgrade_request( - &ext_token.get_client(), - uuid, - challenge, - user_id, - auth_token, - is_pri, - ) - .send() - .await - .ok()?; - - if !upgrade_response.status().is_success() { - return None; - } - - // 轮询获取token - for _ in 0..POLL_MAX_ATTEMPTS { - let poll_response = - super::client::build_token_poll_request(&ext_token.get_client(), uuid, verifier, is_pri) - .send() - .await - .ok()?; - - match poll_response.status() { - reqwest::StatusCode::OK => { - let token = poll_response - .json::() - .await - .ok()? - .access_token; - return parse_token(token); - } - reqwest::StatusCode::NOT_FOUND => { - tokio::time::sleep(::core::time::Duration::from_millis(POLL_INTERVAL_MS)).await; - } - _ => return None, + #[derive(::serde::Deserialize)] + #[serde(rename_all = "camelCase")] + struct PollResponse { + pub access_token: String, + // pub refresh_token: String, + // pub challenge: String, + // pub auth_id: String, + // pub uuid: String, } - } - None -} + let (verifier, challenge) = generate_pkce_pair(); + let verifier = unsafe { ::core::str::from_utf8_unchecked(&verifier) }; + let challenge = unsafe { ::core::str::from_utf8_unchecked(&challenge) }; + let mut buf = [0; 36]; + let uuid = uuid::Uuid::new_v4().hyphenated().encode_lower(&mut buf) as &str; -async fn refresh_token(ext_token: &ExtToken, is_pri: bool) -> Option<(RawToken, String)> { - const CLIENT_ID: &str = "KbZUR41cY7W6zRSdpSUJ7I7mLYBKOCmB"; + let token = ext_token + .secondary_token + .as_ref() + .unwrap_or(&ext_token.primary_token); + let mut buf = [0; 31]; + let user_id = token.raw().subject.id.to_str(&mut buf) as &str; + let auth_token = token.as_str(); - struct RefreshTokenRequest<'a> { - refresh_token: &'a str, - } - - impl ::serde::Serialize for RefreshTokenRequest<'_> { - fn serialize(&self, serializer: S) -> Result - where - S: ::serde::Serializer, - { - use ::serde::ser::SerializeStruct as _; - let mut state = serializer.serialize_struct("RefreshTokenRequest", 3)?; - state.serialize_field("grant_type", "refresh_token")?; - state.serialize_field("client_id", CLIENT_ID)?; - state.serialize_field("refresh_token", self.refresh_token)?; - state.end() - } - } - - #[derive(::serde::Deserialize)] - struct RefreshTokenResponse { - access_token: String, - // id_token: String, - // #[serde(rename = "shouldLogout")] - // should_logout: bool, - } - - let refresh_request = RefreshTokenRequest { - refresh_token: ext_token.primary_token.as_str(), - }; - - let body = serde_json::to_vec(&refresh_request).ok()?; - - let response = super::client::build_token_refresh_request(&ext_token.get_client(), is_pri, body) + // 发起刷新请求 + let upgrade_response = super::client::build_token_upgrade_request( + &ext_token.get_client(), + uuid, + challenge, + user_id, + auth_token, + is_pri, + ) .send() .await .ok()?; - let token = response - .json::() - .await - .ok()? - .access_token; + if !upgrade_response.status().is_success() { + return None; + } - parse_token(token) + // 轮询获取token + for _ in 0..POLL_MAX_ATTEMPTS { + let poll_response = super::client::build_token_poll_request( + &ext_token.get_client(), + uuid, + verifier, + is_pri, + ) + .send() + .await + .ok()?; + + match poll_response.status() { + reqwest::StatusCode::OK => { + let token = poll_response + .json::() + .await + .ok()? + .access_token; + return parse_token(token); + } + reqwest::StatusCode::NOT_FOUND => { + tokio::time::sleep(::core::time::Duration::from_millis(POLL_INTERVAL_MS)).await; + } + _ => return None, + } + } + + None +} + +async fn refresh_token(ext_token: &ExtToken, is_pri: bool) -> Option<(RawToken, String)> { + const CLIENT_ID: &str = "KbZUR41cY7W6zRSdpSUJ7I7mLYBKOCmB"; + + struct RefreshTokenRequest<'a> { + refresh_token: &'a str, + } + + impl ::serde::Serialize for RefreshTokenRequest<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + use ::serde::ser::SerializeStruct as _; + let mut state = serializer.serialize_struct("RefreshTokenRequest", 3)?; + state.serialize_field("grant_type", "refresh_token")?; + state.serialize_field("client_id", CLIENT_ID)?; + state.serialize_field("refresh_token", self.refresh_token)?; + state.end() + } + } + + #[derive(::serde::Deserialize)] + struct RefreshTokenResponse { + access_token: String, + // id_token: String, + // #[serde(rename = "shouldLogout")] + // should_logout: bool, + } + + let refresh_request = RefreshTokenRequest { + refresh_token: ext_token.primary_token.as_str(), + }; + + let body = serde_json::to_vec(&refresh_request).ok()?; + + let response = + super::client::build_token_refresh_request(&ext_token.get_client(), is_pri, body) + .send() + .await + .ok()?; + + let token = response + .json::() + .await + .ok()? + .access_token; + + parse_token(token) } // 提取token解析逻辑 #[inline] fn parse_token(token_string: String) -> Option<(RawToken, String)> { - let raw_token = token_string.parse().ok()?; - Some((raw_token, token_string)) + let raw_token = token_string.parse().ok()?; + Some((raw_token, token_string)) } pub async fn get_server_config(ext_token: ExtToken, is_pri: bool) -> Option { - let response = { - let client = super::client::build_client_request(super::client::AiServiceRequest { - ext_token, - fs_client_key: None, - url: server_config_url(is_pri), - is_stream: false, - trace_id: Some(new_uuid_v4()), - is_pri, - cookie: None, - }); - client.send().await.ok()?.bytes().await.ok()? - }; - let server_config = GetServerConfigResponse::decode(response.as_ref()).ok()?; - uuid::Uuid::try_parse(&server_config.config_version).ok() + let response = { + let client = super::client::build_client_request(super::client::AiServiceRequest { + ext_token, + fs_client_key: None, + url: server_config_url(is_pri), + is_stream: false, + trace_id: Some(new_uuid_v4()), + is_pri, + cookie: None, + }); + client.send().await.ok()?.bytes().await.ok()? + }; + let server_config = GetServerConfigResponse::decode(response.as_ref()).ok()?; + uuid::Uuid::try_parse(&server_config.config_version).ok() } // pub async fn get_geo_cpp_backend_url( @@ -814,206 +847,206 @@ pub async fn get_server_config(ext_token: ExtToken, is_pri: bool) -> Option Option> { - let request = super::client::build_proto_web_request( - client, user_id, auth_token, teams_url, is_pri, EMPTY_JSON, - ); + let request = super::client::build_proto_web_request( + client, user_id, auth_token, teams_url, is_pri, EMPTY_JSON, + ); - request - .send() - .await - .ok()? - .json::() - .await - .ok() - .map(|r| r.teams) + request + .send() + .await + .ok()? + .json::() + .await + .ok() + .map(|r| r.teams) } pub async fn get_is_on_new_pricing( - client: &Client, - user_id: &str, - auth_token: &str, - is_pri: bool, + client: &Client, + user_id: &str, + auth_token: &str, + is_pri: bool, ) -> Option { - let request = super::client::build_proto_web_request( - client, - user_id, - auth_token, - is_on_new_pricing_url, - is_pri, - EMPTY_JSON, - ); + let request = super::client::build_proto_web_request( + client, + user_id, + auth_token, + is_on_new_pricing_url, + is_pri, + EMPTY_JSON, + ); - #[derive(serde::Deserialize)] - struct PricingConfig { - #[serde(rename = "isOnNewPricing")] - is_on_new_pricing: bool, - } + #[derive(serde::Deserialize)] + struct PricingConfig { + #[serde(rename = "isOnNewPricing")] + is_on_new_pricing: bool, + } - request - .send() - .await - .ok()? - .json::() - .await - .ok() - .map(|r| r.is_on_new_pricing) + request + .send() + .await + .ok()? + .json::() + .await + .ok() + .map(|r| r.is_on_new_pricing) } pub async fn get_sessions( - client: &Client, - user_id: &str, - auth_token: &str, - is_pri: bool, + client: &Client, + user_id: &str, + auth_token: &str, + is_pri: bool, ) -> Option> { - let request = super::client::build_sessions_request(client, user_id, auth_token, is_pri); + let request = super::client::build_sessions_request(client, user_id, auth_token, is_pri); - request - .send() - .await - .ok()? - .json::() - .await - .ok() - .map(|r| r.sessions) + request + .send() + .await + .ok()? + .json::() + .await + .ok() + .map(|r| r.sessions) } pub async fn get_aggregated_usage_events( - client: &Client, - user_id: &str, - auth_token: &str, - is_pri: bool, + client: &Client, + user_id: &str, + auth_token: &str, + is_pri: bool, ) -> Option { - let request = super::client::build_proto_web_request( - client, - user_id, - auth_token, - aggregated_usage_events_url, - is_pri, - bytes::Bytes::from(__unwrap!(serde_json::to_vec(&{ - const DELTA: chrono::TimeDelta = chrono::TimeDelta::new(2629743, 765840000).unwrap(); - let now = DateTime::utc_now(); - let start_date = now - DELTA; - GetAggregatedUsageEventsRequest { - team_id: -1, - start_date: Some(start_date.timestamp_millis()), - end_date: Some(now.timestamp_millis()), - user_id: None, - } - }))), - ); + let request = super::client::build_proto_web_request( + client, + user_id, + auth_token, + aggregated_usage_events_url, + is_pri, + bytes::Bytes::from(__unwrap!(serde_json::to_vec(&{ + const DELTA: chrono::TimeDelta = chrono::TimeDelta::new(2629743, 765840000).unwrap(); + let now = DateTime::utc_now(); + let start_date = now - DELTA; + GetAggregatedUsageEventsRequest { + team_id: -1, + start_date: Some(start_date.timestamp_millis()), + end_date: Some(now.timestamp_millis()), + user_id: None, + } + }))), + ); - request - .send() - .await - .ok()? - .json::() - .await - .ok() + request + .send() + .await + .ok()? + .json::() + .await + .ok() } pub struct FilteredUsageArgs { - pub start: Option, - pub end: Option, - pub model_id: Option<&'static str>, - pub size: Option, + pub start: Option, + pub end: Option, + pub model_id: Option<&'static str>, + pub size: Option, } impl From for GetFilteredUsageEventsRequest { - #[inline] - fn from(args: FilteredUsageArgs) -> Self { - const TZ: chrono::FixedOffset = chrono::FixedOffset::west_opt(16 * 3600).unwrap(); - const TIME: chrono::NaiveTime = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap(); - const START: chrono::TimeDelta = chrono::TimeDelta::days(-7); - const END: chrono::TimeDelta = chrono::TimeDelta::new(86399, 999000000).unwrap(); + #[inline] + fn from(args: FilteredUsageArgs) -> Self { + const TZ: chrono::FixedOffset = chrono::FixedOffset::west_opt(16 * 3600).unwrap(); + const TIME: chrono::NaiveTime = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap(); + const START: chrono::TimeDelta = chrono::TimeDelta::days(-7); + const END: chrono::TimeDelta = chrono::TimeDelta::new(86399, 999000000).unwrap(); - let (start_date, end_date) = if let (Some(a), Some(b)) = (args.start, args.end) { - (a.timestamp_millis(), b.timestamp_millis()) - } else { - let now = chrono::DateTime::::from_naive_utc_and_offset( - DateTime::naive_now(), - TZ, - ) - .date_naive() - .and_time(TIME); - match (args.start, args.end) { - (None, None) => ( - (now + START) - .and_local_timezone(TZ) - .unwrap() - .timestamp_millis(), - (now + END) - .and_local_timezone(TZ) - .unwrap() - .timestamp_millis(), - ), - (None, Some(b)) => ( - (now + START) - .and_local_timezone(TZ) - .unwrap() - .timestamp_millis(), - b.timestamp_millis(), - ), - (Some(a), None) => ( - a.timestamp_millis(), - (now + END) - .and_local_timezone(TZ) - .unwrap() - .timestamp_millis(), - ), - (Some(_), Some(_)) => unsafe { ::core::hint::unreachable_unchecked() }, - } - }; - Self { - team_id: 0, - start_date: Some(start_date), - end_date: Some(end_date), - user_id: None, - model_id: args.model_id.map(ToString::to_string), - page: Some(1), - page_size: Some(args.size.unwrap_or(100)), + let (start_date, end_date) = if let (Some(a), Some(b)) = (args.start, args.end) { + (a.timestamp_millis(), b.timestamp_millis()) + } else { + let now = chrono::DateTime::::from_naive_utc_and_offset( + DateTime::naive_now(), + TZ, + ) + .date_naive() + .and_time(TIME); + match (args.start, args.end) { + (None, None) => ( + (now + START) + .and_local_timezone(TZ) + .unwrap() + .timestamp_millis(), + (now + END) + .and_local_timezone(TZ) + .unwrap() + .timestamp_millis(), + ), + (None, Some(b)) => ( + (now + START) + .and_local_timezone(TZ) + .unwrap() + .timestamp_millis(), + b.timestamp_millis(), + ), + (Some(a), None) => ( + a.timestamp_millis(), + (now + END) + .and_local_timezone(TZ) + .unwrap() + .timestamp_millis(), + ), + (Some(_), Some(_)) => unsafe { ::core::hint::unreachable_unchecked() }, + } + }; + Self { + team_id: 0, + start_date: Some(start_date), + end_date: Some(end_date), + user_id: None, + model_id: args.model_id.map(ToString::to_string), + page: Some(1), + page_size: Some(args.size.unwrap_or(100)), + } } - } } pub async fn get_filtered_usage_events( - client: &Client, - user_id: &str, - auth_token: &str, - is_pri: bool, - args: FilteredUsageArgs, + client: &Client, + user_id: &str, + auth_token: &str, + is_pri: bool, + args: FilteredUsageArgs, ) -> Option { - let request = super::client::build_proto_web_request( - client, - user_id, - auth_token, - filtered_usage_events_url, - is_pri, - bytes::Bytes::from(__unwrap!(serde_json::to_vec(&{ - let req: GetFilteredUsageEventsRequest = args.into(); - req - }))), - ); + let request = super::client::build_proto_web_request( + client, + user_id, + auth_token, + filtered_usage_events_url, + is_pri, + bytes::Bytes::from(__unwrap!(serde_json::to_vec(&{ + let req: GetFilteredUsageEventsRequest = args.into(); + req + }))), + ); - let res = request.send().await.ok()?; - crate::debug!(" {}", res.status()); - let res = res.bytes().await.ok()?; - crate::debug!(" {}", unsafe { - ::core::str::from_utf8_unchecked(&res[..]) - }); - serde_json::from_slice(&res[..]).ok() - // .json::() - // .await - // .ok() + let res = request.send().await.ok()?; + crate::debug!(" {}", res.status()); + let res = res.bytes().await.ok()?; + crate::debug!(" {}", unsafe { + ::core::str::from_utf8_unchecked(&res[..]) + }); + serde_json::from_slice(&res[..]).ok() + // .json::() + // .await + // .ok() } #[inline] pub fn new_uuid_v4() -> [u8; 36] { - let mut buf = [0; 36]; - uuid::Uuid::new_v4().hyphenated().encode_lower(&mut buf); - buf + let mut buf = [0; 36]; + uuid::Uuid::new_v4().hyphenated().encode_lower(&mut buf); + buf } diff --git a/src/common/utils/string_builder.rs b/src/common/utils/string_builder.rs index 9098816..c4e147b 100644 --- a/src/common/utils/string_builder.rs +++ b/src/common/utils/string_builder.rs @@ -12,7 +12,7 @@ use ::std::borrow::Cow; mod private { use std::borrow::Cow; - pub trait Sealed {} + pub trait Sealed: Sized {} impl Sealed for &str {} impl Sealed for String {} @@ -22,9 +22,9 @@ mod private { /// A trait representing types that can be appended to a `StringBuilder`. /// This is a sealed trait and cannot be implemented for types outside this crate. -pub trait StringPart<'a>: private::Sealed + Into> + Debug + Clone {} +pub trait StringPart<'a>: private::Sealed + Into> {} -impl<'a, T> StringPart<'a> for T where T: private::Sealed + Into> + Debug + Clone {} +impl<'a, T> StringPart<'a> for T where T: private::Sealed + Into> {} /// Internal storage state for StringBuilder /// diff --git a/src/core/adapter.rs b/src/core/adapter.rs index 58cfe6d..3fe8914 100644 --- a/src/core/adapter.rs +++ b/src/core/adapter.rs @@ -128,11 +128,11 @@ fn extract_web_references_info(text: &str) -> (String, Vec, bool) } } -const trait ToOpt: Copy { +trait ToOpt: Copy { fn to_opt(self) -> Option; } -impl const ToOpt for bool { +impl ToOpt for bool { #[inline(always)] fn to_opt(self) -> Option { if self { Some(true) } else { None } } } diff --git a/src/core/model/resolver.rs b/src/core/model/resolver.rs index 8af3b3e..a341019 100644 --- a/src/core/model/resolver.rs +++ b/src/core/model/resolver.rs @@ -9,7 +9,7 @@ static mut BYPASS_MODEL_VALIDATION: bool = false; pub fn init_resolver() { unsafe { BYPASS_MODEL_VALIDATION = - crate::common::utils::parse_bool_from_env("BYPASS_MODEL_VALIDATION", false) + crate::common::utils::parse_from_env("BYPASS_MODEL_VALIDATION", false) } } diff --git a/src/core/stream/decoder.rs b/src/core/stream/decoder.rs index b3571c2..951b981 100644 --- a/src/core/stream/decoder.rs +++ b/src/core/stream/decoder.rs @@ -18,7 +18,7 @@ use std::{ time::Instant, }; -pub trait InstantExt { +pub trait InstantExt: Sized { fn duration_as_secs_f32(&mut self) -> f32; } diff --git a/src/core/stream/decoder/cpp.rs b/src/core/stream/decoder/cpp.rs index a45880d..88f670e 100644 --- a/src/core/stream/decoder/cpp.rs +++ b/src/core/stream/decoder/cpp.rs @@ -1,7 +1,8 @@ +use ::bytes::{Buf as _, BytesMut}; +use ::prost::Message as _; + use super::decompress_gzip; use crate::core::{aiserver::v1::StreamCppResponse, error::StreamError}; -use bytes::{Buf as _, BytesMut}; -use prost::Message as _; #[derive(::serde::Serialize, PartialEq, Clone)] #[serde(tag = "type", rename_all = "snake_case")] diff --git a/src/core/stream/decoder/direct.rs b/src/core/stream/decoder/direct.rs index 0070b02..cf3d273 100644 --- a/src/core/stream/decoder/direct.rs +++ b/src/core/stream/decoder/direct.rs @@ -1,6 +1,6 @@ -// use bytes::{Buf as _, BytesMut}; +// use ::bytes::{Buf as _, BytesMut}; -use std::borrow::Cow; +use ::std::borrow::Cow; use super::{ decompress_gzip, @@ -45,7 +45,7 @@ use super::{ // if let Ok(msg) = T::decode(&self.buf[..]) { // return Ok(Some(DecodedMessage::Protobuf(msg))); -// } else if let Some(text) = String::from_utf8(self.buf.to_vec()) { +// } else if let Ok(text) = String::from_utf8(self.buf.to_vec()) { // return Ok(Some(DecodedMessage::Text(text))); // } // } @@ -92,7 +92,7 @@ pub fn decode(data: &[u8]) -> Result, Deco if let Ok(msg) = T::decode(&*decompressed) { return Ok(DecodedMessage::Protobuf(msg)); - } else if let Some(text) = super::utils::string_from_utf8_cow(decompressed) { + } else if let Some(text) = super::utils::string_from_utf8(decompressed) { return Ok(DecodedMessage::Text(text)); } } diff --git a/src/core/stream/decoder/types.rs b/src/core/stream/decoder/types.rs index 03dceab..a84a80b 100644 --- a/src/core/stream/decoder/types.rs +++ b/src/core/stream/decoder/types.rs @@ -1,4 +1,4 @@ -use prost::Message; +use ::prost::Message; /// 表示可以被Protobuf编解码并可创建默认实例的消息类型 pub trait ProtobufMessage: Message + Default {} @@ -32,6 +32,7 @@ pub enum DecodedMessage { } // impl DecodedMessage { +// #[inline] // pub fn encode(&self) -> Vec // where // Self: Sized, @@ -43,8 +44,8 @@ pub enum DecodedMessage { // } // } -// impl std::fmt::Debug for DecodedMessage { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// impl ::core::fmt::Debug for DecodedMessage { +// fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { // match self { // Self::Protobuf(msg) => write!(f, "\n{msg:#?}"), // Self::Text(s) => write!(f, "\n{s:?}"), @@ -52,8 +53,8 @@ pub enum DecodedMessage { // } // } -// impl std::fmt::Display for DecodedMessage { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// impl ::core::fmt::Display for DecodedMessage { +// fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { // match self { // Self::Protobuf(msg) => write!(f, "\n{}", serde_json::to_string(msg).unwrap()), // Self::Text(s) => write!(f, "\n{s}"), diff --git a/src/core/stream/decoder/utils.rs b/src/core/stream/decoder/utils.rs index be88800..fdb8624 100644 --- a/src/core/stream/decoder/utils.rs +++ b/src/core/stream/decoder/utils.rs @@ -1,17 +1,36 @@ use std::borrow::Cow; +#[allow(private_bounds)] #[inline] -pub fn string_from_utf8(v: &[u8]) -> Option { - match ::core::str::from_utf8(v) { - Ok(_) => Some(unsafe { String::from_utf8_unchecked(v.to_vec()) }), +pub fn string_from_utf8(v: V) -> Option { + match ::core::str::from_utf8(v.as_bytes()) { + Ok(_) => Some(unsafe { String::from_utf8_unchecked(v.into_vec()) }), Err(_) => None, } } -#[inline] -pub fn string_from_utf8_cow(v: Cow<'_, [u8]>) -> Option { - match ::core::str::from_utf8(&v) { - Ok(_) => Some(unsafe { String::from_utf8_unchecked(v.into_owned()) }), - Err(_) => None, - } +trait StringFrom: Sized { + fn as_bytes(&self) -> &[u8]; + fn into_vec(self) -> Vec; } + +impl StringFrom for &[u8] { + #[inline(always)] + fn as_bytes(&self) -> &[u8] { *self } + #[inline(always)] + fn into_vec(self) -> Vec { self.to_vec() } +} + +impl StringFrom for Cow<'_, [u8]> { + #[inline(always)] + fn as_bytes(&self) -> &[u8] { self } + #[inline(always)] + fn into_vec(self) -> Vec { self.into_owned() } +} + +// mod private { +// pub trait Sealed: Sized {} + +// impl Sealed for &[u8] {} +// impl Sealed for super::Cow<'_, [u8]> {} +// } diff --git a/src/main.rs b/src/main.rs index d0c85fc..616e2bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,8 @@ hasher_prefixfree_extras, const_trait_impl, const_default, - core_intrinsics + core_intrinsics, + associated_type_defaults )] #![allow(clippy::redundant_static_lifetimes)] @@ -18,6 +19,34 @@ mod core; mod leak; mod natural_args; +use ::axum::{ + Router, middleware, + routing::{get, post}, +}; +use ::tokio::signal; +use ::tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer}; + +use app::{ + config::handle_config_update, + constant::{ + EMPTY_STRING, EXE_NAME, ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BUILD_KEY_PATH, + ROUTE_CONFIG_PATH, ROUTE_CONFIG_VERSION_GET_PATH, ROUTE_CPP_CONFIG_PATH, + ROUTE_CPP_MODELS_PATH, ROUTE_CPP_STREAM_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_FILE_SYNC_PATH, + ROUTE_FILE_UPLOAD_PATH, ROUTE_GEN_CHECKSUM, ROUTE_GEN_HASH, ROUTE_GEN_TOKEN, + ROUTE_GEN_UUID, ROUTE_GET_TIMESTAMP_HEADER, ROUTE_HEALTH_PATH, ROUTE_LOGS_GET_PATH, + ROUTE_LOGS_PATH, ROUTE_LOGS_TOKENS_GET_PATH, ROUTE_PROXIES_ADD_PATH, + ROUTE_PROXIES_DELETE_PATH, ROUTE_PROXIES_GET_PATH, ROUTE_PROXIES_PATH, + ROUTE_PROXIES_SET_GENERAL_PATH, ROUTE_PROXIES_SET_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, + ROUTE_STATIC_PATH, ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_ALIAS_SET_PATH, + ROUTE_TOKENS_CONFIG_VERSION_UPDATE_PATH, ROUTE_TOKENS_DELETE_PATH, ROUTE_TOKENS_GET_PATH, + ROUTE_TOKENS_PATH, ROUTE_TOKENS_PROFILE_UPDATE_PATH, ROUTE_TOKENS_PROXY_SET_PATH, + ROUTE_TOKENS_REFRESH_PATH, ROUTE_TOKENS_SET_PATH, ROUTE_TOKENS_STATUS_SET_PATH, + ROUTE_TOKENS_TIMEZONE_SET_PATH, VERSION, + }, + lazy::AUTH_TOKEN, + model::{AppConfig, AppState}, +}; +use common::utils::parse_from_env; use core::{ middleware::{admin_auth_middleware, cpp_auth_middleware, v1_auth_middleware}, route::{ @@ -40,34 +69,7 @@ use core::{ handle_chat_completions, handle_messages, handle_models, handle_raw_models, }, }; -use app::{ - config::handle_config_update, - constant::{ - EMPTY_STRING, EXE_NAME, ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BUILD_KEY_PATH, - ROUTE_CONFIG_PATH, ROUTE_CONFIG_VERSION_GET_PATH, ROUTE_CPP_CONFIG_PATH, - ROUTE_CPP_MODELS_PATH, ROUTE_CPP_STREAM_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_FILE_SYNC_PATH, - ROUTE_FILE_UPLOAD_PATH, ROUTE_GEN_CHECKSUM, ROUTE_GEN_HASH, ROUTE_GEN_TOKEN, - ROUTE_GEN_UUID, ROUTE_GET_TIMESTAMP_HEADER, ROUTE_HEALTH_PATH, ROUTE_LOGS_GET_PATH, - ROUTE_LOGS_PATH, ROUTE_LOGS_TOKENS_GET_PATH, ROUTE_PROXIES_ADD_PATH, - ROUTE_PROXIES_DELETE_PATH, ROUTE_PROXIES_GET_PATH, ROUTE_PROXIES_PATH, - ROUTE_PROXIES_SET_GENERAL_PATH, ROUTE_PROXIES_SET_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, - ROUTE_STATIC_PATH, ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_ALIAS_SET_PATH, - ROUTE_TOKENS_CONFIG_VERSION_UPDATE_PATH, ROUTE_TOKENS_DELETE_PATH, ROUTE_TOKENS_GET_PATH, - ROUTE_TOKENS_PATH, ROUTE_TOKENS_PROFILE_UPDATE_PATH, ROUTE_TOKENS_PROXY_SET_PATH, - ROUTE_TOKENS_REFRESH_PATH, ROUTE_TOKENS_SET_PATH, ROUTE_TOKENS_STATUS_SET_PATH, - ROUTE_TOKENS_TIMEZONE_SET_PATH, VERSION, - }, - lazy::AUTH_TOKEN, - model::{AppConfig, AppState}, -}; -use axum::{ - Router, middleware, - routing::{get, post}, -}; -use common::utils::{parse_string_from_env, parse_usize_from_env}; use natural_args::{DEFAULT_LISTEN_HOST, ENV_HOST, ENV_PORT}; -use tokio::signal; -use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer}; #[tokio::main] async fn main() { @@ -219,7 +221,6 @@ async fn main() { route_chat_completions_path, route_messages_path, ) = { - let route_prefix = parse_string_from_env("ROUTE_PREFIX", EMPTY_STRING); define_typed_constants! { &'static str => { RAW_MODELS_PATH = "/raw/models", @@ -228,9 +229,9 @@ async fn main() { MESSAGES_PATH = "/v1/messages", } } + use ::std::borrow::Cow; - use std::borrow::Cow; - + let route_prefix = parse_from_env("ROUTE_PREFIX", EMPTY_STRING); if route_prefix.is_empty() { ( Cow::Borrowed(RAW_MODELS_PATH), @@ -361,7 +362,7 @@ async fn main() { post(handle_get_config_version), ) // .route(ROUTE_TOKEN_UPGRADE_PATH, post(handle_token_upgrade)) - .layer(RequestBodyLimitLayer::new(parse_usize_from_env( + .layer(RequestBodyLimitLayer::new(parse_from_env( "REQUEST_BODY_LIMIT", 2_000_000, ))) @@ -379,7 +380,7 @@ async fn main() { .unwrap_or(3000) }; let addr = SocketAddr::new( - IpAddr::parse_ascii(parse_string_from_env(ENV_HOST, DEFAULT_LISTEN_HOST).as_bytes()) + IpAddr::parse_ascii(parse_from_env(ENV_HOST, DEFAULT_LISTEN_HOST).as_bytes()) .unwrap_or_else(|e| { __cold_path!(); // IP解析失败是错误路径 eprintln!("无法解析IP: {e}");