From 54e38f1db038ddbfc758527ac82aefd0c9efac2b Mon Sep 17 00:00:00 2001 From: wisdgod Date: Sun, 5 Jan 2025 16:50:26 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=99=E6=88=91=E5=B7=B2=E7=BB=8F=E4=B8=8D?= =?UTF-8?q?=E6=83=B3=E5=81=9A=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 3 + Cargo.lock | 585 +++++++++++++++++++++++++++++++---- Cargo.toml | 13 +- src/app.rs | 1 - src/app/config.rs | 10 +- src/app/constant.rs | 23 +- src/app/db.rs | 388 ++++++++++------------- src/app/db/logs.rs | 271 ++++++++++++++++ src/app/db/tokens.rs | 305 ++++++++++++++++++ src/app/db/users.rs | 198 ++++++++++++ src/app/lazy.rs | 30 +- src/app/model.rs | 84 +---- src/app/model/db.rs | 148 +++++++++ src/app/model/usage_check.rs | 6 +- src/chat.rs | 1 + src/chat/adapter.rs | 8 +- src/chat/constant.rs | 59 ++-- src/chat/model.rs | 14 +- src/chat/route.rs | 4 +- src/chat/route/auth.rs | 113 +++++++ src/chat/route/config.rs | 37 ++- src/chat/route/health.rs | 125 ++++---- src/chat/route/logs.rs | 47 ++- src/chat/route/token.rs | 221 ++++++------- src/chat/route/usage.rs | 49 ++- src/chat/service.rs | 316 ++++++++++--------- src/common/client.rs | 44 ++- src/common/models/error.rs | 24 +- src/common/models/health.rs | 6 +- src/common/models/usage.rs | 39 ++- src/common/utils.rs | 50 ++- src/common/utils/checksum.rs | 9 +- src/common/utils/oauth.rs | 79 ++--- src/common/utils/token.rs | 148 +++++++++ src/common/utils/tokens.rs | 144 --------- src/main.rs | 71 +++-- static/tokeninfo.html | 13 + 37 files changed, 2615 insertions(+), 1071 deletions(-) create mode 100644 src/app/db/logs.rs create mode 100644 src/app/db/tokens.rs create mode 100644 src/app/db/users.rs create mode 100644 src/app/model/db.rs create mode 100644 src/chat/route/auth.rs create mode 100644 src/common/utils/token.rs delete mode 100644 src/common/utils/tokens.rs diff --git a/.env.example b/.env.example index 472c252..416e6c7 100644 --- a/.env.example +++ b/.env.example @@ -40,3 +40,6 @@ DEFAULT_INSTRUCTIONS="Respond in Chinese by default" # 反向代理服务器主机名 CURSOR_API2_HOST= + +# 管理员认证令牌 +ADMIN_AUTH_TOKEN= diff --git a/Cargo.lock b/Cargo.lock index 1334525..45a7b98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -105,10 +105,10 @@ dependencies = [ "axum-core", "bytes", "futures-util", - "http", - "http-body", + "http 1.2.0", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.5.2", "hyper-util", "itoa", "matchit", @@ -121,7 +121,7 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 1.0.2", "tokio", "tower", "tower-layer", @@ -138,13 +138,13 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 1.2.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", "rustversion", - "sync_wrapper", + "sync_wrapper 1.0.2", "tower-layer", "tower-service", "tracing", @@ -162,9 +162,21 @@ dependencies = [ "miniz_oxide", "object", "rustc-demangle", - "windows-targets", + "windows-targets 0.52.6", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -237,6 +249,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.39" @@ -249,7 +267,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -308,7 +326,7 @@ version = "0.1.3" dependencies = [ "anyhow", "axum", - "base64", + "base64 0.22.1", "bytes", "chrono", "dotenvy", @@ -317,13 +335,14 @@ dependencies = [ "gif", "hex", "image", - "lazy_static", + "oauth2", "paste", "prost", "prost-build", "rand", "regex", - "reqwest", + "reqwest 0.12.12", + "ring", "rusqlite", "serde", "serde_json", @@ -561,8 +580,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -581,6 +602,25 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "h2" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "h2" version = "0.4.7" @@ -592,7 +632,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.2.0", "indexmap", "slab", "tokio", @@ -636,6 +676,17 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.2.0" @@ -647,6 +698,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -654,7 +716,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.2.0", ] [[package]] @@ -665,8 +727,8 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http", - "http-body", + "http 1.2.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -682,6 +744,30 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.5.2" @@ -691,9 +777,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2", - "http", - "http-body", + "h2 0.4.7", + "http 1.2.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -703,6 +789,20 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.32", + "rustls 0.21.12", + "tokio", + "tokio-rustls 0.24.1", +] + [[package]] name = "hyper-rustls" version = "0.27.5" @@ -710,14 +810,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http", - "hyper", + "http 1.2.0", + "hyper 1.5.2", "hyper-util", - "rustls", + "rustls 0.23.20", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.1", "tower-service", + "webpki-roots 0.26.7", ] [[package]] @@ -728,7 +829,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper", + "hyper 1.5.2", "hyper-util", "native-tls", "tokio", @@ -745,9 +846,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http", - "http-body", - "hyper", + "http 1.2.0", + "http-body 1.0.1", + "hyper 1.5.2", "pin-project-lite", "socket2", "tokio", @@ -985,12 +1086,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "lazy_static" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" - [[package]] name = "libc" version = "0.2.169" @@ -1106,6 +1201,26 @@ dependencies = [ "autocfg", ] +[[package]] +name = "oauth2" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c38841cdd844847e3e7c8d29cef9dcfed8877f8f56f9071f77843ecf3baf937f" +dependencies = [ + "base64 0.13.1", + "chrono", + "getrandom", + "http 0.2.12", + "rand", + "reqwest 0.11.27", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror 1.0.69", + "url", +] + [[package]] name = "object" version = "0.36.7" @@ -1304,6 +1419,58 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" +[[package]] +name = "quinn" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls 0.23.20", + "socket2", + "thiserror 2.0.9", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" +dependencies = [ + "bytes", + "getrandom", + "rand", + "ring", + "rustc-hash", + "rustls 0.23.20", + "rustls-pki-types", + "slab", + "thiserror 2.0.9", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.38" @@ -1372,6 +1539,47 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "reqwest" +version = "0.11.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" +dependencies = [ + "base64 0.21.7", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper-rustls 0.24.2", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls 0.21.12", + "rustls-pemfile 1.0.4", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 0.1.2", + "system-configuration 0.5.1", + "tokio", + "tokio-rustls 0.24.1", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots 0.25.4", + "winreg", +] + [[package]] name = "reqwest" version = "0.12.12" @@ -1379,17 +1587,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "async-compression", - "base64", + "base64 0.22.1", "bytes", "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.4.7", + "http 1.2.0", + "http-body 1.0.1", "http-body-util", - "hyper", - "hyper-rustls", + "hyper 1.5.2", + "hyper-rustls 0.27.5", "hyper-tls", "hyper-util", "ipnet", @@ -1400,14 +1608,18 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile", + "quinn", + "rustls 0.23.20", + "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", - "system-configuration", + "sync_wrapper 1.0.2", + "system-configuration 0.6.1", "tokio", "tokio-native-tls", + "tokio-rustls 0.26.1", "tokio-util", "tower", "tower-service", @@ -1416,6 +1628,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", + "webpki-roots 0.26.7", "windows-registry", ] @@ -1441,6 +1654,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" dependencies = [ "bitflags 2.6.0", + "chrono", "fallible-iterator", "fallible-streaming-iterator", "hashlink", @@ -1454,6 +1668,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" + [[package]] name = "rustix" version = "0.38.42" @@ -1467,6 +1687,18 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.23.20" @@ -1474,12 +1706,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" dependencies = [ "once_cell", + "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.102.8", "subtle", "zeroize", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + [[package]] name = "rustls-pemfile" version = "2.2.0" @@ -1494,6 +1736,19 @@ name = "rustls-pki-types" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +dependencies = [ + "web-time", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] [[package]] name = "rustls-webpki" @@ -1527,6 +1782,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -1681,6 +1946,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "sync_wrapper" version = "1.0.2" @@ -1714,6 +1985,17 @@ dependencies = [ "windows", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys 0.5.0", +] + [[package]] name = "system-configuration" version = "0.6.1" @@ -1722,7 +2004,17 @@ checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.6.0", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.6.0", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", ] [[package]] @@ -1748,6 +2040,46 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +dependencies = [ + "thiserror-impl 2.0.9", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tinystr" version = "0.7.6" @@ -1758,6 +2090,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.42.0" @@ -1795,13 +2142,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.12", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ - "rustls", + "rustls 0.23.20", "tokio", ] @@ -1838,7 +2195,7 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper", + "sync_wrapper 1.0.2", "tokio", "tower-layer", "tower-service", @@ -1853,7 +2210,7 @@ checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" dependencies = [ "bitflags 2.6.0", "bytes", - "http", + "http 1.2.0", "pin-project-lite", "tower-layer", "tower-service", @@ -1924,6 +2281,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] @@ -2070,6 +2428,31 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" + +[[package]] +name = "webpki-roots" +version = "0.26.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "weezl" version = "0.1.8" @@ -2105,7 +2488,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" dependencies = [ "windows-core 0.57.0", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2114,7 +2497,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2126,7 +2509,7 @@ dependencies = [ "windows-implement", "windows-interface", "windows-result 0.1.2", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2159,7 +2542,7 @@ checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" dependencies = [ "windows-result 0.2.0", "windows-strings", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2168,7 +2551,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2177,7 +2560,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2187,7 +2570,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" dependencies = [ "windows-result 0.2.0", - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", ] [[package]] @@ -2196,7 +2588,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2205,7 +2597,22 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] @@ -2214,28 +2621,46 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -2248,30 +2673,64 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + [[package]] name = "write16" version = "1.0.0" diff --git a/Cargo.toml b/Cargo.toml index 2bebb0a..08f0001 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ serde_json = "1.0.134" [dependencies] anyhow = "1.0.95" -axum = { version = "0.7.9", features = ["json"] } +axum = { version = "0.7.9", features = ["http2", "json"] } base64 = { version = "0.22.1", default-features = false, features = ["std"] } # brotli = { version = "7.0.0", default-features = false, features = ["std"] } bytes = "1.9.0" @@ -24,13 +24,14 @@ futures = { version = "0.3.31", default-features = false, features = ["std"] } gif = { version = "0.13.1", default-features = false, features = ["std"] } hex = { version = "0.4.3", default-features = false, features = ["std"] } image = { version = "0.25.5", default-features = false, features = ["jpeg", "png", "gif", "webp"] } -lazy_static = "1.5.0" +oauth2 = { version = "4.4.2", default-features = false, features = ["reqwest", "rustls-tls"] } paste = "1.0.15" prost = "0.13.4" rand = { version = "0.8.5", default-features = false, features = ["std", "std_rng"] } regex = { version = "1.11.1", default-features = false, features = ["std", "perf"] } -reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] } -rusqlite = { version = "0.32.1", features = ["bundled"], optional = true } +reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "json", "stream", "rustls-tls", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] } +ring = { version = "0.17.8", default-features = false, features = ["alloc"] } +rusqlite = { version = "0.32.1", features = ["bundled", "chrono"] } serde = { version = "1.0.217", default-features = false, features = ["std", "derive"] } serde_json = "1.0.134" sha2 = { version = "0.10.8", default-features = false } @@ -47,7 +48,3 @@ codegen-units = 1 panic = 'abort' strip = true opt-level = 3 - -[features] -default = [] -sqlite = ["dep:rusqlite"] diff --git a/src/app.rs b/src/app.rs index 9b33bc5..9195293 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,6 +1,5 @@ pub mod config; pub mod constant; -#[cfg(feature = "sqlite")] pub mod db; pub mod model; pub mod lazy; diff --git a/src/app/config.rs b/src/app/config.rs index d7c7ab7..c4c952f 100644 --- a/src/app/config.rs +++ b/src/app/config.rs @@ -1,7 +1,7 @@ use super::{ - constant::{HEADER_NAME_AUTHORIZATION, AUTHORIZATION_BEARER_PREFIX}, + constant::AUTHORIZATION_BEARER_PREFIX, model::{AppConfig, AppState}, - lazy::AUTH_TOKEN, + lazy::ADMIN_AUTH_TOKEN, }; use crate::common::models::{ config::{ConfigData, ConfigUpdateRequest}, @@ -9,7 +9,7 @@ use crate::common::models::{ }; use axum::{ extract::State, - http::{HeaderMap, StatusCode}, + http::{header::AUTHORIZATION, HeaderMap, StatusCode}, Json, }; use std::sync::Arc; @@ -59,7 +59,7 @@ pub async fn handle_config_update( Json(request): Json, ) -> Result>, (StatusCode, Json)> { let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(( @@ -72,7 +72,7 @@ pub async fn handle_config_update( }), ))?; - if auth_header != AUTH_TOKEN.as_str() { + if auth_header != ADMIN_AUTH_TOKEN.as_str() { return Err(( StatusCode::UNAUTHORIZED, Json(ErrorResponse { diff --git a/src/app/constant.rs b/src/app/constant.rs index a23503f..c114f05 100644 --- a/src/app/constant.rs +++ b/src/app/constant.rs @@ -5,13 +5,20 @@ macro_rules! def_pub_const { } def_pub_const!(PKG_VERSION, env!("CARGO_PKG_VERSION")); -// def_pub_const!(PKG_NAME, env!("CARGO_PKG_NAME")); +def_pub_const!(PKG_NAME, env!("CARGO_PKG_NAME")); // def_pub_const!(PKG_DESCRIPTION, env!("CARGO_PKG_DESCRIPTION")); // def_pub_const!(PKG_AUTHORS, env!("CARGO_PKG_AUTHORS")); // def_pub_const!(PKG_REPOSITORY, env!("CARGO_PKG_REPOSITORY")); def_pub_const!(EMPTY_STRING, ""); +// v1 +def_pub_const!(ROUTE_MODELS_PATH, "/models"); +def_pub_const!(ROUTE_CHAT_PATH, "/chat/completions"); + +// api +def_pub_const!(ROUTE_API_PATH, "/api"); + def_pub_const!(ROUTE_ROOT_PATH, "/"); def_pub_const!(ROUTE_HEALTH_PATH, "/health"); def_pub_const!(ROUTE_GET_CHECKSUM, "/get-checksum"); @@ -28,15 +35,15 @@ def_pub_const!(ROUTE_SHARED_JS_PATH, "/static/shared.js"); def_pub_const!(ROUTE_ABOUT_PATH, "/about"); def_pub_const!(ROUTE_README_PATH, "/readme"); -def_pub_const!(DEFAULT_TOKEN_FILE_NAME, ".token"); -def_pub_const!(DEFAULT_TOKEN_LIST_FILE_NAME, ".token-list"); +// api/auth +def_pub_const!(ROUTE_AUTH_PATH, "/auth"); +def_pub_const!(ROUTE_AUTH_CALLBACK_PATH, "/callback"); +def_pub_const!(ROUTE_AUTH_INITIATE_PATH, "/initiate"); -def_pub_const!(STATUS_SUCCESS, "success"); -def_pub_const!(STATUS_FAILED, "failed"); +def_pub_const!(FALSE, "false"); +def_pub_const!(TRUE, "true"); -def_pub_const!(HEADER_NAME_CONTENT_TYPE, "content-type"); -def_pub_const!(HEADER_NAME_AUTHORIZATION, "authorization"); -def_pub_const!(HEADER_NAME_LOCATION, "Location"); +def_pub_const!(HEADER_NAME_GHOST_MODE, "x-ghost-mode"); def_pub_const!(CONTENT_TYPE_PROTO, "application/proto"); def_pub_const!(CONTENT_TYPE_CONNECT_PROTO, "application/connect+proto"); diff --git a/src/app/db.rs b/src/app/db.rs index b29b33b..e57e67d 100644 --- a/src/app/db.rs +++ b/src/app/db.rs @@ -1,262 +1,190 @@ -use crate::app::model::{RequestLog, TokenInfo}; -use crate::common::models::usage::UserUsageInfo; -use chrono::{DateTime, Local}; -use lazy_static::lazy_static; -use rusqlite::params; +mod logs; +mod tokens; +mod users; + +use chrono::Utc; use rusqlite::{Connection, Result}; -use std::path::Path; -use std::sync::Mutex; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Mutex, OnceLock}; +use tokio::time::{self, Duration}; -const DB_PATH: &str = "logs/sqlite.db"; - -pub struct AppDb { +pub struct Database { conn: Connection, } -impl AppDb { - pub fn new() -> Result { - // 确保目录存在 - if let Some(parent) = Path::new(DB_PATH).parent() { - std::fs::create_dir_all(parent).map_err(|e| { - rusqlite::Error::SqliteFailure( - rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_IOERR), - Some(e.to_string()), - ) - })?; - } +// 全局静态 Database 实例 +static DB: OnceLock> = OnceLock::new(); - let conn = Connection::open(DB_PATH)?; +// 用于控制清理任务的标志 +static CLEANER_RUNNING: AtomicBool = AtomicBool::new(false); - // 启用WAL模式以提升性能 - conn.execute_batch("PRAGMA journal_mode = WAL")?; +impl Database { + pub fn new(path: &str) -> Result { + let conn = Connection::open(path)?; - // 创建token信息表 - conn.execute( - "CREATE TABLE IF NOT EXISTS token_infos ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - token TEXT NOT NULL UNIQUE, - checksum TEXT NOT NULL, - alias TEXT, - fast_requests INTEGER, - max_fast_requests INTEGER - )", - [], + // 启用 WAL 模式 + conn.execute_batch( + " + PRAGMA journal_mode = WAL; -- 启用 WAL 模式 + PRAGMA synchronous = NORMAL; -- 适度的同步模式 + PRAGMA cache_size = -64000; -- 64MB 缓存 + PRAGMA foreign_keys = ON; -- 启用外键约束 + PRAGMA temp_store = MEMORY; -- 临时表使用内存 + PRAGMA mmap_size = 30000000000; -- 30GB mmap + ", )?; - // 创建请求日志表 - conn.execute( - "CREATE TABLE IF NOT EXISTS request_logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL, - model TEXT NOT NULL, - token_id INTEGER NOT NULL, - prompt TEXT, - stream BOOLEAN NOT NULL, - status TEXT NOT NULL, - error TEXT, - FOREIGN KEY(token_id) REFERENCES token_infos(id) - )", - [], - )?; - - // 创建索引 - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_token ON token_infos(token)", - [], - )?; - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_timestamp_model ON request_logs(timestamp, model)", - [], - )?; + // 按照依赖顺序初始化表 + Self::init_users_table(&conn)?; + Self::init_tokens_table(&conn)?; + Self::init_logs_table(&conn)?; Ok(Self { conn }) } - fn get_or_create_token_info(&self, token_info: &TokenInfo) -> Result { - let mut stmt = self.conn.prepare_cached( - "INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests) - VALUES (?1, ?2, ?3, ?4, ?5) - RETURNING id" - )?; - - stmt.query_row( - params![ - &token_info.token, - &token_info.checksum, - &token_info.alias, - token_info.usage.as_ref().map(|u| u.fast_requests), - token_info.usage.as_ref().map(|u| u.max_fast_requests), - ], - |row| row.get(0), - ) - } - - pub fn add_log(&self, log: &RequestLog) -> Result<()> { - let token_id = self.get_or_create_token_info(&log.token_info)?; - - self.conn.execute( - "INSERT INTO request_logs (timestamp, model, token_id, prompt, stream, status, error) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", - params![ - log.timestamp.to_rfc3339(), - &log.model, - token_id, - &log.prompt, - log.stream, - &log.status, - &log.error, - ], - )?; - Ok(()) - } - - fn map_row_to_log(&self, row: &rusqlite::Row) -> Result { - let token_id: i64 = row.get(3)?; - let token_info = self.get_token_info_by_id(token_id)?; - - Ok(RequestLog { - id: row.get(0)?, - timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?) - .unwrap() - .with_timezone(&Local), - model: row.get(2)?, - token_info, - prompt: row.get(4)?, - stream: row.get(5)?, - status: row.get(6)?, - error: row.get(7)?, + pub fn init(path: &str) -> Result<()> { + let db = Database::new(path)?; + DB.set(Mutex::new(db)).map_err(|_| { + rusqlite::Error::InvalidParameterName("Database already initialized".into()) }) } - fn get_token_info_by_id(&self, id: i64) -> Result { - let mut stmt = self.conn.prepare_cached( - "SELECT token, checksum, alias, fast_requests, max_fast_requests - FROM token_infos - WHERE id = ?", - )?; + pub fn global() -> &'static Mutex { + DB.get().expect("Database not initialized") + } - stmt.query_row([id], |row| { - Ok(TokenInfo { - token: row.get(0)?, - checksum: row.get(1)?, - alias: row.get(2)?, - usage: Some(UserUsageInfo { - fast_requests: row.get(3)?, - max_fast_requests: row.get(4)?, - }), - }) + pub fn conn(&self) -> &Connection { + &self.conn + } + + pub fn conn_mut(&mut self) -> &mut Connection { + &mut self.conn + } + + // 启动定时清理任务 + pub fn start_cleaner() { + // 确保只启动一次 + if CLEANER_RUNNING.swap(true, Ordering::SeqCst) { + return; + } + + tokio::spawn(async move { + loop { + // 等待到下一个 UTC 20:00 + let now = Utc::now(); + let next = (now.date_naive() + chrono::Duration::days(1)) + .and_hms_opt(20, 0, 0) + .unwrap(); + let duration = next.signed_duration_since(now.naive_utc()); + + time::sleep(Duration::from_secs(duration.num_seconds() as u64)).await; + + if let Err(e) = Self::clean_expired_tokens().await { + eprintln!("Failed to clean expired tokens: {}", e); + } + } + }); + } + + // 清理过期数据 + async fn clean_expired_tokens() -> Result<()> { + with_db_mut(|conn| { + let tx = conn.transaction()?; + + // 删除过期token相关的日志 + tx.execute( + "DELETE FROM logs WHERE token_id IN ( + SELECT id FROM tokens + WHERE status = 2 OR + (status = 1 AND duration > 0 AND + datetime(create_at, '+' || (duration / 86400) || ' days') < datetime('now')) + )", + [], + )?; + + // 删除过期token + tx.execute( + "DELETE FROM tokens + WHERE status = 2 OR + (status = 1 AND duration > 0 AND + datetime(create_at, '+' || (duration / 86400) || ' days') < datetime('now'))", + [], + )?; + + // 执行WAL清理 + tx.execute_batch("PRAGMA wal_checkpoint(TRUNCATE)")?; + + tx.commit() }) } - pub fn get_token_infos(&self) -> Result> { - let mut stmt = self.conn.prepare_cached( - "SELECT token, checksum, alias, fast_requests, max_fast_requests - FROM token_infos", - )?; + // 停止清理任务 + // pub fn stop_cleaner() { + // CLEANER_RUNNING.store(false, Ordering::SeqCst); + // } +} - let tokens = stmt.query_map([], |row| { - Ok(TokenInfo { - token: row.get(0)?, - checksum: row.get(1)?, - alias: row.get(2)?, - usage: Some(UserUsageInfo { - fast_requests: row.get(3)?, - max_fast_requests: row.get(4)?, - }), - }) - })?; - tokens.collect() +pub fn with_db(f: F) -> Result +where + F: FnOnce(&Connection) -> Result, +{ + let guard = Database::global().lock().expect("Database lock poisoned"); + f(guard.conn()) +} + +pub fn with_db_mut(f: F) -> Result +where + F: FnOnce(&mut Connection) -> Result, +{ + let mut guard = Database::global().lock().expect("Database lock poisoned"); + f(guard.conn_mut()) +} + +// 重新导出子模块 +pub use self::logs::*; +pub use self::tokens::*; +pub use self::users::*; + +/* +// 以下是可选的扩展功能,暂时注释掉 + +impl Drop for Database { + fn drop(&mut self) { + // 这里可以添加清理代码 + // Connection 会自动关闭,但如果有其他清理工作可以在这里进行 } +} - pub fn get_recent_logs(&self, limit: i64) -> Result> { - let mut stmt = self.conn.prepare_cached( - "SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests - FROM request_logs r - JOIN token_infos t ON r.token_id = t.id - ORDER BY r.timestamp DESC - LIMIT ?", - )?; +use std::sync::LazyLock; - let logs = stmt.query_map([limit], |row| { - Ok(RequestLog { - id: row.get(0)?, - timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?) - .unwrap() - .with_timezone(&Local), - model: row.get(2)?, - token_info: TokenInfo { - token: row.get(8)?, - checksum: row.get(9)?, - alias: row.get(10)?, - usage: Some(UserUsageInfo { - fast_requests: row.get(11)?, - max_fast_requests: row.get(12)?, - }), - }, - prompt: row.get(4)?, - stream: row.get(5)?, - status: row.get(6)?, - error: row.get(7)?, - }) - })?; - logs.collect() +static DB_CONFIG: LazyLock = LazyLock::new(|| { + DbConfig { + max_connections: 10, + timeout: std::time::Duration::from_secs(30), } +}); - pub fn get_logs_by_timerange( - &self, - start: DateTime, - end: DateTime, - ) -> Result> { - let mut stmt = self.conn.prepare_cached( - "SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests - FROM request_logs r - JOIN token_infos t ON r.token_id = t.id - WHERE r.timestamp BETWEEN ?1 AND ?2 - ORDER BY r.timestamp DESC", - )?; +struct DbConfig { + max_connections: u32, + timeout: std::time::Duration, +} - let logs = stmt.query_map([start.to_rfc3339(), end.to_rfc3339()], |row| { - Ok(RequestLog { - id: row.get(0)?, - timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?) - .unwrap() - .with_timezone(&Local), - model: row.get(2)?, - token_info: TokenInfo { - token: row.get(8)?, - checksum: row.get(9)?, - alias: row.get(10)?, - usage: Some(UserUsageInfo { - fast_requests: row.get(11)?, - max_fast_requests: row.get(12)?, - }), - }, - prompt: row.get(4)?, - stream: row.get(5)?, - status: row.get(6)?, - error: row.get(7)?, - }) - })?; - logs.collect() - } - - pub fn update_token_info(&self, token_info: &TokenInfo) -> Result<()> { - self.conn.execute( - "INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests) - VALUES (?1, ?2, ?3, ?4, ?5)", - params![ - &token_info.token, - &token_info.checksum, - &token_info.alias, - token_info.usage.as_ref().map(|u| u.fast_requests), - token_info.usage.as_ref().map(|u| u.max_fast_requests), - ], - )?; +pub fn example_usage() -> Result<()> { + Database::init("path/to/db.sqlite")?; + println!("Max connections: {}", DB_CONFIG.max_connections); + with_db(|conn| { Ok(()) - } + })?; + with_db_mut(|conn| { + let tx = conn.transaction()?; + tx.commit() + }) } +*/ -lazy_static! { - pub static ref APP_DB: Mutex = - Mutex::new(AppDb::new().expect("Failed to initialize database")); +// 在应用启动时初始化 +pub async fn init_database(path: &str) -> Result<()> { + Database::init(path)?; + Database::start_cleaner(); + Ok(()) } diff --git a/src/app/db/logs.rs b/src/app/db/logs.rs new file mode 100644 index 0000000..2196a81 --- /dev/null +++ b/src/app/db/logs.rs @@ -0,0 +1,271 @@ +use super::Database; +use crate::{app::model::{LogInfo, LogStatus, TokenInfo}, common::models::usage::UserUsageInfo}; +use chrono::Local; +use rusqlite::{params, Connection, OptionalExtension as _, Result}; +const MAX_PROMPT_LENGTH: usize = 100000; // 限制 prompt 长度为 100000 字符 +const MAX_MODEL_LENGTH: usize = 100; // 限制 model 名称长度为 100 字符 +const MAX_ERROR_LENGTH: usize = 1000; // 限制 error 信息长度为 1000 字符 +const MAX_QUERY_LIMIT: usize = 1000; // 最大查询数量限制 +pub fn insert_log(log_info: &LogInfo) -> Result { + super::with_db_mut(|conn| Database::insert_log(conn, log_info)) +} +pub fn get_logs_by_user_id(user_id: Option) -> Result> { + super::with_db_mut(|conn| Database::get_logs_by_user_id(conn, user_id)) +} +pub fn get_logs_by_token_id(token_id: i64) -> Result> { + super::with_db_mut(|conn| Database::get_logs_by_token_id(conn, token_id)) +} +pub fn get_log_by_id(id: i64) -> Result> { + super::with_db(|conn| Database::get_log_by_id(conn, id)) +} +pub fn update_log_status(id: i64, status: LogStatus, error: Option) -> Result<()> { + super::with_db_mut(|conn| Database::update_log_status(conn, id, status, error)) +} +pub fn clean_user_logs(user_id: i64, limit: usize) -> Result<()> { + super::with_db_mut(|conn| Database::clean_user_logs(conn, user_id, limit)) +} +pub fn get_user_logs_count(user_id: i64) -> Result { + super::with_db(|conn| Database::get_user_logs_count(conn, user_id)) +} +pub fn update_log_usage(log_id: i64, usage: Option) -> Result<()> { + super::with_db_mut(|conn| Database::update_log_usage(conn, log_id, usage)) +} +pub fn update_log_prompt(log_id: i64, prompt: Option) -> Result<()> { + super::with_db_mut(|conn| Database::update_log_prompt(conn, log_id, prompt)) +} +impl Database { + pub fn init_logs_table(conn: &Connection) -> Result<()> { + conn.execute( + "CREATE TABLE IF NOT EXISTS logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + token_id INTEGER NOT NULL, + prompt TEXT, + model TEXT NOT NULL, + stream BOOLEAN NOT NULL, + status INTEGER NOT NULL, + error TEXT, + FOREIGN KEY(token_id) REFERENCES tokens(id) + )", + [], + )?; + Ok(()) + } + pub fn insert_log(conn: &mut Connection, log_info: &LogInfo) -> Result { + // 输入验证 + if let Some(prompt) = &log_info.prompt { + if prompt.len() > MAX_PROMPT_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Prompt too long".to_string(), + )); + } + } + if log_info.model.len() > MAX_MODEL_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Model name too long".to_string(), + )); + } + if let Some(error) = &log_info.error { + if error.len() > MAX_ERROR_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Error message too long".to_string(), + )); + } + } + let tx = conn.transaction()?; + tx.execute( + "INSERT INTO logs ( + timestamp, token_id, prompt, model, + stream, status, error + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", + params![ + Local::now(), + log_info.token_info.id, + log_info.prompt, + log_info.model, + log_info.stream, + log_info.status, + log_info.error, + ], + )?; + let id = tx.last_insert_rowid(); + tx.commit()?; + Ok(id) + } + pub fn get_logs_by_user_id( + conn: &mut Connection, + user_id: Option, + ) -> Result> { + let mut stmt = conn.prepare( + "SELECT l.id, l.timestamp, l.prompt, l.model, l.stream, l.status, l.error, + t.id, t.create_at, t.token, t.checksum, t.alias, + t.status, t.pengding_at, t.user_id, t.is_public, t.usage + FROM logs l + JOIN tokens t ON l.token_id = t.id + WHERE t.user_id IS ?1 + ORDER BY l.timestamp DESC + LIMIT 100", + )?; + let logs_iter = stmt.query_map(params![user_id], Self::row_to_log_info)?; + let mut logs = Vec::with_capacity(100); + for log in logs_iter { + logs.push(log?); + } + Ok(logs) + } + pub fn get_logs_by_token_id(conn: &mut Connection, token_id: i64) -> Result> { + // 使用事务确保一致性 + let tx = conn.transaction()?; + // 先获取token信息 + let token = Self::get_token_by_id(&tx, token_id)? + .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows)?; + // 查询日志记录 + let logs = { + let mut stmt = tx.prepare( + "SELECT id, timestamp, prompt, model, stream, status, error + FROM logs + WHERE token_id = ?1 + ORDER BY timestamp DESC + LIMIT ?2", + )?; + let mut logs = Vec::with_capacity(100); + let logs_iter = stmt.query_map(params![token_id, MAX_QUERY_LIMIT as i64], |row| { + Ok(LogInfo { + id: row.get(0)?, + timestamp: row.get(1)?, + prompt: row.get(2)?, + model: row.get(3)?, + stream: row.get(4)?, + status: row.get(5)?, + error: row.get(6)?, + token_info: token.clone(), + }) + })?; + for log in logs_iter { + logs.push(log?); + } + logs + }; + tx.commit()?; + Ok(logs) + } + pub fn get_log_by_id(conn: &Connection, id: i64) -> Result> { + conn.query_row( + "SELECT l.id, l.timestamp, l.prompt, l.model, l.stream, l.status, l.error, + t.id, t.create_at, t.token, t.checksum, t.alias, + t.status, t.pengding_at, t.user_id, t.is_public, t.usage + FROM logs l + JOIN tokens t ON l.token_id = t.id + WHERE l.id = ?1", + params![id], + Self::row_to_log_info, + ) + .optional() + } + pub fn update_log_status( + conn: &mut Connection, + id: i64, + status: LogStatus, + error: Option, + ) -> Result<()> { + // 验证 error 长度 + if let Some(error_msg) = &error { + if error_msg.len() > MAX_ERROR_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Error message too long".to_string(), + )); + } + } + let tx = conn.transaction()?; + tx.execute( + "UPDATE logs SET status = ?1, error = ?2 WHERE id = ?3", + params![status, error, id], + )?; + tx.commit()?; + Ok(()) + } + fn row_to_log_info(row: &rusqlite::Row<'_>) -> Result { + let token_info = TokenInfo { + id: row.get(7)?, + create_at: row.get(8)?, + token: row.get(9)?, + checksum: row.get(10)?, + alias: row.get(11)?, + status: row.get(12)?, + pengding_at: row.get(13)?, + user_id: row.get(14)?, + is_public: row.get(15)?, + usage: row.get(16)?, + }; + Ok(LogInfo { + id: row.get(0)?, + timestamp: row.get(1)?, + prompt: row.get(2)?, + model: row.get(3)?, + stream: row.get(4)?, + status: row.get(5)?, + error: row.get(6)?, + token_info, + }) + } + pub fn clean_user_logs(conn: &mut Connection, user_id: i64, limit: usize) -> Result<()> { + let tx = conn.transaction()?; + // 先获取所有需要删除的日志ID + let mut stmt = tx.prepare( + "WITH RankedLogs AS ( + SELECT l.id, + ROW_NUMBER() OVER (ORDER BY l.timestamp DESC) as rn + FROM logs l + JOIN tokens t ON l.token_id = t.id + WHERE t.user_id = ?1 + ) + SELECT id FROM RankedLogs WHERE rn > ?2", + )?; + let log_ids: Vec = stmt + .query_map(params![user_id, limit as i64], |row| row.get(0))? + .collect::>>()?; + // 确保 stmt 被释放 + drop(stmt); + // 如果有需要删除的日志 + if !log_ids.is_empty() { + // 直接更新状态,不使用 IN 子句 + tx.execute( + "UPDATE logs SET status = ?1 + WHERE id IN ( + SELECT l.id + FROM logs l + JOIN tokens t ON l.token_id = t.id + WHERE t.user_id = ?2 + ORDER BY l.timestamp ASC + LIMIT -1 + OFFSET ?3 + )", + params![LogStatus::Deleted as u8, user_id, limit], + )?; + } + tx.commit()?; + Ok(()) + } + pub fn get_user_logs_count(conn: &Connection, user_id: i64) -> Result { + conn.query_row( + "SELECT COUNT(*) + FROM logs l + JOIN tokens t ON l.token_id = t.id + WHERE t.user_id = ?1 AND l.status != ?2", + params![user_id, LogStatus::Deleted], + |row| row.get(0), + ) + } + pub fn update_log_usage(conn: &mut Connection, log_id: i64, usage: Option) -> Result<()> { + let tx = conn.transaction()?; + tx.execute("UPDATE logs SET usage = ?1 WHERE id = ?2", params![usage, log_id])?; + tx.commit()?; + Ok(()) + } + pub fn update_log_prompt(conn: &mut Connection, log_id: i64, prompt: Option) -> Result<()> { + let tx = conn.transaction()?; + tx.execute("UPDATE logs SET prompt = ?1 WHERE id = ?2", params![prompt, log_id])?; + tx.commit()?; + Ok(()) + } +} diff --git a/src/app/db/tokens.rs b/src/app/db/tokens.rs new file mode 100644 index 0000000..d37075e --- /dev/null +++ b/src/app/db/tokens.rs @@ -0,0 +1,305 @@ +use crate::app::model::{TokenInfo, TokenStatus}; +use chrono::Local; +use rusqlite::{params, Connection, OptionalExtension as _, Result}; +use super::Database; +// 限制字段长度 +const MAX_TOKEN_LENGTH: usize = 1000; // 限制 token 长度为 1000 字符 +const MAX_CHECKSUM_LENGTH: usize = 200; // 限制 checksum 长度为 200 字符 +const MAX_ALIAS_LENGTH: usize = 100; // 限制 alias 长度为 100 字符 +const MAX_QUERY_LIMIT: usize = 1000; // 最大查询数量限制 +pub fn insert_token(token_info: &TokenInfo) -> Result { + super::with_db_mut(|conn| Database::insert_token(conn, token_info)) +} +pub fn get_tokens_by_user_id(user_id: Option) -> Result> { + super::with_db(|conn| Database::get_tokens_by_user_id(conn, user_id)) +} +pub fn get_available_tokens_by_user_id(user_id: Option) -> Result> { + super::with_db(|conn| Database::get_available_tokens_by_user_id(conn, user_id)) +} +pub fn get_token_by_id(id: i64) -> Result> { + super::with_db(|conn| Database::get_token_by_id(conn, id)) +} +pub fn get_token_by_token(token: &str) -> Result> { + super::with_db(|conn| Database::get_token_by_token(conn, token)) +} +pub fn update_token_status(id: i64, status: TokenStatus) -> Result<()> { + super::with_db_mut(|conn| Database::update_token_status(conn, id, status)) +} +pub fn delete_expired_tokens() -> Result<()> { + super::with_db_mut(|conn| Database::delete_expired_tokens(conn)) +} +pub fn get_token_by_alias_and_user( + alias: &str, + current_user_id: i64, + target_user_id: Option, +) -> Result> { + super::with_db(|conn| { + Database::get_token_by_alias_and_user(conn, alias, current_user_id, target_user_id) + }) +} +pub fn update_token(token_info: &TokenInfo) -> Result<()> { + super::with_db_mut(|conn| Database::update_token(conn, token_info)) +} +impl Database { + pub fn init_tokens_table(conn: &Connection) -> Result<()> { + conn.execute( + "CREATE TABLE IF NOT EXISTS tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + create_at TEXT NOT NULL, + token TEXT NOT NULL UNIQUE, + checksum TEXT NOT NULL, + alias TEXT, + status INTEGER NOT NULL, + pengding_at TEXT NOT NULL, + user_id INTEGER NOT NULL, + is_public BOOLEAN NOT NULL DEFAULT 0, + usage TEXT, + FOREIGN KEY(user_id) REFERENCES users(id), + UNIQUE(alias, user_id) + )", + [], + )?; + Ok(()) + } + pub fn insert_token(conn: &mut Connection, token_info: &TokenInfo) -> Result { + // 输入验证 + if token_info.token.len() > MAX_TOKEN_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Token too long".to_string(), + )); + } + if token_info.checksum.len() > MAX_CHECKSUM_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Checksum too long".to_string(), + )); + } + if let Some(alias) = &token_info.alias { + if alias.len() > MAX_ALIAS_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Alias too long".to_string(), + )); + } + } + let tx = conn.transaction()?; + tx.execute( + "INSERT INTO tokens ( + create_at, token, checksum, alias, + status, pengding_at, user_id, is_public, usage + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + params![ + Local::now(), + token_info.token, + token_info.checksum, + token_info.alias, + token_info.status, + token_info.user_id, + token_info.is_public, + token_info.usage, + ], + )?; + let id = tx.last_insert_rowid(); + tx.commit()?; + Ok(id) + } + pub fn get_tokens_by_user_id( + conn: &Connection, + user_id: Option, + ) -> Result> { + let mut stmt = conn.prepare( + "SELECT id, create_at, token, checksum, alias, + status, pengding_at, user_id, is_public, usage + FROM tokens + WHERE user_id IS ?1 + LIMIT ?2", + )?; + let tokens_iter = stmt.query_map( + params![user_id, MAX_QUERY_LIMIT as i64], + Self::row_to_token_info, + )?; + let mut tokens = Vec::with_capacity(100); + for token in tokens_iter { + tokens.push(token?); + } + Ok(tokens) + } + pub fn get_available_tokens_by_user_id( + conn: &Connection, + user_id: Option, + ) -> Result> { + let mut stmt = conn.prepare( + "SELECT id, create_at, token, checksum, alias, + status, pengding_at, user_id, is_public, usage + FROM tokens + WHERE status = 1 AND datetime('now') >= pengding_at AND user_id IS ?1 + LIMIT ?2", + )?; + let tokens_iter = stmt.query_map( + params![user_id, MAX_QUERY_LIMIT as i64], + Self::row_to_token_info, + )?; + let mut tokens = Vec::with_capacity(100); + for token in tokens_iter { + tokens.push(token?); + } + Ok(tokens) + } + pub fn get_token_by_id(conn: &Connection, id: i64) -> Result> { + conn.query_row( + "SELECT id, create_at, token, checksum, alias, + status, pengding_at, user_id, is_public, usage + FROM tokens + WHERE id = ?1", + params![id], + Self::row_to_token_info, + ) + .optional() + } + pub fn get_token_by_token(conn: &Connection, token: &str) -> Result> { + // 输入验证 + if token.len() > MAX_TOKEN_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Token too long".to_string(), + )); + } + conn.query_row( + "SELECT id, create_at, token, checksum, alias, + status, pengding_at, user_id, is_public, usage + FROM tokens + WHERE token = ?1", + params![token], + Self::row_to_token_info, + ) + .optional() + } + pub fn update_token_status(conn: &mut Connection, id: i64, status: TokenStatus) -> Result<()> { + let tx = conn.transaction()?; + tx.execute( + "UPDATE tokens SET status = ?1 WHERE id = ?2", + params![status, id], + )?; + if status == TokenStatus::Pending { + tx.execute( + "UPDATE tokens SET pengding_at = ?1 WHERE id = ?2", + params![Local::now() + chrono::Duration::minutes(1), id], + )?; + } + tx.commit()?; + Ok(()) + } + pub fn delete_expired_tokens(conn: &mut Connection) -> Result<()> { + // 开始事务 + let tx = conn.transaction()?; + // 删除过期token相关的日志 + tx.execute( + "DELETE FROM logs WHERE token_id IN ( + SELECT id FROM tokens + WHERE status = ?1 + )", + params![TokenStatus::Expired], + )?; + // 删除过期的token + tx.execute( + "DELETE FROM tokens WHERE status = ?1", + params![TokenStatus::Expired], + )?; + // 提交事务 + tx.commit()?; + Ok(()) + } + pub fn get_token_by_alias_and_user( + conn: &Connection, + alias: &str, + current_user_id: i64, + target_user_id: Option, + ) -> Result> { + // 管理员可以查看所有token + let is_admin = current_user_id == 0; + let sql = if is_admin { + // 管理员查询:如果指定了user_id就查指定用户的,否则查所有 + if target_user_id.is_some() { + "SELECT id, create_at, token, checksum, alias, + status, user_id, is_public, usage + FROM tokens + WHERE alias = ?1 + AND status IN (?2, ?3) + AND user_id = ?4" + } else { + "SELECT id, create_at, token, checksum, alias, + status, pengding_at, user_id, is_public, usage + FROM tokens + WHERE alias = ?1 + AND status IN (?2, ?3)" + } + } else { + // 普通用户查询:只能查看自己的token + "SELECT id, create_at, token, checksum, alias, + status, pengding_at, user_id, is_public, usage + FROM tokens + WHERE alias = ?1 + AND status IN (?2, ?3) + AND user_id = ?4" + }; + let target_id = target_user_id.map(|id| id); + let params: Vec<&dyn rusqlite::ToSql> = if is_admin { + if let Some(ref id) = target_id { + vec![&alias, &TokenStatus::Active, &TokenStatus::Pending, id] + } else { + vec![&alias, &TokenStatus::Active, &TokenStatus::Pending] + } + } else { + vec![ + &alias, + &TokenStatus::Active, + &TokenStatus::Pending, + ¤t_user_id, + ] + }; + conn.query_row(sql, params.as_slice(), Self::row_to_token_info) + .optional() + } + pub fn update_token(conn: &mut Connection, token_info: &TokenInfo) -> Result<()> { + // 输入验证 + if token_info.checksum.len() > MAX_CHECKSUM_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Checksum too long".to_string(), + )); + } + if let Some(alias) = &token_info.alias { + if alias.len() > MAX_ALIAS_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Alias too long".to_string(), + )); + } + } + let tx = conn.transaction()?; + tx.execute( + "UPDATE tokens SET + checksum = ?1, + alias = ?2, + is_public = ?3 + WHERE id = ?4", + params![ + token_info.checksum, + token_info.alias, + token_info.is_public, + token_info.id, + ], + )?; + tx.commit()?; + Ok(()) + } + fn row_to_token_info(row: &rusqlite::Row<'_>) -> Result { + Ok(TokenInfo { + id: row.get(0)?, + create_at: row.get(1)?, + token: row.get(2)?, + checksum: row.get(3)?, + alias: row.get(4)?, + status: row.get(5)?, + pengding_at: row.get(6)?, + user_id: row.get(7)?, + is_public: row.get(8)?, + usage: row.get(9)?, + }) + } +} diff --git a/src/app/db/users.rs b/src/app/db/users.rs new file mode 100644 index 0000000..eee2f28 --- /dev/null +++ b/src/app/db/users.rs @@ -0,0 +1,198 @@ +use crate::app::model::UserInfo; +use crate::common::utils::oauth::ForumUser; +use chrono::{DateTime, Local}; +use rusqlite::{params, Connection, OptionalExtension as _, Result}; +use crate::app::lazy::ADMIN_AUTH_TOKEN; +use super::Database; +// 限制字段长度 +const MAX_USERNAME_LENGTH: usize = 100; // 限制用户名长度为 100 字符 +const MAX_NAME_LENGTH: usize = 100; // 限制姓名长度为 100 字符 +const MAX_QUERY_LIMIT: usize = 1000; // 最大查询数量限制 +pub fn insert_user(user: &ForumUser) -> Result { + super::with_db_mut(|conn| Database::insert_user(conn, user)) +} +pub fn get_user_by_id(id: i64) -> Result> { + super::with_db(|conn| Database::get_user_by_id(conn, id)) +} +pub fn get_user_by_forum_id(forum_id: i64) -> Result> { + super::with_db(|conn| Database::get_user_by_forum_id(conn, forum_id)) +} +pub fn update_user_ban(forum_id: i64, ban_expired_at: Option>, ban_count: u32) -> Result<()> { + super::with_db_mut(|conn| Database::update_user_ban(conn, forum_id, ban_expired_at, ban_count)) +} +pub fn update_user_auth_token(forum_id: i64, auth_token: Option) -> Result<()> { + super::with_db_mut(|conn| Database::update_user_auth_token(conn, forum_id, auth_token)) +} +pub fn get_user_by_auth_token(auth_token: &str) -> Result> { + super::with_db(|conn| Database::get_user_by_auth_token(conn, auth_token)) +} +impl Database { + pub fn init_users_table(conn: &Connection) -> Result<()> { + conn.execute( + "CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + forum_id INTEGER NOT NULL UNIQUE, + username TEXT NOT NULL, + name TEXT NOT NULL, + trust_level INTEGER NOT NULL, + created_at TEXT NOT NULL, + ban_expired_at TEXT, + ban_count INTEGER NOT NULL, + auth_token TEXT + )", + [], + )?; + let admin_exists: bool = conn + .query_row("SELECT EXISTS(SELECT 1 FROM users WHERE id = 0)", [], |row| { + row.get(0) + })?; + if !admin_exists { + conn.execute( + "INSERT INTO users ( + id, forum_id, username, name, trust_level, + created_at, ban_expired_at, ban_count, auth_token + ) VALUES ( + 0, 0, 'admin', 'Administrator', 255, + ?1, NULL, 0, ?2 + )", + params![Local::now(), &*ADMIN_AUTH_TOKEN], + )?; + } + Ok(()) + } + pub fn insert_user(conn: &mut Connection, user: &ForumUser) -> Result { + // 输入验证 + if user.username.len() > MAX_USERNAME_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Username too long".to_string(), + )); + } + if user.name.len() > MAX_NAME_LENGTH { + return Err(rusqlite::Error::InvalidParameterName( + "Name too long".to_string(), + )); + } + let tx = conn.transaction()?; + tx.execute( + "INSERT INTO users (forum_id, username, name, trust_level, created_at, ban_expired_at, ban_count, auth_token) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + params![ + user.id, + user.username, + user.name, + user.trust_level, + Local::now(), + Option::>::None, + 0, + Option::::None + ], + )?; + let id = tx.last_insert_rowid(); + tx.commit()?; + Ok(id) + } + pub fn get_user_by_id(conn: &Connection, id: i64) -> Result> { + conn.query_row( + "SELECT id, forum_id, username, name, trust_level, created_at, ban_expired_at, ban_count, auth_token + FROM users + WHERE id = ?1 + LIMIT 1", + params![id], + |row| { + Ok(UserInfo { + id: row.get(0)?, + forum_id: row.get(1)?, + username: row.get(2)?, + name: row.get(3)?, + trust_level: row.get(4)?, + created_at: row.get(5)?, + ban_expired_at: row.get(6)?, + ban_count: row.get(7)?, + auth_token: row.get(8)?, + }) + }, + ) + .optional() + } + pub fn get_user_by_forum_id(conn: &Connection, forum_id: i64) -> Result> { + conn.query_row( + "SELECT id, forum_id, username, name, trust_level, created_at, ban_expired_at, ban_count, auth_token + FROM users + WHERE forum_id = ?1 + LIMIT 1", + params![forum_id], + |row| { + Ok(UserInfo { + id: row.get(0)?, + forum_id: row.get(1)?, + username: row.get(2)?, + name: row.get(3)?, + trust_level: row.get(4)?, + created_at: row.get(5)?, + ban_expired_at: row.get(6)?, + ban_count: row.get(7)?, + auth_token: row.get(8)?, + }) + }, + ) + .optional() + } + pub fn update_user_ban(conn: &mut Connection, forum_id: i64, ban_expired_at: Option>, ban_count: u32) -> Result<()> { + let tx = conn.transaction()?; + tx.execute( + "UPDATE users SET ban_expired_at = ?1, ban_count = ?2 WHERE forum_id = ?3", + params![ban_expired_at, ban_count, forum_id], + )?; + tx.commit()?; + Ok(()) + } + pub fn update_user_auth_token( + conn: &mut Connection, + forum_id: i64, + auth_token: Option, + ) -> Result<()> { + let tx = conn.transaction()?; + // 检查 auth_token 是否已存在 + if let Some(token) = &auth_token { + let existing = tx.query_row( + "SELECT forum_id FROM users WHERE auth_token = ?1 AND forum_id != ?2", + params![token, forum_id], + |_| Ok(()), + ); + if existing.optional()?.is_some() { + return Err(rusqlite::Error::InvalidParameterName( + "Auth token already exists".to_string(), + )); + } + } + tx.execute( + "UPDATE users SET auth_token = ?1 WHERE forum_id = ?2", + params![auth_token, forum_id], + )?; + tx.commit()?; + Ok(()) + } + pub fn get_user_by_auth_token(conn: &Connection, auth_token: &str) -> Result> { + conn.query_row( + "SELECT id, forum_id, username, name, trust_level, created_at, ban_expired_at, ban_count, auth_token + FROM users + WHERE auth_token = ?1 + LIMIT 1", + params![auth_token], + |row| { + Ok(UserInfo { + id: row.get(0)?, + forum_id: row.get(1)?, + username: row.get(2)?, + name: row.get(3)?, + trust_level: row.get(4)?, + created_at: row.get(5)?, + ban_expired_at: row.get(6)?, + ban_count: row.get(7)?, + auth_token: row.get(8)?, + }) + }, + ) + .optional() + } +} diff --git a/src/app/lazy.rs b/src/app/lazy.rs index 74ab67e..d1defaf 100644 --- a/src/app/lazy.rs +++ b/src/app/lazy.rs @@ -1,5 +1,5 @@ use crate::{ - app::constant::{DEFAULT_TOKEN_FILE_NAME, DEFAULT_TOKEN_LIST_FILE_NAME, EMPTY_STRING}, + app::constant::EMPTY_STRING, common::utils::parse_string_from_env, }; use std::sync::LazyLock; @@ -27,18 +27,8 @@ macro_rules! def_pub_static { // }; // } -def_pub_static!(ROUTE_PREFIX, env: "ROUTE_PREFIX", default: EMPTY_STRING); -def_pub_static!(AUTH_TOKEN, env: "AUTH_TOKEN", default: EMPTY_STRING); -def_pub_static!(TOKEN_FILE, env: "TOKEN_FILE", default: DEFAULT_TOKEN_FILE_NAME); -def_pub_static!(TOKEN_LIST_FILE, env: "TOKEN_LIST_FILE", default: DEFAULT_TOKEN_LIST_FILE_NAME); -def_pub_static!( - ROUTE_MODELS_PATH, - format!("{}/v1/models", *ROUTE_PREFIX) -); -def_pub_static!( - ROUTE_CHAT_PATH, - format!("{}/v1/chat/completions", *ROUTE_PREFIX) -); +def_pub_static!(ROUTE_PREFIX, env: "ROUTE_PREFIX", default: "/v1"); +def_pub_static!(PUBLIC_AUTH_TOKEN, env: "PUBLIC_AUTH_TOKEN", default: EMPTY_STRING); pub static START_TIME: LazyLock> = LazyLock::new(chrono::Local::now); @@ -55,6 +45,18 @@ pub static CURSOR_API2_BASE_URL: LazyLock = LazyLock::new(|| { format!("https://{}/aiserver.v1.AiService/", *CURSOR_API2_HOST) }); +pub static OAUTH_CLIENT_ID: LazyLock = LazyLock::new(|| { + parse_string_from_env("OAUTH_CLIENT_ID", EMPTY_STRING).trim().to_string() +}); + +pub static OAUTH_CLIENT_SECRET: LazyLock = LazyLock::new(|| { + parse_string_from_env("OAUTH_CLIENT_SECRET", EMPTY_STRING).trim().to_string() +}); + +pub static OAUTH_REDIRECT_URI: LazyLock = LazyLock::new(|| { + parse_string_from_env("OAUTH_REDIRECT_URI", EMPTY_STRING).trim().to_string() +}); + // pub static DEBUG: LazyLock = LazyLock::new(|| parse_bool_from_env("DEBUG", false)); // #[macro_export] @@ -65,3 +67,5 @@ pub static CURSOR_API2_BASE_URL: LazyLock = LazyLock::new(|| { // } // }; // } + +def_pub_static!(ADMIN_AUTH_TOKEN, env: "ADMIN_AUTH_TOKEN", default: EMPTY_STRING); diff --git a/src/app/model.rs b/src/app/model.rs index bba146f..9fe3808 100644 --- a/src/app/model.rs +++ b/src/app/model.rs @@ -1,15 +1,10 @@ -use crate::{ - app::constant::{ - ERR_INVALID_PATH, ERR_RESET_CONFIG, ERR_UPDATE_CONFIG, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, - ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_SHARED_JS_PATH, - ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENINFO_PATH, - }, - common::models::usage::UserUsageInfo, +use crate::app::constant::{ + ERR_INVALID_PATH, ERR_RESET_CONFIG, ERR_UPDATE_CONFIG, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, + ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_SHARED_JS_PATH, + ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENINFO_PATH, }; -use crate::chat::model::Message; -use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use std::sync::RwLock; +use std::sync::{LazyLock, RwLock}; // 页面内容类型枚举 #[derive(Clone, Serialize, Deserialize)] @@ -87,16 +82,10 @@ pub struct Pages { pub struct AppState { pub total_requests: u64, pub active_requests: u64, - #[cfg(not(feature = "sqlite"))] - pub request_logs: Vec, - #[cfg(not(feature = "sqlite"))] - pub token_infos: Vec, } // 全局配置实例 -lazy_static! { - pub static ref APP_CONFIG: RwLock = RwLock::new(AppConfig::default()); -} +pub static APP_CONFIG: LazyLock> = LazyLock::new(|| RwLock::new(AppConfig::default())); impl Default for AppConfig { fn default() -> Self { @@ -275,17 +264,6 @@ impl AppConfig { } impl AppState { - #[cfg(not(feature = "sqlite"))] - pub fn new(token_infos: Vec) -> Self { - Self { - total_requests: 0, - active_requests: 0, - request_logs: Vec::new(), - token_infos, - } - } - - #[cfg(feature = "sqlite")] pub fn new() -> Self { Self { total_requests: 0, @@ -294,51 +272,5 @@ impl AppState { } } -// 请求日志 -#[derive(Serialize, Clone)] -pub struct RequestLog { - pub id: u64, - pub timestamp: chrono::DateTime, - pub model: String, - pub token_info: TokenInfo, - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, - pub stream: bool, - pub status: &'static str, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -// pub struct PromptList(Option); - -// impl PromptList { -// pub fn to_vec(&self) -> Vec<> -// } - -// 聊天请求 -#[derive(Deserialize)] -pub struct ChatRequest { - pub model: String, - pub messages: Vec, - #[serde(default)] - pub stream: bool, -} - -// 用于存储 token 信息 -#[derive(Serialize, Clone)] -pub struct TokenInfo { - pub token: String, - pub checksum: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub alias: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, -} - -// TokenUpdateRequest 结构体 -#[derive(Deserialize)] -pub struct TokenUpdateRequest { - pub tokens: String, - #[serde(default)] - pub token_list: Option, -} +mod db; +pub use db::*; diff --git a/src/app/model/db.rs b/src/app/model/db.rs new file mode 100644 index 0000000..073584b --- /dev/null +++ b/src/app/model/db.rs @@ -0,0 +1,148 @@ +use crate::{chat::model::Message, common::models::usage::UserUsageInfo}; +use chrono::{DateTime, Local}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Clone)] +pub enum LogStatus { + #[serde(rename = "pending")] + Pending, + #[serde(rename = "success")] + Success, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "deleted")] + Deleted, +} + +impl rusqlite::types::FromSql for LogStatus { + fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult { + match value.as_i64()? { + 0 => Ok(LogStatus::Pending), + 1 => Ok(LogStatus::Success), + 2 => Ok(LogStatus::Failed), + 3 => Ok(LogStatus::Deleted), + _ => Err(rusqlite::types::FromSqlError::OutOfRange(value.as_i64()?)), + } + } +} + +impl rusqlite::ToSql for LogStatus { + fn to_sql(&self) -> rusqlite::Result> { + Ok(rusqlite::types::ToSqlOutput::from(match self { + LogStatus::Pending => 0u8, + LogStatus::Success => 1u8, + LogStatus::Failed => 2u8, + LogStatus::Deleted => 3u8, + })) + } +} + +// 请求日志 +#[derive(Serialize, Clone)] +pub struct LogInfo { + #[serde(skip_serializing)] + pub id: i64, + pub timestamp: DateTime, + #[serde(skip_serializing_if = "TokenInfo::is_hide")] + pub token_info: TokenInfo, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + pub model: String, + pub stream: bool, + pub status: LogStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +// 聊天请求 +#[derive(Deserialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + #[serde(default)] + pub stream: bool, +} + +#[derive(Serialize, PartialEq, Clone)] +pub enum TokenStatus { + #[serde(rename = "pending")] + Pending, + #[serde(rename = "active")] + Active, + #[serde(rename = "expired")] + Expired, + #[serde(rename = "deleted")] + Deleted, +} + +impl rusqlite::types::FromSql for TokenStatus { + fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult { + match value.as_i64()? { + 0 => Ok(TokenStatus::Pending), + 1 => Ok(TokenStatus::Active), + 2 => Ok(TokenStatus::Expired), + 3 => Ok(TokenStatus::Deleted), + _ => Err(rusqlite::types::FromSqlError::OutOfRange(value.as_i64()?)), + } + } +} + +impl rusqlite::ToSql for TokenStatus { + fn to_sql(&self) -> rusqlite::Result> { + Ok(rusqlite::types::ToSqlOutput::from(match self { + TokenStatus::Pending => 0u8, + TokenStatus::Active => 1u8, + TokenStatus::Expired => 2u8, + TokenStatus::Deleted => 3u8, + })) + } +} + +// 用于存储 token 信息 +#[derive(Serialize, Clone)] +pub struct TokenInfo { + #[serde(skip_serializing)] + pub id: i64, + pub create_at: DateTime, + pub token: String, + pub checksum: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub alias: Option, + pub status: TokenStatus, + pub pengding_at: DateTime, + #[serde(skip_serializing)] + pub user_id: i64, + pub is_public: bool, // 公益 + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +impl TokenInfo { + pub fn is_hide(&self) -> bool { + self.status == TokenStatus::Deleted + } +} + +// TokenUpdateRequest 结构体 +#[derive(Deserialize)] +pub struct TokenUpdateRequest { + pub token: String, + pub checksum: String, + #[serde(default)] + pub alias: Option, + #[serde(default)] + pub is_public: bool, +} + +#[derive(Clone)] +pub struct UserInfo { + pub id: i64, + pub forum_id: i64, + pub username: String, // 论坛用户名 + pub name: String, // 论坛昵称 + pub trust_level: u8, + pub created_at: DateTime, + pub ban_expired_at: Option>, // 封禁到期时间 + pub ban_count: u32, // 封禁次数 + pub auth_token: Option, +} diff --git a/src/app/model/usage_check.rs b/src/app/model/usage_check.rs index b7c4848..38d5d16 100644 --- a/src/app/model/usage_check.rs +++ b/src/app/model/usage_check.rs @@ -6,7 +6,7 @@ pub enum UsageCheck { None, Default, All, - Custom(Vec<&'static str>), + Custom(Vec), } impl Default for UsageCheck { @@ -69,14 +69,14 @@ impl<'de> Deserialize<'de> for UsageCheck { return Ok(UsageCheck::None); } - let models: Vec<&'static str> = list + let models: Vec = list .split(',') .filter_map(|model| { let model = model.trim(); AVAILABLE_MODELS .iter() .find(|m| m.id == model) - .map(|m| m.id) + .map(|m| m.id.clone()) }) .collect(); diff --git a/src/chat.rs b/src/chat.rs index b557892..e5f7607 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -2,6 +2,7 @@ pub mod adapter; pub mod aiserver; pub mod constant; pub mod error; +// pub mod middleware; pub mod model; pub mod route; pub mod service; diff --git a/src/chat/adapter.rs b/src/chat/adapter.rs index f57f0ea..397548d 100644 --- a/src/chat/adapter.rs +++ b/src/chat/adapter.rs @@ -14,9 +14,9 @@ use super::{ conversation_message, image_proto, ConversationMessage, ExplicitContext, GetChatRequest, ImageProto, ModelDetails, }, - constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS}, - model::{Message, MessageContent, Role}, -}; + constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT}, + model::{Message, MessageContent, Role, Model}, +}; async fn process_chat_inputs(inputs: Vec) -> (String, Vec) { // 收集 system 指令 @@ -380,7 +380,7 @@ pub async fn encode_chat_message( workspace_id: None, external_links: vec![], commit_notes: vec![], - long_context_mode: if LONG_CONTEXT_MODELS.contains(&model_name) { + long_context_mode: if Model::is_long_context(model_name) { Some(true) } else { None diff --git a/src/chat/constant.rs b/src/chat/constant.rs index 84a8d6a..543187b 100644 --- a/src/chat/constant.rs +++ b/src/chat/constant.rs @@ -1,3 +1,5 @@ +use std::sync::LazyLock; + use super::model::Model; macro_rules! def_pub_const { @@ -9,8 +11,8 @@ def_pub_const!(ERR_UNSUPPORTED_GIF, "不支持动态 GIF"); def_pub_const!(ERR_UNSUPPORTED_IMAGE_FORMAT, "不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF"); def_pub_const!(ERR_NODATA, "No data"); -const MODEL_OBJECT: &str = "model"; -const CREATED: &i64 = &1706659200; +pub const MODEL_OBJECT: &str = "model"; +pub const CREATED: &i64 = &1706659200; def_pub_const!(ANTHROPIC, "anthropic"); def_pub_const!(CURSOR, "cursor"); @@ -42,134 +44,134 @@ def_pub_const!( ); def_pub_const!(GEMINI_2_0_FLASH_EXP, "gemini-2.0-flash-exp"); -pub const AVAILABLE_MODELS: [Model; 21] = [ +pub const AVAILABLE_MODELS: LazyLock<[Model; 21]> = LazyLock::new(|| [ Model { - id: CLAUDE_3_5_SONNET, + id: CLAUDE_3_5_SONNET.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: ANTHROPIC, }, Model { - id: GPT_4, + id: GPT_4.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: GPT_4O, + id: GPT_4O.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: CLAUDE_3_OPUS, + id: CLAUDE_3_OPUS.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: ANTHROPIC, }, Model { - id: CURSOR_FAST, + id: CURSOR_FAST.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: CURSOR, }, Model { - id: CURSOR_SMALL, + id: CURSOR_SMALL.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: CURSOR, }, Model { - id: GPT_3_5_TURBO, + id: GPT_3_5_TURBO.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: GPT_4_TURBO_2024_04_09, + id: GPT_4_TURBO_2024_04_09.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: GPT_4O_128K, + id: GPT_4O_128K.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: GEMINI_1_5_FLASH_500K, + id: GEMINI_1_5_FLASH_500K.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: GOOGLE, }, Model { - id: CLAUDE_3_HAIKU_200K, + id: CLAUDE_3_HAIKU_200K.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: ANTHROPIC, }, Model { - id: CLAUDE_3_5_SONNET_200K, + id: CLAUDE_3_5_SONNET_200K.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: ANTHROPIC, }, Model { - id: CLAUDE_3_5_SONNET_20241022, + id: CLAUDE_3_5_SONNET_20241022.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: ANTHROPIC, }, Model { - id: GPT_4O_MINI, + id: GPT_4O_MINI.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: O1_MINI, + id: O1_MINI.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: O1_PREVIEW, + id: O1_PREVIEW.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: O1, + id: O1.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: OPENAI, }, Model { - id: CLAUDE_3_5_HAIKU, + id: CLAUDE_3_5_HAIKU.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: ANTHROPIC, }, Model { - id: GEMINI_EXP_1206, + id: GEMINI_EXP_1206.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: GOOGLE, }, Model { - id: GEMINI_2_0_FLASH_THINKING_EXP, + id: GEMINI_2_0_FLASH_THINKING_EXP.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: GOOGLE, }, Model { - id: GEMINI_2_0_FLASH_EXP, + id: GEMINI_2_0_FLASH_EXP.to_string(), created: CREATED, object: MODEL_OBJECT, owned_by: GOOGLE, }, -]; +]); pub const USAGE_CHECK_MODELS: [&str; 11] = [ CLAUDE_3_5_SONNET_20241022, @@ -184,10 +186,3 @@ pub const USAGE_CHECK_MODELS: [&str; 11] = [ CLAUDE_3_HAIKU_200K, CLAUDE_3_5_SONNET_200K, ]; - -pub const LONG_CONTEXT_MODELS: [&str; 4] = [ - GPT_4O_128K, - GEMINI_1_5_FLASH_500K, - CLAUDE_3_HAIKU_200K, - CLAUDE_3_5_SONNET_200K, -]; diff --git a/src/chat/model.rs b/src/chat/model.rs index 0091dc6..6f704ce 100644 --- a/src/chat/model.rs +++ b/src/chat/model.rs @@ -80,12 +80,18 @@ pub struct Usage { // 模型定义 #[derive(Serialize, Clone)] pub struct Model { - pub id: &'static str, + pub id: String, pub created: &'static i64, pub object: &'static str, pub owned_by: &'static str, } +impl Model { + pub fn is_long_context(id :&str) -> bool { + id.ends_with("128k") || id.ends_with("500k") || id.ends_with("200k") + } +} + use crate::app::model::{AppConfig, UsageCheck}; use super::constant::USAGE_CHECK_MODELS; @@ -93,9 +99,9 @@ impl Model { pub fn is_usage_check(&self) -> bool { match AppConfig::get_usage_check() { UsageCheck::None => false, - UsageCheck::Default => USAGE_CHECK_MODELS.contains(&self.id), + UsageCheck::Default => USAGE_CHECK_MODELS.iter().any(|&x| x == self.id.as_str()), UsageCheck::All => true, - UsageCheck::Custom(models) => models.contains(&self.id), + UsageCheck::Custom(models) => models.iter().any(|x| x == &self.id), } } } @@ -103,5 +109,5 @@ impl Model { #[derive(Serialize)] pub struct ModelsResponse { pub object: &'static str, - pub data: &'static [Model], + pub data: Vec, } diff --git a/src/chat/route.rs b/src/chat/route.rs index e345a32..f43007e 100644 --- a/src/chat/route.rs +++ b/src/chat/route.rs @@ -3,8 +3,10 @@ pub use logs::{handle_logs, handle_logs_post}; mod health; pub use health::{handle_root, handle_health}; mod token; -pub use token::{handle_get_checksum, handle_update_tokeninfo, handle_get_tokeninfo, handle_update_tokeninfo_post, handle_tokeninfo_page}; +pub use token::{handle_get_checksum, handle_get_tokeninfo, handle_update_tokeninfo_post, handle_tokeninfo_page}; mod usage; pub use usage::get_user_info; mod config; pub use config::{handle_env_example, handle_config_page, handle_static, handle_readme, handle_about}; +mod auth; +pub use auth::{handle_auth_callback, handle_auth_initiate}; diff --git a/src/chat/route/auth.rs b/src/chat/route/auth.rs new file mode 100644 index 0000000..9718e31 --- /dev/null +++ b/src/chat/route/auth.rs @@ -0,0 +1,113 @@ +use crate::app::constant::ROUTE_TOKENINFO_PATH; +use crate::app::lazy::{OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OAUTH_REDIRECT_URI}; +use crate::common::utils::oauth::ForumOAuth; +use axum::http::{header::SET_COOKIE, HeaderMap}; +use axum::{extract::Query, response::Redirect}; +use base64::{engine::general_purpose::URL_SAFE, Engine}; +use ring::rand::SecureRandom as _; +use ring::{aead, rand}; +use serde::Deserialize; +use std::sync::OnceLock; + +#[derive(Deserialize)] +pub struct AuthCallback { + code: String, + state: String, +} + +// 用于加密的密钥,使用 OnceLock 确保只初始化一次 +static ENCRYPTION_KEY: OnceLock = OnceLock::new(); + +// 初始化加密密钥 +fn get_encryption_key() -> &'static aead::LessSafeKey { + ENCRYPTION_KEY.get_or_init(|| { + let rng = rand::SystemRandom::new(); + let mut key_bytes = [0u8; 32]; + rng.fill(&mut key_bytes).unwrap(); + let key = aead::UnboundKey::new(&aead::CHACHA20_POLY1305, &key_bytes).unwrap(); + aead::LessSafeKey::new(key) + }) +} + +// 加密 state +fn encrypt_state(state: &str) -> Result { + let key = get_encryption_key(); + let nonce = aead::Nonce::assume_unique_for_key([0; 12]); // 在生产环境中应使用随机 nonce + let mut in_out = state.as_bytes().to_vec(); + key.seal_in_place_append_tag(nonce, aead::Aad::empty(), &mut in_out) + .map_err(|e| e.to_string())?; + Ok(URL_SAFE.encode(in_out)) +} + +// 解密 state +fn decrypt_state(encrypted_state: &str) -> Result { + let key = get_encryption_key(); + let nonce = aead::Nonce::assume_unique_for_key([0; 12]); + let mut encrypted_data = URL_SAFE + .decode(encrypted_state) + .map_err(|e| e.to_string())?; + let decrypted = key + .open_in_place(nonce, aead::Aad::empty(), &mut encrypted_data) + .map_err(|e| e.to_string())?; + String::from_utf8(decrypted.to_vec()).map_err(|e| e.to_string()) +} + +pub async fn handle_auth_callback( + headers: HeaderMap, + Query(params): Query, +) -> Result { + let cookie_header = headers + .get("cookie") + .ok_or_else(|| "Missing cookie header".to_string())?; + + let cookie_str = cookie_header.to_str().map_err(|e| e.to_string())?; + + let encrypted_state = cookie_str + .split(';') + .find(|s| s.trim().starts_with("oauth_state=")) + .and_then(|s| s.trim().strip_prefix("oauth_state=")) + .ok_or_else(|| "Missing state cookie".to_string())?; + + // 解密 state + let expected_state = decrypt_state(encrypted_state)?; + + let oauth = ForumOAuth::new( + OAUTH_CLIENT_ID.to_string(), + OAUTH_CLIENT_SECRET.to_string(), + OAUTH_REDIRECT_URI.to_string(), + ) + .map_err(|e| e.to_string())?; + + let access_token = oauth + .exchange_code_for_token(¶ms.code, ¶ms.state, &expected_state) + .await + .map_err(|e| e.to_string())?; + + let redirect_url = format!("{}?auth={}", ROUTE_TOKENINFO_PATH, access_token); + Ok(Redirect::to(&redirect_url)) +} + +pub async fn handle_auth_initiate() -> Result<(HeaderMap, Redirect), String> { + let oauth = ForumOAuth::new( + OAUTH_CLIENT_ID.to_string(), + OAUTH_CLIENT_SECRET.to_string(), + OAUTH_REDIRECT_URI.to_string(), + ) + .map_err(|e| e.to_string())?; + + let (auth_url, state) = oauth.get_authorize_url(); + + // 加密 state + let encrypted_state = encrypt_state(state.secret())?; + + // 创建安全的 cookie + let cookie = format!( + "oauth_state={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age=300", + encrypted_state + ); + + let mut headers = HeaderMap::new(); + headers.insert(SET_COOKIE, cookie.parse().unwrap()); + + Ok((headers, Redirect::to(&auth_url.to_string()))) +} diff --git a/src/chat/route/config.rs b/src/chat/route/config.rs index 328d541..df96682 100644 --- a/src/chat/route/config.rs +++ b/src/chat/route/config.rs @@ -2,21 +2,24 @@ use crate::app::{ constant::{ CONTENT_TYPE_TEXT_CSS_WITH_UTF8, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, CONTENT_TYPE_TEXT_JS_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, - HEADER_NAME_CONTENT_TYPE, HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, - ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, + ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, + ROUTE_SHARED_STYLES_PATH, }, model::{AppConfig, PageContent}, }; use axum::{ body::Body, extract::Path, - http::StatusCode, + http::{ + header::{CONTENT_TYPE, LOCATION}, + StatusCode, + }, response::{IntoResponse, Response}, }; pub async fn handle_env_example() -> impl IntoResponse { Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(include_str!("../../../.env.example").to_string()) .unwrap() } @@ -25,15 +28,15 @@ pub async fn handle_env_example() -> impl IntoResponse { pub async fn handle_config_page() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_CONFIG_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(include_str!("../../../static/config.min.html").to_string()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(content.clone()) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(content.clone()) .unwrap(), } @@ -44,11 +47,11 @@ pub async fn handle_static(Path(path): Path) -> impl IntoResponse { "shared-styles.css" => { match AppConfig::get_page_content(ROUTE_SHARED_STYLES_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) .body(include_str!("../../../static/shared-styles.min.css").to_string()) .unwrap(), PageContent::Text(content) | PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8) .body(content.clone()) .unwrap(), } @@ -56,11 +59,11 @@ pub async fn handle_static(Path(path): Path) -> impl IntoResponse { "shared.js" => { match AppConfig::get_page_content(ROUTE_SHARED_JS_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) .body(include_str!("../../../static/shared.min.js").to_string()) .unwrap(), PageContent::Text(content) | PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8) .body(content.clone()) .unwrap(), } @@ -75,15 +78,15 @@ pub async fn handle_static(Path(path): Path) -> impl IntoResponse { pub async fn handle_about() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_ABOUT_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(include_str!("../../../static/readme.min.html").to_string()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(content.clone()) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(content.clone()) .unwrap(), } @@ -93,15 +96,15 @@ pub async fn handle_readme() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_README_PATH).unwrap_or_default() { PageContent::Default => Response::builder() .status(StatusCode::TEMPORARY_REDIRECT) - .header(HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH) + .header(LOCATION, ROUTE_ABOUT_PATH) .body(Body::empty()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), } diff --git a/src/chat/route/health.rs b/src/chat/route/health.rs index e5acdb1..5e3df92 100644 --- a/src/chat/route/health.rs +++ b/src/chat/route/health.rs @@ -1,15 +1,12 @@ use crate::{ app::{ constant::{ - CONTENT_TYPE_TEXT_HTML_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, - HEADER_NAME_CONTENT_TYPE, HEADER_NAME_LOCATION, PKG_VERSION, ROUTE_ABOUT_PATH, - ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, - ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, - ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKENINFO_PATH, - ROUTE_UPDATE_TOKENINFO_PATH, + AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, + CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, PKG_VERSION, ROUTE_HEALTH_PATH, ROUTE_ROOT_PATH, }, + db, + lazy::get_start_time, model::{AppConfig, AppState, PageContent}, - lazy::{get_start_time, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH}, }, chat::constant::AVAILABLE_MODELS, common::models::{ @@ -20,7 +17,10 @@ use crate::{ use axum::{ body::Body, extract::State, - http::StatusCode, + http::{ + header::{CONTENT_TYPE, LOCATION}, + HeaderMap, StatusCode, + }, response::{IntoResponse, Response}, Json, }; @@ -33,80 +33,85 @@ pub async fn handle_root() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_ROOT_PATH).unwrap_or_default() { PageContent::Default => Response::builder() .status(StatusCode::TEMPORARY_REDIRECT) - .header(HEADER_NAME_LOCATION, ROUTE_HEALTH_PATH) + .header(LOCATION, ROUTE_HEALTH_PATH) .body(Body::empty()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), } } -pub async fn handle_health(State(state): State>>) -> Json { +pub async fn handle_health( + headers: HeaderMap, + State(state): State>>, +) -> Json { let start_time = get_start_time(); - // 创建系统信息实例,只监控 CPU 和内存 - let mut sys = System::new_with_specifics( - RefreshKind::nothing() - .with_memory(MemoryRefreshKind::everything()) - .with_cpu(CpuRefreshKind::everything()), - ); + // 尝试从请求头获取token并验证用户 + let mut stats = None; + let token = headers + .get(axum::http::header::AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)); - std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL); + if let Some(token) = token { + if let Ok(Some(user)) = db::get_user_by_auth_token(token) { + if user.id == 0 && user.ban_expired_at.map_or(true, |t| t <= Local::now()) { + // 创建系统信息实例,只监控 CPU 和内存 + let mut sys = System::new_with_specifics( + RefreshKind::nothing() + .with_memory(MemoryRefreshKind::everything()) + .with_cpu(CpuRefreshKind::everything()), + ); - // 刷新 CPU 和内存信息 - sys.refresh_memory(); - sys.refresh_cpu_usage(); + std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL); - let pid = std::process::id() as usize; - let process = sys.process(pid.into()); + // 刷新 CPU 和内存信息 + sys.refresh_memory(); + sys.refresh_cpu_usage(); - // 获取内存信息 - let memory = process.map(|p| p.memory()).unwrap_or(0); + let pid = std::process::id() as usize; + let process = sys.process(pid.into()); - // 获取 CPU 使用率 - let cpu_usage = sys.global_cpu_usage(); + // 获取内存信息 + let memory = process.map(|p| p.memory()).unwrap_or(0); - let state = state.lock().await; - let uptime = (Local::now() - start_time).num_seconds(); + // 获取 CPU 使用率 + let cpu_usage = sys.global_cpu_usage(); + + let state = state.lock().await; + + stats = Some(SystemStats { + started: start_time.to_string(), + total_requests: state.total_requests, + active_requests: state.active_requests, + system: SystemInfo { + memory: MemoryInfo { + rss: memory, // 物理内存使用量(字节) + }, + cpu: CpuInfo { + usage: cpu_usage, // CPU 使用率(百分比) + }, + }, + }); + } + } + } Json(HealthCheckResponse { status: ApiStatus::Healthy, version: PKG_VERSION, - uptime, - stats: SystemStats { - started: start_time.to_string(), - total_requests: state.total_requests, - active_requests: state.active_requests, - system: SystemInfo { - memory: MemoryInfo { - rss: memory, // 物理内存使用量(字节) - }, - cpu: CpuInfo { - usage: cpu_usage, // CPU 使用率(百分比) - }, - }, - }, - models: AVAILABLE_MODELS.iter().map(|m| m.id).collect::>(), - endpoints: vec![ - ROUTE_CHAT_PATH.as_str(), - ROUTE_MODELS_PATH.as_str(), - ROUTE_GET_CHECKSUM, - ROUTE_TOKENINFO_PATH, - ROUTE_UPDATE_TOKENINFO_PATH, - ROUTE_GET_TOKENINFO_PATH, - ROUTE_LOGS_PATH, - ROUTE_GET_USER_INFO_PATH, - ROUTE_ENV_EXAMPLE_PATH, - ROUTE_CONFIG_PATH, - ROUTE_STATIC_PATH, - ROUTE_ABOUT_PATH, - ROUTE_README_PATH, - ], + uptime: (Local::now() - start_time).num_seconds(), + stats, + models: AVAILABLE_MODELS + .iter() + .map(|m| m.id.clone()) + .collect::>(), }) } diff --git a/src/chat/route/logs.rs b/src/chat/route/logs.rs index a7576bb..ea7fb82 100644 --- a/src/chat/route/logs.rs +++ b/src/chat/route/logs.rs @@ -2,67 +2,62 @@ use crate::{ app::{ constant::{ AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, - CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE, - ROUTE_LOGS_PATH, + CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_LOGS_PATH, }, - model::{AppConfig, AppState, PageContent, RequestLog}, - lazy::AUTH_TOKEN, + db, + model::{AppConfig, LogInfo, PageContent}, }, common::models::ApiStatus, }; use axum::{ body::Body, - extract::State, - http::{HeaderMap, StatusCode}, + http::{header::{AUTHORIZATION, CONTENT_TYPE}, HeaderMap, StatusCode}, response::{IntoResponse, Response}, Json, }; use chrono::Local; -use std::sync::Arc; -use tokio::sync::Mutex; // 日志处理 pub async fn handle_logs() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_LOGS_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from( include_str!("../../../static/logs.min.html").to_string(), )) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(Body::from(content.clone())) .unwrap(), } } -pub async fn handle_logs_post( - State(state): State>>, - headers: HeaderMap, -) -> Result, StatusCode> { - let auth_token = AUTH_TOKEN.as_str(); - - // 验证 AUTH_TOKEN +pub async fn handle_logs_post(headers: HeaderMap) -> Result, StatusCode> { + // 验证 auth_token let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(StatusCode::UNAUTHORIZED)?; - if auth_header != auth_token { - return Err(StatusCode::UNAUTHORIZED); - } + // 通过 auth_token 获取用户信息 + let user = db::get_user_by_auth_token(auth_header) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::UNAUTHORIZED)?; + + // 获取用户的日志记录 + let logs = + db::get_logs_by_user_id(Some(user.id)).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let state = state.lock().await; Ok(Json(LogsResponse { status: ApiStatus::Success, - total: state.request_logs.len(), - logs: state.request_logs.clone(), + total: logs.len(), + logs, timestamp: Local::now().to_string(), })) } @@ -71,6 +66,6 @@ pub async fn handle_logs_post( pub struct LogsResponse { pub status: ApiStatus, pub total: usize, - pub logs: Vec, + pub logs: Vec, pub timestamp: String, } diff --git a/src/chat/route/token.rs b/src/chat/route/token.rs index 95a2687..37a4579 100644 --- a/src/chat/route/token.rs +++ b/src/chat/route/token.rs @@ -2,34 +2,30 @@ use crate::{ app::{ constant::{ AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, - CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE, - ROUTE_TOKENINFO_PATH, + CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_TOKENINFO_PATH, }, - model::{AppConfig, PageContent, TokenUpdateRequest}, - lazy::{AUTH_TOKEN, TOKEN_FILE, TOKEN_LIST_FILE}, + db::{ + get_token_by_id, get_token_by_token, get_tokens_by_user_id, get_user_by_auth_token, + insert_token, update_token, + }, + model::{AppConfig, PageContent, TokenInfo, TokenStatus, TokenUpdateRequest}, }, common::{ - models::{ApiStatus, NormalResponseNoData}, - utils::{generate_checksum, generate_hash, tokens::load_tokens}, + models::ApiStatus, + utils::{extract_user_id, extract_time, generate_checksum, generate_hash, validate_checksum}, }, }; -#[cfg(not(feature = "sqlite"))] -use crate::app::model::AppState; -#[cfg(feature = "sqlite")] -use crate::app::db::APP_DB; use axum::{ - http::HeaderMap, + http::{ + header::{AUTHORIZATION, CONTENT_TYPE}, + HeaderMap, + }, response::{IntoResponse, Response}, Json, }; -#[cfg(not(feature = "sqlite"))] -use axum::extract::State; +use chrono::Local; use reqwest::StatusCode; use serde::Serialize; -#[cfg(not(feature = "sqlite"))] -use std::sync::Arc; -#[cfg(not(feature = "sqlite"))] -use tokio::sync::Mutex; #[derive(Serialize)] pub struct ChecksumResponse { @@ -41,79 +37,38 @@ pub async fn handle_get_checksum() -> Json { Json(ChecksumResponse { checksum }) } -// 更新 TokenInfo 处理 -pub async fn handle_update_tokeninfo( - #[cfg(not(feature = "sqlite"))] State(state): State>>, -) -> Json { - // 重新加载 tokens - let token_infos = load_tokens(); - - // 更新应用状态 - #[cfg(not(feature = "sqlite"))] - { - let mut state = state.lock().await; - state.token_infos = token_infos; - } - - #[cfg(feature = "sqlite")] - { - // 使用 APP_DB 更新 token_infos - if let Ok(db) = APP_DB.lock() { - for token_info in token_infos { - let _ = db.update_token_info(&token_info); - } - } - } - - Json(NormalResponseNoData { - status: ApiStatus::Success, - message: Some("Token list has been reloaded".to_string()), - }) -} - // 获取 TokenInfo 处理 pub async fn handle_get_tokeninfo( headers: HeaderMap, ) -> Result, StatusCode> { - // 验证 AUTH_TOKEN + // 验证用户身份 let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(StatusCode::UNAUTHORIZED)?; - if auth_header != AUTH_TOKEN.as_str() { - return Err(StatusCode::UNAUTHORIZED); + // 获取用户信息 + let user = get_user_by_auth_token(auth_header) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::UNAUTHORIZED)?; + + // 获取用户的tokens + let tokens = if user.id == 0 { + // 管理员可以查看所有tokens + get_tokens_by_user_id(None) + } else { + // 普通用户只能查看自己的tokens + get_tokens_by_user_id(Some(user.id)) } + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let token_file = TOKEN_FILE.as_str(); - let token_list_file = TOKEN_LIST_FILE.as_str(); - - // 读取文件内容 - let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new()); - let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new()); - - // 获取 tokens_count - let tokens_count = { - #[cfg(feature = "sqlite")] - { - APP_DB.lock() - .map(|db| db.get_token_infos().map(|v| v.len()).unwrap_or(0)) - .unwrap_or(0) - } - #[cfg(not(feature = "sqlite"))] - { - tokens.len() - } - }; + let token_num = tokens.len(); Ok(Json(TokenInfoResponse { status: ApiStatus::Success, - token_file: token_file.to_string(), - token_list_file: token_list_file.to_string(), tokens: Some(tokens), - tokens_count: Some(tokens_count), - token_list: Some(token_list), + num: Some(token_num), message: None, })) } @@ -121,89 +76,113 @@ pub async fn handle_get_tokeninfo( #[derive(Serialize)] pub struct TokenInfoResponse { pub status: ApiStatus, - pub token_file: String, - pub token_list_file: String, #[serde(skip_serializing_if = "Option::is_none")] - pub tokens: Option, + pub tokens: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub tokens_count: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub token_list: Option, + pub num: Option, #[serde(skip_serializing_if = "Option::is_none")] pub message: Option, } pub async fn handle_update_tokeninfo_post( - #[cfg(not(feature = "sqlite"))] State(state): State>>, headers: HeaderMap, Json(request): Json, ) -> Result, StatusCode> { - // 验证 AUTH_TOKEN + // 验证用户身份 let auth_header = headers - .get(HEADER_NAME_AUTHORIZATION) + .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(StatusCode::UNAUTHORIZED)?; - if auth_header != AUTH_TOKEN.as_str() { - return Err(StatusCode::UNAUTHORIZED); + // 获取用户信息 + let user = get_user_by_auth_token(auth_header) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::UNAUTHORIZED)?; + + if !validate_checksum(&request.checksum) { + return Err(StatusCode::BAD_REQUEST); } - let token_file = TOKEN_FILE.as_str(); - let token_list_file = TOKEN_LIST_FILE.as_str(); + // 检查token是否已存在 + let existing_token = + get_token_by_token(&request.token).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - // 写入文件 - std::fs::write(&token_file, &request.tokens) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let is_update = existing_token.is_some(); - if let Some(token_list) = &request.token_list { - std::fs::write(&token_list_file, token_list) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - } + let token_info = match existing_token { + Some(mut token) => { + // 更新现有token + token.checksum = request.checksum; + token.alias = request.alias; + token.is_public = request.is_public; - // 重新加载 tokens - let token_infos = load_tokens(); - let token_infos_len = token_infos.len(); - - // 更新应用状态 - #[cfg(not(feature = "sqlite"))] - { - let mut state = state.lock().await; - state.token_infos = token_infos; - } - - #[cfg(feature = "sqlite")] - { - if let Ok(db) = APP_DB.lock() { - for token_info in token_infos { - let _ = db.update_token_info(&token_info); - } + // 更新数据库 + update_token(&token).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + token } - } + None => { + let now = Local::now(); + let alias = if request.alias.is_none() { + match extract_user_id(&request.token) { + Some(user_id) => Some(user_id), + None => None, + } + } else { + request.alias + }; + // 创建新token + let new_token = TokenInfo { + id: 0, // 数据库会自动分配ID + create_at: extract_time(&request.token).unwrap_or_else(|| now), + token: request.token, + checksum: request.checksum, + alias, + status: TokenStatus::Active, + pengding_at: now, + user_id: user.id, + is_public: request.is_public, + usage: None, + }; + + // 插入数据库 + let id = insert_token(&new_token).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // 获取插入后的完整token信息 + get_token_by_id(id) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::INTERNAL_SERVER_ERROR)? + } + }; Ok(Json(TokenInfoResponse { status: ApiStatus::Success, - token_file: token_file.to_string(), - token_list_file: token_list_file.to_string(), tokens: None, - tokens_count: Some(token_infos_len), - token_list: None, - message: Some("Token files have been updated and reloaded".to_string()), + num: None, + message: Some(format!( + "Token {} has been {}", + token_info.token, + if is_update { + "updated" + } else { + "created" + } + )), })) } pub async fn handle_tokeninfo_page() -> impl IntoResponse { match AppConfig::get_page_content(ROUTE_TOKENINFO_PATH).unwrap_or_default() { PageContent::Default => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(include_str!("../../../static/tokeninfo.min.html").to_string()) .unwrap(), PageContent::Text(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) .body(content.clone()) .unwrap(), PageContent::Html(content) => Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) .body(content.clone()) .unwrap(), } diff --git a/src/chat/route/usage.rs b/src/chat/route/usage.rs index ebd9466..dc7892c 100644 --- a/src/chat/route/usage.rs +++ b/src/chat/route/usage.rs @@ -1,37 +1,54 @@ use crate::{ - app::model::AppState, + app::{ + constant::AUTHORIZATION_BEARER_PREFIX, + db::{get_token_by_alias_and_user, get_user_by_auth_token}, + }, chat::constant::ERR_NODATA, common::{models::usage::GetUserInfo, utils::get_user_usage}, }; use axum::{ - extract::{Query, State}, + extract::Query, + http::{header::AUTHORIZATION, HeaderMap, StatusCode}, Json, }; use serde::Deserialize; -use std::sync::Arc; -use tokio::sync::Mutex; #[derive(Deserialize)] pub struct GetUserInfoQuery { alias: String, + user_id: Option, } pub async fn get_user_info( - State(state): State>>, + headers: HeaderMap, Query(query): Query, -) -> Json { - let token_infos = &state.lock().await.token_infos; - let token_info = token_infos - .iter() - .find(|token_info| token_info.alias == Some(query.alias.clone())); +) -> Result, StatusCode> { + // 1. 验证用户身份 + let auth_header = headers + .get(AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) + .ok_or(StatusCode::UNAUTHORIZED)?; - let (auth_token, checksum) = match token_info { - Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()), - None => return Json(GetUserInfo::Error(ERR_NODATA.to_string())), + // 2. 获取用户信息 + let user = get_user_by_auth_token(auth_header) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::UNAUTHORIZED)?; + + // 3. 查询token信息 + let token_info = match get_token_by_alias_and_user( + &query.alias, + user.id, + query.user_id + ) { + Ok(Some(token)) => token, + Ok(None) => return Ok(Json(GetUserInfo::Error(ERR_NODATA.to_string()))), + Err(_) => return Ok(Json(GetUserInfo::Error("Database error".to_string()))), }; - match get_user_usage(&auth_token, &checksum).await { - Some(usage) => Json(GetUserInfo::Usage(usage)), - None => Json(GetUserInfo::Error(ERR_NODATA.to_string())), + // 4. 获取使用情况 + match get_user_usage(&token_info.token, &token_info.checksum).await { + Some(usage) => Ok(Json(GetUserInfo::Usage(usage))), + None => Ok(Json(GetUserInfo::Error(ERR_NODATA.to_string()))), } } diff --git a/src/chat/service.rs b/src/chat/service.rs index fea3787..666756d 100644 --- a/src/chat/service.rs +++ b/src/chat/service.rs @@ -3,11 +3,10 @@ use crate::{ app::{ constant::{ AUTHORIZATION_BEARER_PREFIX, CURSOR_API2_STREAM_CHAT, FINISH_REASON_STOP, - HEADER_NAME_CONTENT_TYPE, OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK, - STATUS_FAILED, STATUS_SUCCESS, + OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK, }, - model::{AppConfig, AppState, ChatRequest, RequestLog, TokenInfo}, - lazy::AUTH_TOKEN, + db, + model::{AppConfig, AppState, ChatRequest, LogInfo, LogStatus, TokenStatus}, }, chat::{ error::StreamError, @@ -25,19 +24,21 @@ use crate::{ use axum::{ body::Body, extract::State, - http::{HeaderMap, StatusCode}, + http::{header::CONTENT_TYPE, HeaderMap, StatusCode}, response::Response, Json, }; use bytes::Bytes; +use chrono::Local; use futures::{Stream, StreamExt}; +use rand::seq::IteratorRandom as _; use std::{ convert::Infallible, sync::{atomic::AtomicBool, Arc}, }; use std::{ pin::Pin, - sync::atomic::{AtomicUsize, Ordering}, + sync::atomic::Ordering, }; use tokio::sync::Mutex; use uuid::Uuid; @@ -46,7 +47,7 @@ use uuid::Uuid; pub async fn handle_models() -> Json { Json(ModelsResponse { object: "list", - data: &AVAILABLE_MODELS, + data: AVAILABLE_MODELS.to_vec(), }) } @@ -56,119 +57,144 @@ pub async fn handle_chat( headers: HeaderMap, Json(request): Json, ) -> Result, (StatusCode, Json)> { - let allow_claude = AppConfig::get_allow_claude(); - // 验证模型是否支持并获取模型信息 - let model = AVAILABLE_MODELS.iter().find(|m| m.id == request.model); - let model_supported = model.is_some(); - - if !(model_supported || allow_claude && request.model.starts_with("claude")) { - return Err(( - StatusCode::BAD_REQUEST, - Json(ChatError::ModelNotSupported(request.model).to_json()), - )); - } - - let request_time = chrono::Local::now(); - - // 验证请求 - if request.messages.is_empty() { - return Err(( - StatusCode::BAD_REQUEST, - Json(ChatError::EmptyMessages.to_json()), - )); - } - - // 获取并处理认证令牌 - let auth_token = headers + // 从请求头获取token + let token = headers .get(axum::http::header::AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) .ok_or(( StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), + Json(ChatError::MissingToken.to_error_response()), ))?; - // 验证 AuthToken - if auth_token != AUTH_TOKEN.as_str() { + // 验证token并获取用户信息 + let user = db::get_user_by_auth_token(token).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; + + let user = user.ok_or(( + StatusCode::UNAUTHORIZED, + Json(ChatError::InvalidToken.to_error_response()), + ))?; + + // 检查用户是否在封禁期 + if let Some(ban_expired_at) = user.ban_expired_at { + if ban_expired_at > Local::now() { + return Err(( + StatusCode::FORBIDDEN, + Json(ChatError::UserBanned(ban_expired_at).to_error_response()), + )); + } + } + + let tokens = db::get_available_tokens_by_user_id(Some(user.id)).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; + + if tokens.is_empty() { return Err(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), + StatusCode::SERVICE_UNAVAILABLE, + Json(ChatError::NoTokens.to_error_response()), )); } - // 完整的令牌处理逻辑和对应的 checksum - let (auth_token, checksum, alias) = { - static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0); - let state_guard = state.lock().await; - let token_infos = &state_guard.token_infos; + // 随机选择一个可用的token + let token_info = tokens.into_iter().choose(&mut rand::thread_rng()).ok_or(( + StatusCode::SERVICE_UNAVAILABLE, + Json(ChatError::NoTokens.to_error_response()), + ))?; - if token_infos.is_empty() { - return Err(( - StatusCode::SERVICE_UNAVAILABLE, - Json(ChatError::NoTokens.to_json()), - )); - } + let allow_claude = AppConfig::get_allow_claude(); + // 验证模型是否支持并获取模型信息 + let model = AVAILABLE_MODELS + .iter() + .find(|m| m.id == request.model) + .cloned(); + let model_supported = model.is_some(); - let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len(); - let token_info = &token_infos[index]; - ( - token_info.token.clone(), - token_info.checksum.clone(), - token_info.alias.clone(), - ) + if !(model_supported || allow_claude && request.model.starts_with("claude")) { + return Err(( + StatusCode::BAD_REQUEST, + Json(ChatError::ModelNotSupported(request.model).to_error_response()), + )); + } + + let request_time = Local::now(); + + // 验证请求 + if request.messages.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ChatError::EmptyMessages.to_error_response()), + )); + } + + let log_info = LogInfo { + id: 0, // 数据库会自动生成 + timestamp: request_time, + token_info: token_info.clone(), + prompt: None, + model: request.model.clone(), + stream: request.stream, + status: LogStatus::Pending, + error: None, }; + let log_id = db::insert_log(&log_info).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; + // 更新请求日志 { - let state_clone = state.clone(); let mut state = state.lock().await; state.total_requests += 1; state.active_requests += 1; + let token = token_info.token.clone(); + let checksum = token_info.checksum.clone(); // 如果有model且需要获取使用情况,创建后台任务获取 if let Some(model) = model { if model.is_usage_check() { - let auth_token_clone = auth_token.clone(); - let checksum_clone = checksum.clone(); - let state_clone = state_clone.clone(); - tokio::spawn(async move { - let usage = get_user_usage(&auth_token_clone, &checksum_clone).await; - let mut state = state_clone.lock().await; - // 根据时间戳找到对应的日志 - if let Some(log) = state - .request_logs - .iter_mut() - .find(|log| log.timestamp == request_time) - { - log.token_info.usage = usage; + let usage = get_user_usage(&token, &checksum).await; + if let Err(err) = db::update_log_usage(log_id, usage) { + eprintln!("Failed to update log usage: {}", err); } }); } } - - let next_id = state.request_logs.last().map_or(1, |log| log.id + 1); - state.request_logs.push(RequestLog { - id: next_id, - timestamp: request_time, - model: request.model.clone(), - token_info: TokenInfo { - token: auth_token.clone(), - checksum: checksum.clone(), - alias: alias.clone(), - usage: None, - }, - prompt: None, - stream: request.stream, - status: "pending", - error: None, - }); - - if state.request_logs.len() > 100 { - state.request_logs.remove(0); - } } + if db::get_user_logs_count(user.id).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })? >= 100 { + db::clean_user_logs(user.id, 100).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; + } + + db::update_token_status(token_info.id, TokenStatus::Pending).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; + // 将消息转换为hex格式 let hex_data = super::adapter::encode_chat_message(request.messages, &request.model) .await @@ -176,37 +202,39 @@ pub async fn handle_chat( ( StatusCode::INTERNAL_SERVER_ERROR, Json( - ChatError::RequestFailed("Failed to encode chat message".to_string()).to_json(), + ChatError::RequestFailed("Failed to encode chat message".to_string()) + .to_error_response(), ), ) })?; // 构建请求客户端 - let client = build_client(&auth_token, &checksum, CURSOR_API2_STREAM_CHAT); + let client = build_client(&token_info.token, &token_info.checksum, CURSOR_API2_STREAM_CHAT); let response = client.body(hex_data).send().await; // 处理请求结果 let response = match response { Ok(resp) => { // 更新请求日志为成功 - { - let mut state = state.lock().await; - state.request_logs.last_mut().unwrap().status = STATUS_SUCCESS; - } + db::update_log_status(log_id, LogStatus::Success, None).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; resp } Err(e) => { // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED; - last_log.error = Some(e.to_string()); - } - } + db::update_log_status(log_id, LogStatus::Failed, Some(e.to_string())).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; return Err(( StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed(e.to_string()).to_json()), + Json(ChatError::RequestFailed(e.to_string()).to_error_response()), )); } }; @@ -237,7 +265,7 @@ pub async fn handle_chat( // 理论上,若程序正常,必定成功,因为前面判断过了 ( StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed(error_message).to_json()), + Json(ChatError::RequestFailed(error_message).to_error_response()), ) })?; @@ -245,13 +273,12 @@ pub async fn handle_chat( Err(StreamError::ChatError(error)) => { let error_respone = error.to_error_response(); // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED; - last_log.error = Some(error_respone.native_code()); - } - } + db::update_log_status(log_id, LogStatus::Failed, Some(error_respone.native_code())).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; return Err(( error_respone.status_code(), Json(error_respone.to_common()), @@ -274,18 +301,17 @@ pub async fn handle_chat( // Box::pin(stream) // as Pin> + Send>> // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED; - last_log.error = Some("Empty stream response".to_string()); - } - } + db::update_log_status(log_id, LogStatus::Failed, Some("Empty stream response".to_string())).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json( ChatError::RequestFailed("Empty stream response".to_string()) - .to_json(), + .to_error_response(), ), )); } @@ -304,7 +330,6 @@ pub async fn handle_chat( let model = request.model.clone(); let is_start = is_start.clone(); let full_text = full_text.clone(); - let state = state.clone(); async move { let chunk = chunk.unwrap_or_default(); @@ -415,10 +440,8 @@ pub async fn handle_chat( } Ok(StreamMessage::Debug(debug_prompt)) => { buffer_guard.clear(); - if let Ok(mut state) = state.try_lock() { - if let Some(last_log) = state.request_logs.last_mut() { - last_log.prompt = Some(debug_prompt.clone()); - } + if let Err(err) = db::update_log_prompt(log_id, Some(debug_prompt.clone())) { + eprintln!("Failed to update log prompt: {}", err); } Ok(Bytes::new()) } @@ -435,7 +458,7 @@ pub async fn handle_chat( Ok(Response::builder() .header("Cache-Control", "no-cache") .header("Connection", "keep-alive") - .header(HEADER_NAME_CONTENT_TYPE, "text/event-stream") + .header(CONTENT_TYPE, "text/event-stream") .body(Body::from_stream(stream)) .unwrap()) } else { @@ -451,7 +474,7 @@ pub async fn handle_chat( StatusCode::INTERNAL_SERVER_ERROR, Json( ChatError::RequestFailed(format!("Failed to read response chunk: {}", e)) - .to_json(), + .to_error_response(), ), ) })?; @@ -496,29 +519,34 @@ pub async fn handle_chat( // 检查响应是否为空 if full_text.is_empty() { // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.status = STATUS_FAILED; - last_log.error = Some("Empty response received".to_string()); - if let Some(p) = prompt { - last_log.prompt = Some(p); - } - } - } + db::update_log_status(log_id, LogStatus::Failed, Some("Empty response received".to_string())).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; + db::update_log_prompt(log_id, None).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; return Err(( StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed("Empty response received".to_string()).to_json()), + Json( + ChatError::RequestFailed("Empty response received".to_string()) + .to_error_response(), + ), )); } // 更新请求日志提示词 - { - let mut state = state.lock().await; - if let Some(last_log) = state.request_logs.last_mut() { - last_log.prompt = prompt; - } - } + db::update_log_prompt(log_id, prompt).map_err(|err| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::DatabaseError(err.to_string()).to_error_response()), + ) + })?; let response_data = ChatResponse { id: format!("chatcmpl-{}", Uuid::new_v4().simple()), @@ -542,7 +570,7 @@ pub async fn handle_chat( }; Ok(Response::builder() - .header(HEADER_NAME_CONTENT_TYPE, "application/json") + .header(CONTENT_TYPE, "application/json") .body(Body::from(serde_json::to_string(&response_data).unwrap())) .unwrap()) } diff --git a/src/common/client.rs b/src/common/client.rs index b5e69e1..f826b14 100644 --- a/src/common/client.rs +++ b/src/common/client.rs @@ -1,12 +1,12 @@ use crate::app::{ constant::{ AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, - CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION, - HEADER_NAME_CONTENT_TYPE, + CURSOR_API2_STREAM_CHAT, HEADER_NAME_GHOST_MODE, + TRUE, FALSE }, - lazy::{CURSOR_API2_HOST, CURSOR_API2_BASE_URL}, + lazy::{CURSOR_API2_BASE_URL, CURSOR_API2_HOST}, }; -use reqwest::Client; +use reqwest::{header::{CONTENT_TYPE,AUTHORIZATION,USER_AGENT,HOST}, Client}; use uuid::Uuid; /// 返回预构建的 Cursor API 客户端 @@ -21,19 +21,45 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest client .post(format!("{}{}", *CURSOR_API2_BASE_URL, endpoint)) - .header(HEADER_NAME_CONTENT_TYPE, content_type) + .header(CONTENT_TYPE, content_type) .header( - HEADER_NAME_AUTHORIZATION, + AUTHORIZATION, format!("{}{}", AUTHORIZATION_BEARER_PREFIX, auth_token), ) .header("connect-accept-encoding", "gzip,br") .header("connect-protocol-version", "1") - .header("user-agent", "connect-es/1.6.1") + .header(USER_AGENT, "connect-es/1.6.1") .header("x-amzn-trace-id", format!("Root={}", trace_id)) .header("x-cursor-checksum", checksum) .header("x-cursor-client-version", "0.42.5") .header("x-cursor-timezone", "Asia/Shanghai") - .header("x-ghost-mode", "false") + .header(HEADER_NAME_GHOST_MODE, FALSE) .header("x-request-id", trace_id) - .header("Host", CURSOR_API2_HOST.clone()) + .header(HOST, CURSOR_API2_HOST.clone()) +} + +/// 返回预构建的获取 Stripe 账户信息的 Cursor API 客户端 +pub fn build_profile_client(auth_token: &str) -> reqwest::RequestBuilder { + let client = Client::new(); + + client + .get(format!("{}/auth/full_stripe_profile", *CURSOR_API2_BASE_URL)) + .header(HOST, CURSOR_API2_HOST.clone()) + .header("sec-ch-ua", "\"Not-A.Brand\";v=\"99\", \"Chromium\";v=\"124\"") + .header(HEADER_NAME_GHOST_MODE, TRUE) + .header("sec-ch-ua-mobile", "?0") + .header( + AUTHORIZATION, + format!("{}{}", AUTHORIZATION_BEARER_PREFIX, auth_token), + ) + .header(USER_AGENT, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.42.5 Chrome/124.0.6367.243 Electron/30.4.0 Safari/537.36") + .header("sec-ch-ua-platform", "\"Windows\"") + .header("accept", "*/*") + .header("origin", "vscode-file://vscode-app") + .header("sec-fetch-site", "cross-site") + .header("sec-fetch-mode", "cors") + .header("sec-fetch-dest", "empty") + .header("accept-encoding", "gzip, deflate, br") + .header("accept-language", "zh-CN") + .header("priority", "u=1, i") } diff --git a/src/common/models/error.rs b/src/common/models/error.rs index 79430db..ce552fc 100644 --- a/src/common/models/error.rs +++ b/src/common/models/error.rs @@ -6,10 +6,14 @@ pub enum ChatError { NoTokens, RequestFailed(String), Unauthorized, + MissingToken, + InvalidToken, + UserBanned(chrono::DateTime), + DatabaseError(String), } impl ChatError { - pub fn to_json(&self) -> ErrorResponse { + pub fn to_error_response(&self) -> ErrorResponse { let (error, message) = match self { ChatError::ModelNotSupported(model) => ( "model_not_supported", @@ -22,13 +26,23 @@ impl ChatError { ChatError::NoTokens => ("no_tokens", "No available tokens".to_string()), ChatError::RequestFailed(err) => ("request_failed", format!("Request failed: {}", err)), ChatError::Unauthorized => ("unauthorized", "Invalid authorization token".to_string()), + ChatError::MissingToken => ("missing_token", "Missing authorization token".to_string()), + ChatError::InvalidToken => ("invalid_token", "Invalid authorization token".to_string()), + ChatError::UserBanned(expired_at) => ( + "user_banned", + format!("User is banned until {}", expired_at), + ), + ChatError::DatabaseError(err) => ( + "database_error", + format!("Database error occurred: {}", err), + ), }; ErrorResponse { - status: super::ApiStatus::Error, - code: None, - error: Some(error.to_string()), - message: Some(message), + status: super::ApiStatus::Error, + code: None, + error: Some(error.to_string()), + message: Some(message), } } } diff --git a/src/common/models/health.rs b/src/common/models/health.rs index 43d2241..6760116 100644 --- a/src/common/models/health.rs +++ b/src/common/models/health.rs @@ -7,9 +7,9 @@ pub struct HealthCheckResponse { pub status: ApiStatus, pub version: &'static str, pub uptime: i64, - pub stats: SystemStats, - pub models: Vec<&'static str>, - pub endpoints: Vec<&'static str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub stats: Option, + pub models: Vec, } #[derive(Serialize)] diff --git a/src/common/models/usage.rs b/src/common/models/usage.rs index 9a61e05..ea14c0e 100644 --- a/src/common/models/usage.rs +++ b/src/common/models/usage.rs @@ -1,4 +1,4 @@ -use serde::Serialize; +use serde::{Deserialize, Serialize}; #[derive(Serialize)] pub enum GetUserInfo { @@ -12,4 +12,41 @@ pub enum GetUserInfo { pub struct UserUsageInfo { pub fast_requests: u32, pub max_fast_requests: u32, + pub mtype: String, + pub trial_days: u32, +} + +impl rusqlite::types::FromSql for UserUsageInfo { + fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult { + let str = value.as_str()?; + let parts: Vec<&str> = str.split(',').collect(); + if parts.len() != 4 { + return Err(rusqlite::types::FromSqlError::InvalidType); + } + + Ok(UserUsageInfo { + fast_requests: parts[0].parse().map_err(|_| rusqlite::types::FromSqlError::InvalidType)?, + max_fast_requests: parts[1].parse().map_err(|_| rusqlite::types::FromSqlError::InvalidType)?, + mtype: parts[2].to_string(), + trial_days: parts[3].parse().map_err(|_| rusqlite::types::FromSqlError::InvalidType)?, + }) + } +} + +impl rusqlite::ToSql for UserUsageInfo { + fn to_sql(&self) -> rusqlite::Result> { + let str = format!("{},{},{},{}", + self.fast_requests, + self.max_fast_requests, + self.mtype, + self.trial_days + ); + Ok(rusqlite::types::ToSqlOutput::from(str)) + } +} + +#[derive(Deserialize)] +pub struct StripeProfile { + pub membership_type: String, + pub days_remaining_on_trial: i32, } diff --git a/src/common/utils.rs b/src/common/utils.rs index 33dfa7b..80fe6ef 100644 --- a/src/common/utils.rs +++ b/src/common/utils.rs @@ -1,19 +1,20 @@ mod checksum; pub use checksum::*; -pub mod tokens; +mod token; +pub use token::*; pub mod oauth; use prost::Message as _; -use crate::{app::constant::CURSOR_API2_GET_USER_INFO, chat::aiserver::v1::GetUserInfoResponse}; +use crate::{app::constant::{CURSOR_API2_GET_USER_INFO, TRUE, FALSE}, chat::aiserver::v1::GetUserInfoResponse}; -use super::models::usage::UserUsageInfo; +use super::models::usage::{StripeProfile, UserUsageInfo}; pub fn parse_bool_from_env(key: &str, default: bool) -> bool { std::env::var(key) .ok() .map(|v| match v.to_lowercase().as_str() { - "true" | "1" => true, - "false" | "0" => false, + TRUE | "1" => true, + FALSE | "0" => false, _ => default, }) .unwrap_or(default) @@ -44,8 +45,47 @@ pub async fn get_user_usage(auth_token: &str, checksum: &str) -> Option Option> { +// let client = super::client::build_client(auth_token, checksum, CURSOR_API2_AVAILABLE_MODELS); +// let response = client +// .body(Vec::new()) +// .send() +// .await +// .ok()? +// .bytes() +// .await +// .ok()?; +// let available_models = AvailableModelsResponse::decode(response.as_ref()).ok()?; +// Some(available_models.models.into_iter().map(|model| Model { +// id: model.name.clone(), +// created: CREATED, +// object: MODEL_OBJECT, +// owned_by: { +// if model.name.starts_with("gpt") || model.name.starts_with("o1") { +// OPENAI +// } else if model.name.starts_with("claude") { +// ANTHROPIC +// } else if model.name.starts_with("gemini") { +// GOOGLE +// } else { +// CURSOR +// } +// }, +// }).collect()) +// } + +pub async fn get_stripe_profile(auth_token: &str) -> Option<(String, u32)> { + let client = super::client::build_profile_client(auth_token); + let response = client.send().await.ok()?.json::().await.ok()?; + Some((response.membership_type, i32_to_u32(response.days_remaining_on_trial))) +} diff --git a/src/common/utils/checksum.rs b/src/common/utils/checksum.rs index 1015c2a..2bf6f0e 100644 --- a/src/common/utils/checksum.rs +++ b/src/common/utils/checksum.rs @@ -56,15 +56,8 @@ pub fn validate_checksum(checksum: &str) -> bool { return false; } - // 验证 BASE64 部分 - let base64_len = 8; - let encoded_part = &checksum[..base64_len]; - if !BASE64.decode(encoded_part).is_ok() { - return false; - } - // 验证 device_id hash 部分 - let device_hash = &checksum[base64_len..]; + let device_hash = &checksum[8..]; is_valid_hash(device_hash) } // 包含 MAC hash 的情况 diff --git a/src/common/utils/oauth.rs b/src/common/utils/oauth.rs index 4732f3f..d33f8b9 100644 --- a/src/common/utils/oauth.rs +++ b/src/common/utils/oauth.rs @@ -1,12 +1,16 @@ use anyhow::Result; -use reqwest::Client; +use oauth2::{ + basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, + TokenResponse, TokenUrl, +}; +use reqwest::{Client, Url}; use serde::{Deserialize, Serialize}; const OAUTH_AUTHORIZE_URL: &str = "https://connect.linux.do/oauth2/authorize"; const OAUTH_TOKEN_URL: &str = "https://connect.linux.do/oauth2/token"; const OAUTH_USER_INFO_URL: &str = "https://connect.linux.do/api/user"; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct ForumUser { pub id: i64, pub username: String, @@ -17,52 +21,49 @@ pub struct ForumUser { } pub struct ForumOAuth { - client_id: String, - client_secret: String, - redirect_uri: String, + oauth_client: BasicClient, http_client: Client, } impl ForumOAuth { - pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self { - Self { - client_id, - client_secret, - redirect_uri, - http_client: Client::new(), - } - } - - pub fn get_authorize_url(&self, state: &str) -> String { - format!( - "{}?response_type=code&client_id={}&redirect_uri={}&state={}", - OAUTH_AUTHORIZE_URL, - self.client_id, - urlencoding::encode(&self.redirect_uri), - state + pub fn new(client_id: String, client_secret: String, redirect_url: String) -> Result { + let oauth_client = BasicClient::new( + ClientId::new(client_id), + Some(ClientSecret::new(client_secret)), + AuthUrl::new(OAUTH_AUTHORIZE_URL.to_string())?, + Some(TokenUrl::new(OAUTH_TOKEN_URL.to_string())?), ) + .set_redirect_uri(RedirectUrl::new(redirect_url)?); + + Ok(Self { + oauth_client, + http_client: Client::new(), + }) } - pub async fn exchange_code_for_token(&self, code: &str) -> Result { - let response = self - .http_client - .post(OAUTH_TOKEN_URL) - .form(&[ - ("grant_type", "authorization_code"), - ("code", code), - ("client_id", &self.client_id), - ("client_secret", &self.client_secret), - ("redirect_uri", &self.redirect_uri), - ]) - .send() - .await? - .json::() + pub fn get_authorize_url(&self) -> (Url, CsrfToken) { + self.oauth_client + .authorize_url(|| CsrfToken::new_random()) + .url() + } + + pub async fn exchange_code_for_token( + &self, + code: &str, + returned_state: &str, + expected_state: &str, + ) -> Result { + if returned_state != expected_state { + return Err(anyhow::anyhow!("Invalid state parameter")); + } + + let token = self + .oauth_client + .exchange_code(AuthorizationCode::new(code.to_string())) + .request_async(oauth2::reqwest::async_http_client) .await?; - Ok(response["access_token"] - .as_str() - .ok_or_else(|| anyhow::anyhow!("No access token found"))? - .to_string()) + Ok(token.access_token().secret().clone()) } pub async fn get_user_info(&self, access_token: &str) -> Result { diff --git a/src/common/utils/token.rs b/src/common/utils/token.rs new file mode 100644 index 0000000..25cb0a6 --- /dev/null +++ b/src/common/utils/token.rs @@ -0,0 +1,148 @@ +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use chrono::{DateTime, Local, TimeZone}; + +// 验证jwt token是否有效 +pub fn validate_token(token: &str) -> bool { + // 检查 token 格式 + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return false; + } + + // 解码 payload + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(decoded) => decoded, + Err(_) => return false, + }; + + // 转换为字符串 + let payload_str = match String::from_utf8(payload) { + Ok(s) => s, + Err(_) => return false, + }; + + // 解析 JSON + let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) { + Ok(v) => v, + Err(_) => return false, + }; + + // 验证必要字段是否存在且有效 + let required_fields = ["sub", "exp", "iss", "aud", "randomness", "time"]; + for field in required_fields { + if !payload_json.get(field).is_some() { + return false; + } + } + + // 验证 randomness 长度 + if let Some(randomness) = payload_json["randomness"].as_str() { + if randomness.len() != 18 { + return false; + } + } else { + return false; + } + + // 验证 time 字段 + if let Some(time) = payload_json["time"].as_str() { + // 验证 time 是否为有效的数字字符串 + if let Ok(time_value) = time.parse::() { + let current_time = chrono::Utc::now().timestamp(); + if time_value > current_time { + return false; + } + } else { + return false; + } + } else { + return false; + } + + // 验证过期时间 + if let Some(exp) = payload_json["exp"].as_i64() { + let current_time = chrono::Utc::now().timestamp(); + if current_time > exp { + return false; + } + } else { + return false; + } + + // 验证发行者 + if payload_json["iss"].as_str() != Some("https://authentication.cursor.sh") { + return false; + } + + // 验证受众 + if payload_json["aud"].as_str() != Some("https://cursor.com") { + return false; + } + + true +} + +// 从 JWT token 中提取用户 ID +pub fn extract_user_id(token: &str) -> Option { + // JWT token 由3部分组成,用 . 分隔 + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return None; + } + + // 解码 payload (第二部分) + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(decoded) => decoded, + Err(_) => return None, + }; + + // 将 payload 转换为字符串 + let payload_str = match String::from_utf8(payload) { + Ok(s) => s, + Err(_) => return None, + }; + + // 解析 JSON + let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) { + Ok(v) => v, + Err(_) => return None, + }; + + // 提取 sub 字段 + payload_json["sub"] + .as_str() + .map(|s| s.split('|').nth(1).unwrap_or(s).to_string()) +} + +// 从 JWT token 中提取 time 字段 +pub fn extract_time(token: &str) -> Option> { + // JWT token 由3部分组成,用 . 分隔 + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return None; + } + + // 解码 payload (第二部分) + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(decoded) => decoded, + Err(_) => return None, + }; + + // 将 payload 转换为字符串 + let payload_str = match String::from_utf8(payload) { + Ok(s) => s, + Err(_) => return None, + }; + + // 解析 JSON + let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) { + Ok(v) => v, + Err(_) => return None, + }; + + // 提取时间戳并转换为本地时间 + payload_json["time"] + .as_str() + .and_then(|t| t.parse::().ok()) + .and_then(|timestamp| Local.timestamp_opt(timestamp, 0).single()) +} diff --git a/src/common/utils/tokens.rs b/src/common/utils/tokens.rs deleted file mode 100644 index acd1d7e..0000000 --- a/src/common/utils/tokens.rs +++ /dev/null @@ -1,144 +0,0 @@ -use crate::{ - app::{ - constant::EMPTY_STRING, - model::TokenInfo, - lazy::{TOKEN_FILE, TOKEN_LIST_FILE}, - }, - common::utils::{generate_checksum, generate_hash}, -}; - -// 规范化文件内容并写入 -fn normalize_and_write(content: &str, file_path: &str) -> String { - let normalized = content.replace("\r\n", "\n"); - if normalized != content { - if let Err(e) = std::fs::write(file_path, &normalized) { - eprintln!("警告: 无法更新规范化的文件: {}", e); - } - } - normalized -} - -// 解析token和别名 -fn parse_token_alias(token_part: &str, line: &str) -> Option<(String, Option)> { - match token_part.split("::").collect::>() { - parts if parts.len() == 1 => Some((parts[0].to_string(), None)), - parts if parts.len() == 2 => Some((parts[1].to_string(), Some(parts[0].to_string()))), - _ => { - eprintln!("警告: 忽略无效的行: {}", line); - None - } - } -} - -// Token 加载函数 -pub fn load_tokens() -> Vec { - let token_file = TOKEN_FILE.as_str(); - let token_list_file = TOKEN_LIST_FILE.as_str(); - - // 确保文件存在 - for file in [&token_file, &token_list_file] { - if !std::path::Path::new(file).exists() { - if let Err(e) = std::fs::write(file, EMPTY_STRING) { - eprintln!("警告: 无法创建文件 '{}': {}", file, e); - } - } - } - - // 读取和规范化 token 文件 - let token_entries = match std::fs::read_to_string(&token_file) { - Ok(content) => { - let normalized = normalize_and_write(&content, &token_file); - normalized - .lines() - .filter_map(|line| { - let line = line.trim(); - if line.is_empty() || line.starts_with('#') { - return None; - } - parse_token_alias(line, line) - }) - .collect::>() - } - Err(e) => { - eprintln!("警告: 无法读取token文件 '{}': {}", token_file, e); - Vec::new() - } - }; - - // 读取和规范化 token-list 文件 - let mut token_map: std::collections::HashMap)> = - match std::fs::read_to_string(&token_list_file) { - Ok(content) => { - let normalized = normalize_and_write(&content, &token_list_file); - normalized - .lines() - .filter_map(|line| { - let line = line.trim(); - if line.is_empty() || line.starts_with('#') { - return None; - } - - let parts: Vec<&str> = line.split(',').collect(); - match parts[..] { - [token_part, checksum] => { - let (token, alias) = parse_token_alias(token_part, line)?; - Some((token, (checksum.to_string(), alias))) - } - _ => { - eprintln!("警告: 忽略无效的token-list行: {}", line); - None - } - } - }) - .collect() - } - Err(e) => { - eprintln!("警告: 无法读取token-list文件: {}", e); - std::collections::HashMap::new() - } - }; - - // 更新或添加新token - for (token, alias) in token_entries { - if let Some((_, existing_alias)) = token_map.get(&token) { - // 只在alias不同时更新已存在的token - if alias != *existing_alias { - if let Some((checksum, _)) = token_map.get(&token) { - token_map.insert(token.clone(), (checksum.clone(), alias)); - } - } - } else { - // 为新token生成checksum - let checksum = generate_checksum(&generate_hash(), Some(&generate_hash())); - token_map.insert(token, (checksum, alias)); - } - } - - // 更新 token-list 文件 - let token_list_content = token_map - .iter() - .map(|(token, (checksum, alias))| { - if let Some(alias) = alias { - format!("{}::{},{}", alias, token, checksum) - } else { - format!("{},{}", token, checksum) - } - }) - .collect::>() - .join("\n"); - - if let Err(e) = std::fs::write(&token_list_file, token_list_content) { - eprintln!("警告: 无法更新token-list文件: {}", e); - } - - // 转换为 TokenInfo vector - token_map - .into_iter() - .map(|(token, (checksum, alias))| TokenInfo { - token, - checksum, - alias, - usage: None, - }) - .collect() -} diff --git a/src/main.rs b/src/main.rs index e01c3b4..7033136 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,15 +3,16 @@ mod chat; mod common; use app::{ - config::handle_config_update, - constant::{ - EMPTY_STRING, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, - ROUTE_GET_CHECKSUM, ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, - ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, - ROUTE_TOKENINFO_PATH, ROUTE_UPDATE_TOKENINFO_PATH, - }, - model::*, - lazy::{AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH}, + config::handle_config_update, constant::{ + EMPTY_STRING, PKG_NAME, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_API_PATH, + ROUTE_AUTH_CALLBACK_PATH, ROUTE_AUTH_INITIATE_PATH, ROUTE_AUTH_PATH, ROUTE_CHAT_PATH, + ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, ROUTE_GET_TOKENINFO_PATH, + ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, ROUTE_MODELS_PATH, + ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKENINFO_PATH, + ROUTE_UPDATE_TOKENINFO_PATH, + }, db::init_database, lazy::{ + OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OAUTH_REDIRECT_URI, PUBLIC_AUTH_TOKEN, ROUTE_PREFIX, + }, model::{AppConfig, AppState, VisionAbility} }; use axum::{ routing::{get, post}, @@ -19,14 +20,14 @@ use axum::{ }; use chat::{ route::{ - get_user_info, handle_about, handle_config_page, handle_env_example, handle_get_checksum, - handle_get_tokeninfo, handle_health, handle_logs, handle_logs_post, handle_readme, - handle_root, handle_static, handle_tokeninfo_page, handle_update_tokeninfo, - handle_update_tokeninfo_post, + get_user_info, handle_about, handle_auth_callback, handle_auth_initiate, + handle_config_page, handle_env_example, handle_get_checksum, handle_get_tokeninfo, + handle_health, handle_logs, handle_logs_post, handle_readme, handle_root, handle_static, + handle_tokeninfo_page, handle_update_tokeninfo_post, }, service::{handle_chat, handle_models}, }; -use common::utils::{parse_bool_from_env, parse_string_from_env, tokens::load_tokens}; +use common::utils::{parse_bool_from_env, parse_string_from_env}; use std::sync::Arc; use tokio::sync::Mutex; use tower_http::cors::CorsLayer; @@ -46,8 +47,20 @@ async fn main() { // 加载环境变量 dotenvy::dotenv().ok(); - if AUTH_TOKEN.is_empty() { - panic!("AUTH_TOKEN must be set") + if PUBLIC_AUTH_TOKEN.is_empty() { + panic!("PUBLIC_AUTH_TOKEN must be set") + }; + + if OAUTH_CLIENT_ID.is_empty() { + panic!("OAUTH_CLIENT_ID must be set") + }; + + if OAUTH_CLIENT_SECRET.is_empty() { + panic!("OAUTH_CLIENT_SECRET must be set") + }; + + if OAUTH_REDIRECT_URI.is_empty() { + panic!("OAUTH_REDIRECT_URI must be set") }; // 初始化全局配置 @@ -59,30 +72,29 @@ async fn main() { parse_bool_from_env("PASS_ANY_CLAUDE", false), ); - // 加载 tokens - let token_infos = load_tokens(); - // 初始化应用状态 - #[cfg(feature = "sqlite")] let state = Arc::new(Mutex::new(AppState::new())); - #[cfg(not(feature = "sqlite"))] - let state = Arc::new(Mutex::new(AppState::new(token_infos))); + + init_database(format!("{}.db", PKG_NAME).as_str()).await.unwrap(); // 设置路由 let app = Router::new() + .nest( + ROUTE_PREFIX.as_str(), + Router::new() + .route(ROUTE_MODELS_PATH, get(handle_models)) + .route(ROUTE_CHAT_PATH, post(handle_chat)), + ) .route(ROUTE_ROOT_PATH, get(handle_root)) .route(ROUTE_HEALTH_PATH, get(handle_health)) .route(ROUTE_TOKENINFO_PATH, get(handle_tokeninfo_page)) - .route(ROUTE_MODELS_PATH.as_str(), get(handle_models)) .route(ROUTE_GET_CHECKSUM, get(handle_get_checksum)) .route(ROUTE_GET_USER_INFO_PATH, get(get_user_info)) - .route(ROUTE_UPDATE_TOKENINFO_PATH, get(handle_update_tokeninfo)) .route(ROUTE_GET_TOKENINFO_PATH, post(handle_get_tokeninfo)) .route( ROUTE_UPDATE_TOKENINFO_PATH, post(handle_update_tokeninfo_post), ) - .route(ROUTE_CHAT_PATH.as_str(), post(handle_chat)) .route(ROUTE_LOGS_PATH, get(handle_logs)) .route(ROUTE_LOGS_PATH, post(handle_logs_post)) .route(ROUTE_ENV_EXAMPLE_PATH, get(handle_env_example)) @@ -91,6 +103,15 @@ async fn main() { .route(ROUTE_STATIC_PATH, get(handle_static)) .route(ROUTE_ABOUT_PATH, get(handle_about)) .route(ROUTE_README_PATH, get(handle_readme)) + .nest( + ROUTE_API_PATH, + Router::new().nest( + ROUTE_AUTH_PATH, + Router::new() + .route(ROUTE_AUTH_CALLBACK_PATH, get(handle_auth_callback)) + .route(ROUTE_AUTH_INITIATE_PATH, get(handle_auth_initiate)), + ), + ) .layer(CorsLayer::permissive()) .with_state(state); diff --git a/static/tokeninfo.html b/static/tokeninfo.html index 0dd5b26..e13edb2 100644 --- a/static/tokeninfo.html +++ b/static/tokeninfo.html @@ -76,6 +76,19 @@