mirror of
https://github.com/wisdgod/cursor-api.git
synced 2025-12-24 13:38:01 +08:00
这我已经不想做了
This commit is contained in:
@@ -40,3 +40,6 @@ DEFAULT_INSTRUCTIONS="Respond in Chinese by default"
|
||||
|
||||
# 反向代理服务器主机名
|
||||
CURSOR_API2_HOST=
|
||||
|
||||
# 管理员认证令牌
|
||||
ADMIN_AUTH_TOKEN=
|
||||
|
||||
585
Cargo.lock
generated
585
Cargo.lock
generated
@@ -105,10 +105,10 @@ dependencies = [
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http 1.2.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper 1.5.2",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit",
|
||||
@@ -121,7 +121,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
@@ -138,13 +138,13 @@ dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http 1.2.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"sync_wrapper",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
@@ -162,9 +162,21 @@ dependencies = [
|
||||
"miniz_oxide",
|
||||
"object",
|
||||
"rustc-demangle",
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
@@ -237,6 +249,12 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "cfg_aliases"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.39"
|
||||
@@ -249,7 +267,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -308,7 +326,7 @@ version = "0.1.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"base64",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"dotenvy",
|
||||
@@ -317,13 +335,14 @@ dependencies = [
|
||||
"gif",
|
||||
"hex",
|
||||
"image",
|
||||
"lazy_static",
|
||||
"oauth2",
|
||||
"paste",
|
||||
"prost",
|
||||
"prost-build",
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"reqwest 0.12.12",
|
||||
"ring",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -561,8 +580,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -581,6 +602,25 @@ version = "0.31.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http 0.2.12",
|
||||
"indexmap",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.7"
|
||||
@@ -592,7 +632,7 @@ dependencies = [
|
||||
"fnv",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"http",
|
||||
"http 1.2.0",
|
||||
"indexmap",
|
||||
"slab",
|
||||
"tokio",
|
||||
@@ -636,6 +676,17 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.2.0"
|
||||
@@ -647,6 +698,17 @@ dependencies = [
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http 0.2.12",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.1"
|
||||
@@ -654,7 +716,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http",
|
||||
"http 1.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -665,8 +727,8 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http 1.2.0",
|
||||
"http-body 1.0.1",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -682,6 +744,30 @@ version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "0.14.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.3.26",
|
||||
"http 0.2.12",
|
||||
"http-body 0.4.6",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.5.2"
|
||||
@@ -691,9 +777,9 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"h2 0.4.7",
|
||||
"http 1.2.0",
|
||||
"http-body 1.0.1",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
@@ -703,6 +789,20 @@ dependencies = [
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.24.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"http 0.2.12",
|
||||
"hyper 0.14.32",
|
||||
"rustls 0.21.12",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.27.5"
|
||||
@@ -710,14 +810,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"http",
|
||||
"hyper",
|
||||
"http 1.2.0",
|
||||
"hyper 1.5.2",
|
||||
"hyper-util",
|
||||
"rustls",
|
||||
"rustls 0.23.20",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.26.1",
|
||||
"tower-service",
|
||||
"webpki-roots 0.26.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -728,7 +829,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper 1.5.2",
|
||||
"hyper-util",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
@@ -745,9 +846,9 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http 1.2.0",
|
||||
"http-body 1.0.1",
|
||||
"hyper 1.5.2",
|
||||
"pin-project-lite",
|
||||
"socket2",
|
||||
"tokio",
|
||||
@@ -985,12 +1086,6 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.169"
|
||||
@@ -1106,6 +1201,26 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oauth2"
|
||||
version = "4.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c38841cdd844847e3e7c8d29cef9dcfed8877f8f56f9071f77843ecf3baf937f"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"chrono",
|
||||
"getrandom",
|
||||
"http 0.2.12",
|
||||
"rand",
|
||||
"reqwest 0.11.27",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"sha2",
|
||||
"thiserror 1.0.69",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.7"
|
||||
@@ -1304,6 +1419,58 @@ version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustls 0.23.20",
|
||||
"socket2",
|
||||
"thiserror 2.0.9",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"getrandom",
|
||||
"rand",
|
||||
"ring",
|
||||
"rustc-hash",
|
||||
"rustls 0.23.20",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
"thiserror 2.0.9",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904"
|
||||
dependencies = [
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2",
|
||||
"tracing",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.38"
|
||||
@@ -1372,6 +1539,47 @@ version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.11.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.3.26",
|
||||
"http 0.2.12",
|
||||
"http-body 0.4.6",
|
||||
"hyper 0.14.32",
|
||||
"hyper-rustls 0.24.2",
|
||||
"ipnet",
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls 0.21.12",
|
||||
"rustls-pemfile 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper 0.1.2",
|
||||
"system-configuration 0.5.1",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.1",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
"webpki-roots 0.25.4",
|
||||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.12"
|
||||
@@ -1379,17 +1587,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
|
||||
dependencies = [
|
||||
"async-compression",
|
||||
"base64",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"h2 0.4.7",
|
||||
"http 1.2.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper 1.5.2",
|
||||
"hyper-rustls 0.27.5",
|
||||
"hyper-tls",
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
@@ -1400,14 +1608,18 @@ dependencies = [
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls-pemfile",
|
||||
"quinn",
|
||||
"rustls 0.23.20",
|
||||
"rustls-pemfile 2.2.0",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"system-configuration",
|
||||
"sync_wrapper 1.0.2",
|
||||
"system-configuration 0.6.1",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-rustls 0.26.1",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower-service",
|
||||
@@ -1416,6 +1628,7 @@ dependencies = [
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"webpki-roots 0.26.7",
|
||||
"windows-registry",
|
||||
]
|
||||
|
||||
@@ -1441,6 +1654,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"chrono",
|
||||
"fallible-iterator",
|
||||
"fallible-streaming-iterator",
|
||||
"hashlink",
|
||||
@@ -1454,6 +1668,12 @@ version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.42"
|
||||
@@ -1467,6 +1687,18 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.21.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-webpki 0.101.7",
|
||||
"sct",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.20"
|
||||
@@ -1474,12 +1706,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"rustls-webpki 0.102.8",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "2.2.0"
|
||||
@@ -1494,6 +1736,19 @@ name = "rustls-pki-types"
|
||||
version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37"
|
||||
dependencies = [
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.101.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
@@ -1527,6 +1782,16 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.11.1"
|
||||
@@ -1681,6 +1946,12 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sync_wrapper"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||
|
||||
[[package]]
|
||||
name = "sync_wrapper"
|
||||
version = "1.0.2"
|
||||
@@ -1714,6 +1985,17 @@ dependencies = [
|
||||
"windows",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation",
|
||||
"system-configuration-sys 0.5.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.6.1"
|
||||
@@ -1722,7 +2004,17 @@ checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"core-foundation",
|
||||
"system-configuration-sys",
|
||||
"system-configuration-sys 0.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration-sys"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1748,6 +2040,46 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
||||
dependencies = [
|
||||
"thiserror-impl 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc"
|
||||
dependencies = [
|
||||
"thiserror-impl 2.0.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinystr"
|
||||
version = "0.7.6"
|
||||
@@ -1758,6 +2090,21 @@ dependencies = [
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
|
||||
dependencies = [
|
||||
"tinyvec_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec_macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.42.0"
|
||||
@@ -1795,13 +2142,23 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.24.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
|
||||
dependencies = [
|
||||
"rustls 0.21.12",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"rustls 0.23.20",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
@@ -1838,7 +2195,7 @@ dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -1853,7 +2210,7 @@ checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"bytes",
|
||||
"http",
|
||||
"http 1.2.0",
|
||||
"pin-project-lite",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -1924,6 +2281,7 @@ dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2070,6 +2428,31 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-time"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.25.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1"
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weezl"
|
||||
version = "0.1.8"
|
||||
@@ -2105,7 +2488,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143"
|
||||
dependencies = [
|
||||
"windows-core 0.57.0",
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2114,7 +2497,7 @@ version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2126,7 +2509,7 @@ dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-result 0.1.2",
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2159,7 +2542,7 @@ checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
|
||||
dependencies = [
|
||||
"windows-result 0.2.0",
|
||||
"windows-strings",
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2168,7 +2551,7 @@ version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2177,7 +2560,7 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2187,7 +2570,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
|
||||
dependencies = [
|
||||
"windows-result 0.2.0",
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.48.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
|
||||
dependencies = [
|
||||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2196,7 +2588,7 @@ version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2205,7 +2597,22 @@ version = "0.59.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.48.5",
|
||||
"windows_aarch64_msvc 0.48.5",
|
||||
"windows_i686_gnu 0.48.5",
|
||||
"windows_i686_msvc 0.48.5",
|
||||
"windows_x86_64_gnu 0.48.5",
|
||||
"windows_x86_64_gnullvm 0.48.5",
|
||||
"windows_x86_64_msvc 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2214,28 +2621,46 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_aarch64_gnullvm 0.52.6",
|
||||
"windows_aarch64_msvc 0.52.6",
|
||||
"windows_i686_gnu 0.52.6",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
"windows_i686_msvc 0.52.6",
|
||||
"windows_x86_64_gnu 0.52.6",
|
||||
"windows_x86_64_gnullvm 0.52.6",
|
||||
"windows_x86_64_msvc 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
@@ -2248,30 +2673,64 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.50.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "write16"
|
||||
version = "1.0.0"
|
||||
|
||||
13
Cargo.toml
13
Cargo.toml
@@ -13,7 +13,7 @@ serde_json = "1.0.134"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.95"
|
||||
axum = { version = "0.7.9", features = ["json"] }
|
||||
axum = { version = "0.7.9", features = ["http2", "json"] }
|
||||
base64 = { version = "0.22.1", default-features = false, features = ["std"] }
|
||||
# brotli = { version = "7.0.0", default-features = false, features = ["std"] }
|
||||
bytes = "1.9.0"
|
||||
@@ -24,13 +24,14 @@ futures = { version = "0.3.31", default-features = false, features = ["std"] }
|
||||
gif = { version = "0.13.1", default-features = false, features = ["std"] }
|
||||
hex = { version = "0.4.3", default-features = false, features = ["std"] }
|
||||
image = { version = "0.25.5", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
|
||||
lazy_static = "1.5.0"
|
||||
oauth2 = { version = "4.4.2", default-features = false, features = ["reqwest", "rustls-tls"] }
|
||||
paste = "1.0.15"
|
||||
prost = "0.13.4"
|
||||
rand = { version = "0.8.5", default-features = false, features = ["std", "std_rng"] }
|
||||
regex = { version = "1.11.1", default-features = false, features = ["std", "perf"] }
|
||||
reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "json", "stream", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] }
|
||||
rusqlite = { version = "0.32.1", features = ["bundled"], optional = true }
|
||||
reqwest = { version = "0.12.12", default-features = false, features = ["gzip", "json", "stream", "rustls-tls", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] }
|
||||
ring = { version = "0.17.8", default-features = false, features = ["alloc"] }
|
||||
rusqlite = { version = "0.32.1", features = ["bundled", "chrono"] }
|
||||
serde = { version = "1.0.217", default-features = false, features = ["std", "derive"] }
|
||||
serde_json = "1.0.134"
|
||||
sha2 = { version = "0.10.8", default-features = false }
|
||||
@@ -47,7 +48,3 @@ codegen-units = 1
|
||||
panic = 'abort'
|
||||
strip = true
|
||||
opt-level = 3
|
||||
|
||||
[features]
|
||||
default = []
|
||||
sqlite = ["dep:rusqlite"]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod config;
|
||||
pub mod constant;
|
||||
#[cfg(feature = "sqlite")]
|
||||
pub mod db;
|
||||
pub mod model;
|
||||
pub mod lazy;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
constant::{HEADER_NAME_AUTHORIZATION, AUTHORIZATION_BEARER_PREFIX},
|
||||
constant::AUTHORIZATION_BEARER_PREFIX,
|
||||
model::{AppConfig, AppState},
|
||||
lazy::AUTH_TOKEN,
|
||||
lazy::ADMIN_AUTH_TOKEN,
|
||||
};
|
||||
use crate::common::models::{
|
||||
config::{ConfigData, ConfigUpdateRequest},
|
||||
@@ -9,7 +9,7 @@ use crate::common::models::{
|
||||
};
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
http::{header::AUTHORIZATION, HeaderMap, StatusCode},
|
||||
Json,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
@@ -59,7 +59,7 @@ pub async fn handle_config_update(
|
||||
Json(request): Json<ConfigUpdateRequest>,
|
||||
) -> Result<Json<NormalResponse<ConfigData>>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let auth_header = headers
|
||||
.get(HEADER_NAME_AUTHORIZATION)
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.ok_or((
|
||||
@@ -72,7 +72,7 @@ pub async fn handle_config_update(
|
||||
}),
|
||||
))?;
|
||||
|
||||
if auth_header != AUTH_TOKEN.as_str() {
|
||||
if auth_header != ADMIN_AUTH_TOKEN.as_str() {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ErrorResponse {
|
||||
|
||||
@@ -5,13 +5,20 @@ macro_rules! def_pub_const {
|
||||
}
|
||||
|
||||
def_pub_const!(PKG_VERSION, env!("CARGO_PKG_VERSION"));
|
||||
// def_pub_const!(PKG_NAME, env!("CARGO_PKG_NAME"));
|
||||
def_pub_const!(PKG_NAME, env!("CARGO_PKG_NAME"));
|
||||
// def_pub_const!(PKG_DESCRIPTION, env!("CARGO_PKG_DESCRIPTION"));
|
||||
// def_pub_const!(PKG_AUTHORS, env!("CARGO_PKG_AUTHORS"));
|
||||
// def_pub_const!(PKG_REPOSITORY, env!("CARGO_PKG_REPOSITORY"));
|
||||
|
||||
def_pub_const!(EMPTY_STRING, "");
|
||||
|
||||
// v1
|
||||
def_pub_const!(ROUTE_MODELS_PATH, "/models");
|
||||
def_pub_const!(ROUTE_CHAT_PATH, "/chat/completions");
|
||||
|
||||
// api
|
||||
def_pub_const!(ROUTE_API_PATH, "/api");
|
||||
|
||||
def_pub_const!(ROUTE_ROOT_PATH, "/");
|
||||
def_pub_const!(ROUTE_HEALTH_PATH, "/health");
|
||||
def_pub_const!(ROUTE_GET_CHECKSUM, "/get-checksum");
|
||||
@@ -28,15 +35,15 @@ def_pub_const!(ROUTE_SHARED_JS_PATH, "/static/shared.js");
|
||||
def_pub_const!(ROUTE_ABOUT_PATH, "/about");
|
||||
def_pub_const!(ROUTE_README_PATH, "/readme");
|
||||
|
||||
def_pub_const!(DEFAULT_TOKEN_FILE_NAME, ".token");
|
||||
def_pub_const!(DEFAULT_TOKEN_LIST_FILE_NAME, ".token-list");
|
||||
// api/auth
|
||||
def_pub_const!(ROUTE_AUTH_PATH, "/auth");
|
||||
def_pub_const!(ROUTE_AUTH_CALLBACK_PATH, "/callback");
|
||||
def_pub_const!(ROUTE_AUTH_INITIATE_PATH, "/initiate");
|
||||
|
||||
def_pub_const!(STATUS_SUCCESS, "success");
|
||||
def_pub_const!(STATUS_FAILED, "failed");
|
||||
def_pub_const!(FALSE, "false");
|
||||
def_pub_const!(TRUE, "true");
|
||||
|
||||
def_pub_const!(HEADER_NAME_CONTENT_TYPE, "content-type");
|
||||
def_pub_const!(HEADER_NAME_AUTHORIZATION, "authorization");
|
||||
def_pub_const!(HEADER_NAME_LOCATION, "Location");
|
||||
def_pub_const!(HEADER_NAME_GHOST_MODE, "x-ghost-mode");
|
||||
|
||||
def_pub_const!(CONTENT_TYPE_PROTO, "application/proto");
|
||||
def_pub_const!(CONTENT_TYPE_CONNECT_PROTO, "application/connect+proto");
|
||||
|
||||
388
src/app/db.rs
388
src/app/db.rs
@@ -1,262 +1,190 @@
|
||||
use crate::app::model::{RequestLog, TokenInfo};
|
||||
use crate::common::models::usage::UserUsageInfo;
|
||||
use chrono::{DateTime, Local};
|
||||
use lazy_static::lazy_static;
|
||||
use rusqlite::params;
|
||||
mod logs;
|
||||
mod tokens;
|
||||
mod users;
|
||||
|
||||
use chrono::Utc;
|
||||
use rusqlite::{Connection, Result};
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use tokio::time::{self, Duration};
|
||||
|
||||
const DB_PATH: &str = "logs/sqlite.db";
|
||||
|
||||
pub struct AppDb {
|
||||
pub struct Database {
|
||||
conn: Connection,
|
||||
}
|
||||
|
||||
impl AppDb {
|
||||
pub fn new() -> Result<Self> {
|
||||
// 确保目录存在
|
||||
if let Some(parent) = Path::new(DB_PATH).parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
rusqlite::Error::SqliteFailure(
|
||||
rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_IOERR),
|
||||
Some(e.to_string()),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
// 全局静态 Database 实例
|
||||
static DB: OnceLock<Mutex<Database>> = OnceLock::new();
|
||||
|
||||
let conn = Connection::open(DB_PATH)?;
|
||||
// 用于控制清理任务的标志
|
||||
static CLEANER_RUNNING: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
// 启用WAL模式以提升性能
|
||||
conn.execute_batch("PRAGMA journal_mode = WAL")?;
|
||||
impl Database {
|
||||
pub fn new(path: &str) -> Result<Self> {
|
||||
let conn = Connection::open(path)?;
|
||||
|
||||
// 创建token信息表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS token_infos (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
checksum TEXT NOT NULL,
|
||||
alias TEXT,
|
||||
fast_requests INTEGER,
|
||||
max_fast_requests INTEGER
|
||||
)",
|
||||
[],
|
||||
// 启用 WAL 模式
|
||||
conn.execute_batch(
|
||||
"
|
||||
PRAGMA journal_mode = WAL; -- 启用 WAL 模式
|
||||
PRAGMA synchronous = NORMAL; -- 适度的同步模式
|
||||
PRAGMA cache_size = -64000; -- 64MB 缓存
|
||||
PRAGMA foreign_keys = ON; -- 启用外键约束
|
||||
PRAGMA temp_store = MEMORY; -- 临时表使用内存
|
||||
PRAGMA mmap_size = 30000000000; -- 30GB mmap
|
||||
",
|
||||
)?;
|
||||
|
||||
// 创建请求日志表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS request_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
token_id INTEGER NOT NULL,
|
||||
prompt TEXT,
|
||||
stream BOOLEAN NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
error TEXT,
|
||||
FOREIGN KEY(token_id) REFERENCES token_infos(id)
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
// 创建索引
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_token ON token_infos(token)",
|
||||
[],
|
||||
)?;
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_timestamp_model ON request_logs(timestamp, model)",
|
||||
[],
|
||||
)?;
|
||||
// 按照依赖顺序初始化表
|
||||
Self::init_users_table(&conn)?;
|
||||
Self::init_tokens_table(&conn)?;
|
||||
Self::init_logs_table(&conn)?;
|
||||
|
||||
Ok(Self { conn })
|
||||
}
|
||||
|
||||
fn get_or_create_token_info(&self, token_info: &TokenInfo) -> Result<i64> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||
RETURNING id"
|
||||
)?;
|
||||
|
||||
stmt.query_row(
|
||||
params![
|
||||
&token_info.token,
|
||||
&token_info.checksum,
|
||||
&token_info.alias,
|
||||
token_info.usage.as_ref().map(|u| u.fast_requests),
|
||||
token_info.usage.as_ref().map(|u| u.max_fast_requests),
|
||||
],
|
||||
|row| row.get(0),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn add_log(&self, log: &RequestLog) -> Result<()> {
|
||||
let token_id = self.get_or_create_token_info(&log.token_info)?;
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO request_logs (timestamp, model, token_id, prompt, stream, status, error)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
|
||||
params![
|
||||
log.timestamp.to_rfc3339(),
|
||||
&log.model,
|
||||
token_id,
|
||||
&log.prompt,
|
||||
log.stream,
|
||||
&log.status,
|
||||
&log.error,
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn map_row_to_log(&self, row: &rusqlite::Row) -> Result<RequestLog> {
|
||||
let token_id: i64 = row.get(3)?;
|
||||
let token_info = self.get_token_info_by_id(token_id)?;
|
||||
|
||||
Ok(RequestLog {
|
||||
id: row.get(0)?,
|
||||
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
|
||||
.unwrap()
|
||||
.with_timezone(&Local),
|
||||
model: row.get(2)?,
|
||||
token_info,
|
||||
prompt: row.get(4)?,
|
||||
stream: row.get(5)?,
|
||||
status: row.get(6)?,
|
||||
error: row.get(7)?,
|
||||
pub fn init(path: &str) -> Result<()> {
|
||||
let db = Database::new(path)?;
|
||||
DB.set(Mutex::new(db)).map_err(|_| {
|
||||
rusqlite::Error::InvalidParameterName("Database already initialized".into())
|
||||
})
|
||||
}
|
||||
|
||||
fn get_token_info_by_id(&self, id: i64) -> Result<TokenInfo> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"SELECT token, checksum, alias, fast_requests, max_fast_requests
|
||||
FROM token_infos
|
||||
WHERE id = ?",
|
||||
)?;
|
||||
pub fn global() -> &'static Mutex<Database> {
|
||||
DB.get().expect("Database not initialized")
|
||||
}
|
||||
|
||||
stmt.query_row([id], |row| {
|
||||
Ok(TokenInfo {
|
||||
token: row.get(0)?,
|
||||
checksum: row.get(1)?,
|
||||
alias: row.get(2)?,
|
||||
usage: Some(UserUsageInfo {
|
||||
fast_requests: row.get(3)?,
|
||||
max_fast_requests: row.get(4)?,
|
||||
}),
|
||||
})
|
||||
pub fn conn(&self) -> &Connection {
|
||||
&self.conn
|
||||
}
|
||||
|
||||
pub fn conn_mut(&mut self) -> &mut Connection {
|
||||
&mut self.conn
|
||||
}
|
||||
|
||||
// 启动定时清理任务
|
||||
pub fn start_cleaner() {
|
||||
// 确保只启动一次
|
||||
if CLEANER_RUNNING.swap(true, Ordering::SeqCst) {
|
||||
return;
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
// 等待到下一个 UTC 20:00
|
||||
let now = Utc::now();
|
||||
let next = (now.date_naive() + chrono::Duration::days(1))
|
||||
.and_hms_opt(20, 0, 0)
|
||||
.unwrap();
|
||||
let duration = next.signed_duration_since(now.naive_utc());
|
||||
|
||||
time::sleep(Duration::from_secs(duration.num_seconds() as u64)).await;
|
||||
|
||||
if let Err(e) = Self::clean_expired_tokens().await {
|
||||
eprintln!("Failed to clean expired tokens: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 清理过期数据
|
||||
async fn clean_expired_tokens() -> Result<()> {
|
||||
with_db_mut(|conn| {
|
||||
let tx = conn.transaction()?;
|
||||
|
||||
// 删除过期token相关的日志
|
||||
tx.execute(
|
||||
"DELETE FROM logs WHERE token_id IN (
|
||||
SELECT id FROM tokens
|
||||
WHERE status = 2 OR
|
||||
(status = 1 AND duration > 0 AND
|
||||
datetime(create_at, '+' || (duration / 86400) || ' days') < datetime('now'))
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
// 删除过期token
|
||||
tx.execute(
|
||||
"DELETE FROM tokens
|
||||
WHERE status = 2 OR
|
||||
(status = 1 AND duration > 0 AND
|
||||
datetime(create_at, '+' || (duration / 86400) || ' days') < datetime('now'))",
|
||||
[],
|
||||
)?;
|
||||
|
||||
// 执行WAL清理
|
||||
tx.execute_batch("PRAGMA wal_checkpoint(TRUNCATE)")?;
|
||||
|
||||
tx.commit()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_token_infos(&self) -> Result<Vec<TokenInfo>> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"SELECT token, checksum, alias, fast_requests, max_fast_requests
|
||||
FROM token_infos",
|
||||
)?;
|
||||
// 停止清理任务
|
||||
// pub fn stop_cleaner() {
|
||||
// CLEANER_RUNNING.store(false, Ordering::SeqCst);
|
||||
// }
|
||||
}
|
||||
|
||||
let tokens = stmt.query_map([], |row| {
|
||||
Ok(TokenInfo {
|
||||
token: row.get(0)?,
|
||||
checksum: row.get(1)?,
|
||||
alias: row.get(2)?,
|
||||
usage: Some(UserUsageInfo {
|
||||
fast_requests: row.get(3)?,
|
||||
max_fast_requests: row.get(4)?,
|
||||
}),
|
||||
})
|
||||
})?;
|
||||
tokens.collect()
|
||||
pub fn with_db<F, T>(f: F) -> Result<T>
|
||||
where
|
||||
F: FnOnce(&Connection) -> Result<T>,
|
||||
{
|
||||
let guard = Database::global().lock().expect("Database lock poisoned");
|
||||
f(guard.conn())
|
||||
}
|
||||
|
||||
pub fn with_db_mut<F, T>(f: F) -> Result<T>
|
||||
where
|
||||
F: FnOnce(&mut Connection) -> Result<T>,
|
||||
{
|
||||
let mut guard = Database::global().lock().expect("Database lock poisoned");
|
||||
f(guard.conn_mut())
|
||||
}
|
||||
|
||||
// 重新导出子模块
|
||||
pub use self::logs::*;
|
||||
pub use self::tokens::*;
|
||||
pub use self::users::*;
|
||||
|
||||
/*
|
||||
// 以下是可选的扩展功能,暂时注释掉
|
||||
|
||||
impl Drop for Database {
|
||||
fn drop(&mut self) {
|
||||
// 这里可以添加清理代码
|
||||
// Connection 会自动关闭,但如果有其他清理工作可以在这里进行
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_recent_logs(&self, limit: i64) -> Result<Vec<RequestLog>> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests
|
||||
FROM request_logs r
|
||||
JOIN token_infos t ON r.token_id = t.id
|
||||
ORDER BY r.timestamp DESC
|
||||
LIMIT ?",
|
||||
)?;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
let logs = stmt.query_map([limit], |row| {
|
||||
Ok(RequestLog {
|
||||
id: row.get(0)?,
|
||||
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
|
||||
.unwrap()
|
||||
.with_timezone(&Local),
|
||||
model: row.get(2)?,
|
||||
token_info: TokenInfo {
|
||||
token: row.get(8)?,
|
||||
checksum: row.get(9)?,
|
||||
alias: row.get(10)?,
|
||||
usage: Some(UserUsageInfo {
|
||||
fast_requests: row.get(11)?,
|
||||
max_fast_requests: row.get(12)?,
|
||||
}),
|
||||
},
|
||||
prompt: row.get(4)?,
|
||||
stream: row.get(5)?,
|
||||
status: row.get(6)?,
|
||||
error: row.get(7)?,
|
||||
})
|
||||
})?;
|
||||
logs.collect()
|
||||
static DB_CONFIG: LazyLock<DbConfig> = LazyLock::new(|| {
|
||||
DbConfig {
|
||||
max_connections: 10,
|
||||
timeout: std::time::Duration::from_secs(30),
|
||||
}
|
||||
});
|
||||
|
||||
pub fn get_logs_by_timerange(
|
||||
&self,
|
||||
start: DateTime<Local>,
|
||||
end: DateTime<Local>,
|
||||
) -> Result<Vec<RequestLog>> {
|
||||
let mut stmt = self.conn.prepare_cached(
|
||||
"SELECT r.id, r.timestamp, r.model, r.token_id, r.prompt, r.stream, r.status, r.error, t.token, t.checksum, t.alias, t.fast_requests, t.max_fast_requests
|
||||
FROM request_logs r
|
||||
JOIN token_infos t ON r.token_id = t.id
|
||||
WHERE r.timestamp BETWEEN ?1 AND ?2
|
||||
ORDER BY r.timestamp DESC",
|
||||
)?;
|
||||
struct DbConfig {
|
||||
max_connections: u32,
|
||||
timeout: std::time::Duration,
|
||||
}
|
||||
|
||||
let logs = stmt.query_map([start.to_rfc3339(), end.to_rfc3339()], |row| {
|
||||
Ok(RequestLog {
|
||||
id: row.get(0)?,
|
||||
timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
|
||||
.unwrap()
|
||||
.with_timezone(&Local),
|
||||
model: row.get(2)?,
|
||||
token_info: TokenInfo {
|
||||
token: row.get(8)?,
|
||||
checksum: row.get(9)?,
|
||||
alias: row.get(10)?,
|
||||
usage: Some(UserUsageInfo {
|
||||
fast_requests: row.get(11)?,
|
||||
max_fast_requests: row.get(12)?,
|
||||
}),
|
||||
},
|
||||
prompt: row.get(4)?,
|
||||
stream: row.get(5)?,
|
||||
status: row.get(6)?,
|
||||
error: row.get(7)?,
|
||||
})
|
||||
})?;
|
||||
logs.collect()
|
||||
}
|
||||
|
||||
pub fn update_token_info(&self, token_info: &TokenInfo) -> Result<()> {
|
||||
self.conn.execute(
|
||||
"INSERT OR REPLACE INTO token_infos (token, checksum, alias, fast_requests, max_fast_requests)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||
params![
|
||||
&token_info.token,
|
||||
&token_info.checksum,
|
||||
&token_info.alias,
|
||||
token_info.usage.as_ref().map(|u| u.fast_requests),
|
||||
token_info.usage.as_ref().map(|u| u.max_fast_requests),
|
||||
],
|
||||
)?;
|
||||
pub fn example_usage() -> Result<()> {
|
||||
Database::init("path/to/db.sqlite")?;
|
||||
println!("Max connections: {}", DB_CONFIG.max_connections);
|
||||
with_db(|conn| {
|
||||
Ok(())
|
||||
}
|
||||
})?;
|
||||
with_db_mut(|conn| {
|
||||
let tx = conn.transaction()?;
|
||||
tx.commit()
|
||||
})
|
||||
}
|
||||
*/
|
||||
|
||||
lazy_static! {
|
||||
pub static ref APP_DB: Mutex<AppDb> =
|
||||
Mutex::new(AppDb::new().expect("Failed to initialize database"));
|
||||
// 在应用启动时初始化
|
||||
pub async fn init_database(path: &str) -> Result<()> {
|
||||
Database::init(path)?;
|
||||
Database::start_cleaner();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
271
src/app/db/logs.rs
Normal file
271
src/app/db/logs.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
use super::Database;
|
||||
use crate::{app::model::{LogInfo, LogStatus, TokenInfo}, common::models::usage::UserUsageInfo};
|
||||
use chrono::Local;
|
||||
use rusqlite::{params, Connection, OptionalExtension as _, Result};
|
||||
const MAX_PROMPT_LENGTH: usize = 100000; // 限制 prompt 长度为 100000 字符
|
||||
const MAX_MODEL_LENGTH: usize = 100; // 限制 model 名称长度为 100 字符
|
||||
const MAX_ERROR_LENGTH: usize = 1000; // 限制 error 信息长度为 1000 字符
|
||||
const MAX_QUERY_LIMIT: usize = 1000; // 最大查询数量限制
|
||||
pub fn insert_log(log_info: &LogInfo) -> Result<i64> {
|
||||
super::with_db_mut(|conn| Database::insert_log(conn, log_info))
|
||||
}
|
||||
pub fn get_logs_by_user_id(user_id: Option<i64>) -> Result<Vec<LogInfo>> {
|
||||
super::with_db_mut(|conn| Database::get_logs_by_user_id(conn, user_id))
|
||||
}
|
||||
pub fn get_logs_by_token_id(token_id: i64) -> Result<Vec<LogInfo>> {
|
||||
super::with_db_mut(|conn| Database::get_logs_by_token_id(conn, token_id))
|
||||
}
|
||||
pub fn get_log_by_id(id: i64) -> Result<Option<LogInfo>> {
|
||||
super::with_db(|conn| Database::get_log_by_id(conn, id))
|
||||
}
|
||||
pub fn update_log_status(id: i64, status: LogStatus, error: Option<String>) -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::update_log_status(conn, id, status, error))
|
||||
}
|
||||
pub fn clean_user_logs(user_id: i64, limit: usize) -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::clean_user_logs(conn, user_id, limit))
|
||||
}
|
||||
pub fn get_user_logs_count(user_id: i64) -> Result<i64> {
|
||||
super::with_db(|conn| Database::get_user_logs_count(conn, user_id))
|
||||
}
|
||||
pub fn update_log_usage(log_id: i64, usage: Option<UserUsageInfo>) -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::update_log_usage(conn, log_id, usage))
|
||||
}
|
||||
pub fn update_log_prompt(log_id: i64, prompt: Option<String>) -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::update_log_prompt(conn, log_id, prompt))
|
||||
}
|
||||
impl Database {
|
||||
pub fn init_logs_table(conn: &Connection) -> Result<()> {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
token_id INTEGER NOT NULL,
|
||||
prompt TEXT,
|
||||
model TEXT NOT NULL,
|
||||
stream BOOLEAN NOT NULL,
|
||||
status INTEGER NOT NULL,
|
||||
error TEXT,
|
||||
FOREIGN KEY(token_id) REFERENCES tokens(id)
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn insert_log(conn: &mut Connection, log_info: &LogInfo) -> Result<i64> {
|
||||
// 输入验证
|
||||
if let Some(prompt) = &log_info.prompt {
|
||||
if prompt.len() > MAX_PROMPT_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Prompt too long".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
if log_info.model.len() > MAX_MODEL_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Model name too long".to_string(),
|
||||
));
|
||||
}
|
||||
if let Some(error) = &log_info.error {
|
||||
if error.len() > MAX_ERROR_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Error message too long".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute(
|
||||
"INSERT INTO logs (
|
||||
timestamp, token_id, prompt, model,
|
||||
stream, status, error
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
|
||||
params![
|
||||
Local::now(),
|
||||
log_info.token_info.id,
|
||||
log_info.prompt,
|
||||
log_info.model,
|
||||
log_info.stream,
|
||||
log_info.status,
|
||||
log_info.error,
|
||||
],
|
||||
)?;
|
||||
let id = tx.last_insert_rowid();
|
||||
tx.commit()?;
|
||||
Ok(id)
|
||||
}
|
||||
pub fn get_logs_by_user_id(
|
||||
conn: &mut Connection,
|
||||
user_id: Option<i64>,
|
||||
) -> Result<Vec<LogInfo>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT l.id, l.timestamp, l.prompt, l.model, l.stream, l.status, l.error,
|
||||
t.id, t.create_at, t.token, t.checksum, t.alias,
|
||||
t.status, t.pengding_at, t.user_id, t.is_public, t.usage
|
||||
FROM logs l
|
||||
JOIN tokens t ON l.token_id = t.id
|
||||
WHERE t.user_id IS ?1
|
||||
ORDER BY l.timestamp DESC
|
||||
LIMIT 100",
|
||||
)?;
|
||||
let logs_iter = stmt.query_map(params![user_id], Self::row_to_log_info)?;
|
||||
let mut logs = Vec::with_capacity(100);
|
||||
for log in logs_iter {
|
||||
logs.push(log?);
|
||||
}
|
||||
Ok(logs)
|
||||
}
|
||||
pub fn get_logs_by_token_id(conn: &mut Connection, token_id: i64) -> Result<Vec<LogInfo>> {
|
||||
// 使用事务确保一致性
|
||||
let tx = conn.transaction()?;
|
||||
// 先获取token信息
|
||||
let token = Self::get_token_by_id(&tx, token_id)?
|
||||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows)?;
|
||||
// 查询日志记录
|
||||
let logs = {
|
||||
let mut stmt = tx.prepare(
|
||||
"SELECT id, timestamp, prompt, model, stream, status, error
|
||||
FROM logs
|
||||
WHERE token_id = ?1
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?2",
|
||||
)?;
|
||||
let mut logs = Vec::with_capacity(100);
|
||||
let logs_iter = stmt.query_map(params![token_id, MAX_QUERY_LIMIT as i64], |row| {
|
||||
Ok(LogInfo {
|
||||
id: row.get(0)?,
|
||||
timestamp: row.get(1)?,
|
||||
prompt: row.get(2)?,
|
||||
model: row.get(3)?,
|
||||
stream: row.get(4)?,
|
||||
status: row.get(5)?,
|
||||
error: row.get(6)?,
|
||||
token_info: token.clone(),
|
||||
})
|
||||
})?;
|
||||
for log in logs_iter {
|
||||
logs.push(log?);
|
||||
}
|
||||
logs
|
||||
};
|
||||
tx.commit()?;
|
||||
Ok(logs)
|
||||
}
|
||||
pub fn get_log_by_id(conn: &Connection, id: i64) -> Result<Option<LogInfo>> {
|
||||
conn.query_row(
|
||||
"SELECT l.id, l.timestamp, l.prompt, l.model, l.stream, l.status, l.error,
|
||||
t.id, t.create_at, t.token, t.checksum, t.alias,
|
||||
t.status, t.pengding_at, t.user_id, t.is_public, t.usage
|
||||
FROM logs l
|
||||
JOIN tokens t ON l.token_id = t.id
|
||||
WHERE l.id = ?1",
|
||||
params![id],
|
||||
Self::row_to_log_info,
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
pub fn update_log_status(
|
||||
conn: &mut Connection,
|
||||
id: i64,
|
||||
status: LogStatus,
|
||||
error: Option<String>,
|
||||
) -> Result<()> {
|
||||
// 验证 error 长度
|
||||
if let Some(error_msg) = &error {
|
||||
if error_msg.len() > MAX_ERROR_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Error message too long".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute(
|
||||
"UPDATE logs SET status = ?1, error = ?2 WHERE id = ?3",
|
||||
params![status, error, id],
|
||||
)?;
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
fn row_to_log_info(row: &rusqlite::Row<'_>) -> Result<LogInfo> {
|
||||
let token_info = TokenInfo {
|
||||
id: row.get(7)?,
|
||||
create_at: row.get(8)?,
|
||||
token: row.get(9)?,
|
||||
checksum: row.get(10)?,
|
||||
alias: row.get(11)?,
|
||||
status: row.get(12)?,
|
||||
pengding_at: row.get(13)?,
|
||||
user_id: row.get(14)?,
|
||||
is_public: row.get(15)?,
|
||||
usage: row.get(16)?,
|
||||
};
|
||||
Ok(LogInfo {
|
||||
id: row.get(0)?,
|
||||
timestamp: row.get(1)?,
|
||||
prompt: row.get(2)?,
|
||||
model: row.get(3)?,
|
||||
stream: row.get(4)?,
|
||||
status: row.get(5)?,
|
||||
error: row.get(6)?,
|
||||
token_info,
|
||||
})
|
||||
}
|
||||
pub fn clean_user_logs(conn: &mut Connection, user_id: i64, limit: usize) -> Result<()> {
|
||||
let tx = conn.transaction()?;
|
||||
// 先获取所有需要删除的日志ID
|
||||
let mut stmt = tx.prepare(
|
||||
"WITH RankedLogs AS (
|
||||
SELECT l.id,
|
||||
ROW_NUMBER() OVER (ORDER BY l.timestamp DESC) as rn
|
||||
FROM logs l
|
||||
JOIN tokens t ON l.token_id = t.id
|
||||
WHERE t.user_id = ?1
|
||||
)
|
||||
SELECT id FROM RankedLogs WHERE rn > ?2",
|
||||
)?;
|
||||
let log_ids: Vec<i64> = stmt
|
||||
.query_map(params![user_id, limit as i64], |row| row.get(0))?
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
// 确保 stmt 被释放
|
||||
drop(stmt);
|
||||
// 如果有需要删除的日志
|
||||
if !log_ids.is_empty() {
|
||||
// 直接更新状态,不使用 IN 子句
|
||||
tx.execute(
|
||||
"UPDATE logs SET status = ?1
|
||||
WHERE id IN (
|
||||
SELECT l.id
|
||||
FROM logs l
|
||||
JOIN tokens t ON l.token_id = t.id
|
||||
WHERE t.user_id = ?2
|
||||
ORDER BY l.timestamp ASC
|
||||
LIMIT -1
|
||||
OFFSET ?3
|
||||
)",
|
||||
params![LogStatus::Deleted as u8, user_id, limit],
|
||||
)?;
|
||||
}
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn get_user_logs_count(conn: &Connection, user_id: i64) -> Result<i64> {
|
||||
conn.query_row(
|
||||
"SELECT COUNT(*)
|
||||
FROM logs l
|
||||
JOIN tokens t ON l.token_id = t.id
|
||||
WHERE t.user_id = ?1 AND l.status != ?2",
|
||||
params![user_id, LogStatus::Deleted],
|
||||
|row| row.get(0),
|
||||
)
|
||||
}
|
||||
pub fn update_log_usage(conn: &mut Connection, log_id: i64, usage: Option<UserUsageInfo>) -> Result<()> {
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute("UPDATE logs SET usage = ?1 WHERE id = ?2", params![usage, log_id])?;
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn update_log_prompt(conn: &mut Connection, log_id: i64, prompt: Option<String>) -> Result<()> {
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute("UPDATE logs SET prompt = ?1 WHERE id = ?2", params![prompt, log_id])?;
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
305
src/app/db/tokens.rs
Normal file
305
src/app/db/tokens.rs
Normal file
@@ -0,0 +1,305 @@
|
||||
use crate::app::model::{TokenInfo, TokenStatus};
|
||||
use chrono::Local;
|
||||
use rusqlite::{params, Connection, OptionalExtension as _, Result};
|
||||
use super::Database;
|
||||
// 限制字段长度
|
||||
const MAX_TOKEN_LENGTH: usize = 1000; // 限制 token 长度为 1000 字符
|
||||
const MAX_CHECKSUM_LENGTH: usize = 200; // 限制 checksum 长度为 200 字符
|
||||
const MAX_ALIAS_LENGTH: usize = 100; // 限制 alias 长度为 100 字符
|
||||
const MAX_QUERY_LIMIT: usize = 1000; // 最大查询数量限制
|
||||
pub fn insert_token(token_info: &TokenInfo) -> Result<i64> {
|
||||
super::with_db_mut(|conn| Database::insert_token(conn, token_info))
|
||||
}
|
||||
pub fn get_tokens_by_user_id(user_id: Option<i64>) -> Result<Vec<TokenInfo>> {
|
||||
super::with_db(|conn| Database::get_tokens_by_user_id(conn, user_id))
|
||||
}
|
||||
pub fn get_available_tokens_by_user_id(user_id: Option<i64>) -> Result<Vec<TokenInfo>> {
|
||||
super::with_db(|conn| Database::get_available_tokens_by_user_id(conn, user_id))
|
||||
}
|
||||
pub fn get_token_by_id(id: i64) -> Result<Option<TokenInfo>> {
|
||||
super::with_db(|conn| Database::get_token_by_id(conn, id))
|
||||
}
|
||||
pub fn get_token_by_token(token: &str) -> Result<Option<TokenInfo>> {
|
||||
super::with_db(|conn| Database::get_token_by_token(conn, token))
|
||||
}
|
||||
pub fn update_token_status(id: i64, status: TokenStatus) -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::update_token_status(conn, id, status))
|
||||
}
|
||||
pub fn delete_expired_tokens() -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::delete_expired_tokens(conn))
|
||||
}
|
||||
pub fn get_token_by_alias_and_user(
|
||||
alias: &str,
|
||||
current_user_id: i64,
|
||||
target_user_id: Option<i64>,
|
||||
) -> Result<Option<TokenInfo>> {
|
||||
super::with_db(|conn| {
|
||||
Database::get_token_by_alias_and_user(conn, alias, current_user_id, target_user_id)
|
||||
})
|
||||
}
|
||||
pub fn update_token(token_info: &TokenInfo) -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::update_token(conn, token_info))
|
||||
}
|
||||
impl Database {
|
||||
pub fn init_tokens_table(conn: &Connection) -> Result<()> {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
create_at TEXT NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
checksum TEXT NOT NULL,
|
||||
alias TEXT,
|
||||
status INTEGER NOT NULL,
|
||||
pengding_at TEXT NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
is_public BOOLEAN NOT NULL DEFAULT 0,
|
||||
usage TEXT,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id),
|
||||
UNIQUE(alias, user_id)
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn insert_token(conn: &mut Connection, token_info: &TokenInfo) -> Result<i64> {
|
||||
// 输入验证
|
||||
if token_info.token.len() > MAX_TOKEN_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Token too long".to_string(),
|
||||
));
|
||||
}
|
||||
if token_info.checksum.len() > MAX_CHECKSUM_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Checksum too long".to_string(),
|
||||
));
|
||||
}
|
||||
if let Some(alias) = &token_info.alias {
|
||||
if alias.len() > MAX_ALIAS_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Alias too long".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute(
|
||||
"INSERT INTO tokens (
|
||||
create_at, token, checksum, alias,
|
||||
status, pengding_at, user_id, is_public, usage
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
|
||||
params![
|
||||
Local::now(),
|
||||
token_info.token,
|
||||
token_info.checksum,
|
||||
token_info.alias,
|
||||
token_info.status,
|
||||
token_info.user_id,
|
||||
token_info.is_public,
|
||||
token_info.usage,
|
||||
],
|
||||
)?;
|
||||
let id = tx.last_insert_rowid();
|
||||
tx.commit()?;
|
||||
Ok(id)
|
||||
}
|
||||
pub fn get_tokens_by_user_id(
|
||||
conn: &Connection,
|
||||
user_id: Option<i64>,
|
||||
) -> Result<Vec<TokenInfo>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, create_at, token, checksum, alias,
|
||||
status, pengding_at, user_id, is_public, usage
|
||||
FROM tokens
|
||||
WHERE user_id IS ?1
|
||||
LIMIT ?2",
|
||||
)?;
|
||||
let tokens_iter = stmt.query_map(
|
||||
params![user_id, MAX_QUERY_LIMIT as i64],
|
||||
Self::row_to_token_info,
|
||||
)?;
|
||||
let mut tokens = Vec::with_capacity(100);
|
||||
for token in tokens_iter {
|
||||
tokens.push(token?);
|
||||
}
|
||||
Ok(tokens)
|
||||
}
|
||||
pub fn get_available_tokens_by_user_id(
|
||||
conn: &Connection,
|
||||
user_id: Option<i64>,
|
||||
) -> Result<Vec<TokenInfo>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, create_at, token, checksum, alias,
|
||||
status, pengding_at, user_id, is_public, usage
|
||||
FROM tokens
|
||||
WHERE status = 1 AND datetime('now') >= pengding_at AND user_id IS ?1
|
||||
LIMIT ?2",
|
||||
)?;
|
||||
let tokens_iter = stmt.query_map(
|
||||
params![user_id, MAX_QUERY_LIMIT as i64],
|
||||
Self::row_to_token_info,
|
||||
)?;
|
||||
let mut tokens = Vec::with_capacity(100);
|
||||
for token in tokens_iter {
|
||||
tokens.push(token?);
|
||||
}
|
||||
Ok(tokens)
|
||||
}
|
||||
pub fn get_token_by_id(conn: &Connection, id: i64) -> Result<Option<TokenInfo>> {
|
||||
conn.query_row(
|
||||
"SELECT id, create_at, token, checksum, alias,
|
||||
status, pengding_at, user_id, is_public, usage
|
||||
FROM tokens
|
||||
WHERE id = ?1",
|
||||
params![id],
|
||||
Self::row_to_token_info,
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
pub fn get_token_by_token(conn: &Connection, token: &str) -> Result<Option<TokenInfo>> {
|
||||
// 输入验证
|
||||
if token.len() > MAX_TOKEN_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Token too long".to_string(),
|
||||
));
|
||||
}
|
||||
conn.query_row(
|
||||
"SELECT id, create_at, token, checksum, alias,
|
||||
status, pengding_at, user_id, is_public, usage
|
||||
FROM tokens
|
||||
WHERE token = ?1",
|
||||
params![token],
|
||||
Self::row_to_token_info,
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
pub fn update_token_status(conn: &mut Connection, id: i64, status: TokenStatus) -> Result<()> {
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute(
|
||||
"UPDATE tokens SET status = ?1 WHERE id = ?2",
|
||||
params![status, id],
|
||||
)?;
|
||||
if status == TokenStatus::Pending {
|
||||
tx.execute(
|
||||
"UPDATE tokens SET pengding_at = ?1 WHERE id = ?2",
|
||||
params![Local::now() + chrono::Duration::minutes(1), id],
|
||||
)?;
|
||||
}
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn delete_expired_tokens(conn: &mut Connection) -> Result<()> {
|
||||
// 开始事务
|
||||
let tx = conn.transaction()?;
|
||||
// 删除过期token相关的日志
|
||||
tx.execute(
|
||||
"DELETE FROM logs WHERE token_id IN (
|
||||
SELECT id FROM tokens
|
||||
WHERE status = ?1
|
||||
)",
|
||||
params![TokenStatus::Expired],
|
||||
)?;
|
||||
// 删除过期的token
|
||||
tx.execute(
|
||||
"DELETE FROM tokens WHERE status = ?1",
|
||||
params![TokenStatus::Expired],
|
||||
)?;
|
||||
// 提交事务
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn get_token_by_alias_and_user(
|
||||
conn: &Connection,
|
||||
alias: &str,
|
||||
current_user_id: i64,
|
||||
target_user_id: Option<i64>,
|
||||
) -> Result<Option<TokenInfo>> {
|
||||
// 管理员可以查看所有token
|
||||
let is_admin = current_user_id == 0;
|
||||
let sql = if is_admin {
|
||||
// 管理员查询:如果指定了user_id就查指定用户的,否则查所有
|
||||
if target_user_id.is_some() {
|
||||
"SELECT id, create_at, token, checksum, alias,
|
||||
status, user_id, is_public, usage
|
||||
FROM tokens
|
||||
WHERE alias = ?1
|
||||
AND status IN (?2, ?3)
|
||||
AND user_id = ?4"
|
||||
} else {
|
||||
"SELECT id, create_at, token, checksum, alias,
|
||||
status, pengding_at, user_id, is_public, usage
|
||||
FROM tokens
|
||||
WHERE alias = ?1
|
||||
AND status IN (?2, ?3)"
|
||||
}
|
||||
} else {
|
||||
// 普通用户查询:只能查看自己的token
|
||||
"SELECT id, create_at, token, checksum, alias,
|
||||
status, pengding_at, user_id, is_public, usage
|
||||
FROM tokens
|
||||
WHERE alias = ?1
|
||||
AND status IN (?2, ?3)
|
||||
AND user_id = ?4"
|
||||
};
|
||||
let target_id = target_user_id.map(|id| id);
|
||||
let params: Vec<&dyn rusqlite::ToSql> = if is_admin {
|
||||
if let Some(ref id) = target_id {
|
||||
vec![&alias, &TokenStatus::Active, &TokenStatus::Pending, id]
|
||||
} else {
|
||||
vec![&alias, &TokenStatus::Active, &TokenStatus::Pending]
|
||||
}
|
||||
} else {
|
||||
vec![
|
||||
&alias,
|
||||
&TokenStatus::Active,
|
||||
&TokenStatus::Pending,
|
||||
¤t_user_id,
|
||||
]
|
||||
};
|
||||
conn.query_row(sql, params.as_slice(), Self::row_to_token_info)
|
||||
.optional()
|
||||
}
|
||||
pub fn update_token(conn: &mut Connection, token_info: &TokenInfo) -> Result<()> {
|
||||
// 输入验证
|
||||
if token_info.checksum.len() > MAX_CHECKSUM_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Checksum too long".to_string(),
|
||||
));
|
||||
}
|
||||
if let Some(alias) = &token_info.alias {
|
||||
if alias.len() > MAX_ALIAS_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Alias too long".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute(
|
||||
"UPDATE tokens SET
|
||||
checksum = ?1,
|
||||
alias = ?2,
|
||||
is_public = ?3
|
||||
WHERE id = ?4",
|
||||
params![
|
||||
token_info.checksum,
|
||||
token_info.alias,
|
||||
token_info.is_public,
|
||||
token_info.id,
|
||||
],
|
||||
)?;
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
fn row_to_token_info(row: &rusqlite::Row<'_>) -> Result<TokenInfo> {
|
||||
Ok(TokenInfo {
|
||||
id: row.get(0)?,
|
||||
create_at: row.get(1)?,
|
||||
token: row.get(2)?,
|
||||
checksum: row.get(3)?,
|
||||
alias: row.get(4)?,
|
||||
status: row.get(5)?,
|
||||
pengding_at: row.get(6)?,
|
||||
user_id: row.get(7)?,
|
||||
is_public: row.get(8)?,
|
||||
usage: row.get(9)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
198
src/app/db/users.rs
Normal file
198
src/app/db/users.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
use crate::app::model::UserInfo;
|
||||
use crate::common::utils::oauth::ForumUser;
|
||||
use chrono::{DateTime, Local};
|
||||
use rusqlite::{params, Connection, OptionalExtension as _, Result};
|
||||
use crate::app::lazy::ADMIN_AUTH_TOKEN;
|
||||
use super::Database;
|
||||
// 限制字段长度
|
||||
const MAX_USERNAME_LENGTH: usize = 100; // 限制用户名长度为 100 字符
|
||||
const MAX_NAME_LENGTH: usize = 100; // 限制姓名长度为 100 字符
|
||||
const MAX_QUERY_LIMIT: usize = 1000; // 最大查询数量限制
|
||||
pub fn insert_user(user: &ForumUser) -> Result<i64> {
|
||||
super::with_db_mut(|conn| Database::insert_user(conn, user))
|
||||
}
|
||||
pub fn get_user_by_id(id: i64) -> Result<Option<UserInfo>> {
|
||||
super::with_db(|conn| Database::get_user_by_id(conn, id))
|
||||
}
|
||||
pub fn get_user_by_forum_id(forum_id: i64) -> Result<Option<UserInfo>> {
|
||||
super::with_db(|conn| Database::get_user_by_forum_id(conn, forum_id))
|
||||
}
|
||||
pub fn update_user_ban(forum_id: i64, ban_expired_at: Option<DateTime<Local>>, ban_count: u32) -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::update_user_ban(conn, forum_id, ban_expired_at, ban_count))
|
||||
}
|
||||
pub fn update_user_auth_token(forum_id: i64, auth_token: Option<String>) -> Result<()> {
|
||||
super::with_db_mut(|conn| Database::update_user_auth_token(conn, forum_id, auth_token))
|
||||
}
|
||||
pub fn get_user_by_auth_token(auth_token: &str) -> Result<Option<UserInfo>> {
|
||||
super::with_db(|conn| Database::get_user_by_auth_token(conn, auth_token))
|
||||
}
|
||||
impl Database {
|
||||
pub fn init_users_table(conn: &Connection) -> Result<()> {
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
forum_id INTEGER NOT NULL UNIQUE,
|
||||
username TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
trust_level INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
ban_expired_at TEXT,
|
||||
ban_count INTEGER NOT NULL,
|
||||
auth_token TEXT
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
let admin_exists: bool = conn
|
||||
.query_row("SELECT EXISTS(SELECT 1 FROM users WHERE id = 0)", [], |row| {
|
||||
row.get(0)
|
||||
})?;
|
||||
if !admin_exists {
|
||||
conn.execute(
|
||||
"INSERT INTO users (
|
||||
id, forum_id, username, name, trust_level,
|
||||
created_at, ban_expired_at, ban_count, auth_token
|
||||
) VALUES (
|
||||
0, 0, 'admin', 'Administrator', 255,
|
||||
?1, NULL, 0, ?2
|
||||
)",
|
||||
params![Local::now(), &*ADMIN_AUTH_TOKEN],
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
pub fn insert_user(conn: &mut Connection, user: &ForumUser) -> Result<i64> {
|
||||
// 输入验证
|
||||
if user.username.len() > MAX_USERNAME_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Username too long".to_string(),
|
||||
));
|
||||
}
|
||||
if user.name.len() > MAX_NAME_LENGTH {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Name too long".to_string(),
|
||||
));
|
||||
}
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute(
|
||||
"INSERT INTO users (forum_id, username, name, trust_level, created_at, ban_expired_at, ban_count, auth_token)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
|
||||
params![
|
||||
user.id,
|
||||
user.username,
|
||||
user.name,
|
||||
user.trust_level,
|
||||
Local::now(),
|
||||
Option::<DateTime<Local>>::None,
|
||||
0,
|
||||
Option::<String>::None
|
||||
],
|
||||
)?;
|
||||
let id = tx.last_insert_rowid();
|
||||
tx.commit()?;
|
||||
Ok(id)
|
||||
}
|
||||
pub fn get_user_by_id(conn: &Connection, id: i64) -> Result<Option<UserInfo>> {
|
||||
conn.query_row(
|
||||
"SELECT id, forum_id, username, name, trust_level, created_at, ban_expired_at, ban_count, auth_token
|
||||
FROM users
|
||||
WHERE id = ?1
|
||||
LIMIT 1",
|
||||
params![id],
|
||||
|row| {
|
||||
Ok(UserInfo {
|
||||
id: row.get(0)?,
|
||||
forum_id: row.get(1)?,
|
||||
username: row.get(2)?,
|
||||
name: row.get(3)?,
|
||||
trust_level: row.get(4)?,
|
||||
created_at: row.get(5)?,
|
||||
ban_expired_at: row.get(6)?,
|
||||
ban_count: row.get(7)?,
|
||||
auth_token: row.get(8)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
pub fn get_user_by_forum_id(conn: &Connection, forum_id: i64) -> Result<Option<UserInfo>> {
|
||||
conn.query_row(
|
||||
"SELECT id, forum_id, username, name, trust_level, created_at, ban_expired_at, ban_count, auth_token
|
||||
FROM users
|
||||
WHERE forum_id = ?1
|
||||
LIMIT 1",
|
||||
params![forum_id],
|
||||
|row| {
|
||||
Ok(UserInfo {
|
||||
id: row.get(0)?,
|
||||
forum_id: row.get(1)?,
|
||||
username: row.get(2)?,
|
||||
name: row.get(3)?,
|
||||
trust_level: row.get(4)?,
|
||||
created_at: row.get(5)?,
|
||||
ban_expired_at: row.get(6)?,
|
||||
ban_count: row.get(7)?,
|
||||
auth_token: row.get(8)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
pub fn update_user_ban(conn: &mut Connection, forum_id: i64, ban_expired_at: Option<DateTime<Local>>, ban_count: u32) -> Result<()> {
|
||||
let tx = conn.transaction()?;
|
||||
tx.execute(
|
||||
"UPDATE users SET ban_expired_at = ?1, ban_count = ?2 WHERE forum_id = ?3",
|
||||
params![ban_expired_at, ban_count, forum_id],
|
||||
)?;
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn update_user_auth_token(
|
||||
conn: &mut Connection,
|
||||
forum_id: i64,
|
||||
auth_token: Option<String>,
|
||||
) -> Result<()> {
|
||||
let tx = conn.transaction()?;
|
||||
// 检查 auth_token 是否已存在
|
||||
if let Some(token) = &auth_token {
|
||||
let existing = tx.query_row(
|
||||
"SELECT forum_id FROM users WHERE auth_token = ?1 AND forum_id != ?2",
|
||||
params![token, forum_id],
|
||||
|_| Ok(()),
|
||||
);
|
||||
if existing.optional()?.is_some() {
|
||||
return Err(rusqlite::Error::InvalidParameterName(
|
||||
"Auth token already exists".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
tx.execute(
|
||||
"UPDATE users SET auth_token = ?1 WHERE forum_id = ?2",
|
||||
params![auth_token, forum_id],
|
||||
)?;
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn get_user_by_auth_token(conn: &Connection, auth_token: &str) -> Result<Option<UserInfo>> {
|
||||
conn.query_row(
|
||||
"SELECT id, forum_id, username, name, trust_level, created_at, ban_expired_at, ban_count, auth_token
|
||||
FROM users
|
||||
WHERE auth_token = ?1
|
||||
LIMIT 1",
|
||||
params![auth_token],
|
||||
|row| {
|
||||
Ok(UserInfo {
|
||||
id: row.get(0)?,
|
||||
forum_id: row.get(1)?,
|
||||
username: row.get(2)?,
|
||||
name: row.get(3)?,
|
||||
trust_level: row.get(4)?,
|
||||
created_at: row.get(5)?,
|
||||
ban_expired_at: row.get(6)?,
|
||||
ban_count: row.get(7)?,
|
||||
auth_token: row.get(8)?,
|
||||
})
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
app::constant::{DEFAULT_TOKEN_FILE_NAME, DEFAULT_TOKEN_LIST_FILE_NAME, EMPTY_STRING},
|
||||
app::constant::EMPTY_STRING,
|
||||
common::utils::parse_string_from_env,
|
||||
};
|
||||
use std::sync::LazyLock;
|
||||
@@ -27,18 +27,8 @@ macro_rules! def_pub_static {
|
||||
// };
|
||||
// }
|
||||
|
||||
def_pub_static!(ROUTE_PREFIX, env: "ROUTE_PREFIX", default: EMPTY_STRING);
|
||||
def_pub_static!(AUTH_TOKEN, env: "AUTH_TOKEN", default: EMPTY_STRING);
|
||||
def_pub_static!(TOKEN_FILE, env: "TOKEN_FILE", default: DEFAULT_TOKEN_FILE_NAME);
|
||||
def_pub_static!(TOKEN_LIST_FILE, env: "TOKEN_LIST_FILE", default: DEFAULT_TOKEN_LIST_FILE_NAME);
|
||||
def_pub_static!(
|
||||
ROUTE_MODELS_PATH,
|
||||
format!("{}/v1/models", *ROUTE_PREFIX)
|
||||
);
|
||||
def_pub_static!(
|
||||
ROUTE_CHAT_PATH,
|
||||
format!("{}/v1/chat/completions", *ROUTE_PREFIX)
|
||||
);
|
||||
def_pub_static!(ROUTE_PREFIX, env: "ROUTE_PREFIX", default: "/v1");
|
||||
def_pub_static!(PUBLIC_AUTH_TOKEN, env: "PUBLIC_AUTH_TOKEN", default: EMPTY_STRING);
|
||||
|
||||
pub static START_TIME: LazyLock<chrono::DateTime<chrono::Local>> =
|
||||
LazyLock::new(chrono::Local::now);
|
||||
@@ -55,6 +45,18 @@ pub static CURSOR_API2_BASE_URL: LazyLock<String> = LazyLock::new(|| {
|
||||
format!("https://{}/aiserver.v1.AiService/", *CURSOR_API2_HOST)
|
||||
});
|
||||
|
||||
pub static OAUTH_CLIENT_ID: LazyLock<String> = LazyLock::new(|| {
|
||||
parse_string_from_env("OAUTH_CLIENT_ID", EMPTY_STRING).trim().to_string()
|
||||
});
|
||||
|
||||
pub static OAUTH_CLIENT_SECRET: LazyLock<String> = LazyLock::new(|| {
|
||||
parse_string_from_env("OAUTH_CLIENT_SECRET", EMPTY_STRING).trim().to_string()
|
||||
});
|
||||
|
||||
pub static OAUTH_REDIRECT_URI: LazyLock<String> = LazyLock::new(|| {
|
||||
parse_string_from_env("OAUTH_REDIRECT_URI", EMPTY_STRING).trim().to_string()
|
||||
});
|
||||
|
||||
// pub static DEBUG: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("DEBUG", false));
|
||||
|
||||
// #[macro_export]
|
||||
@@ -65,3 +67,5 @@ pub static CURSOR_API2_BASE_URL: LazyLock<String> = LazyLock::new(|| {
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
|
||||
def_pub_static!(ADMIN_AUTH_TOKEN, env: "ADMIN_AUTH_TOKEN", default: EMPTY_STRING);
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
use crate::{
|
||||
app::constant::{
|
||||
ERR_INVALID_PATH, ERR_RESET_CONFIG, ERR_UPDATE_CONFIG, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH,
|
||||
ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_SHARED_JS_PATH,
|
||||
ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENINFO_PATH,
|
||||
},
|
||||
common::models::usage::UserUsageInfo,
|
||||
use crate::app::constant::{
|
||||
ERR_INVALID_PATH, ERR_RESET_CONFIG, ERR_UPDATE_CONFIG, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH,
|
||||
ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_SHARED_JS_PATH,
|
||||
ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENINFO_PATH,
|
||||
};
|
||||
use crate::chat::model::Message;
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::RwLock;
|
||||
use std::sync::{LazyLock, RwLock};
|
||||
|
||||
// 页面内容类型枚举
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@@ -87,16 +82,10 @@ pub struct Pages {
|
||||
pub struct AppState {
|
||||
pub total_requests: u64,
|
||||
pub active_requests: u64,
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
pub request_logs: Vec<RequestLog>,
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
pub token_infos: Vec<TokenInfo>,
|
||||
}
|
||||
|
||||
// 全局配置实例
|
||||
lazy_static! {
|
||||
pub static ref APP_CONFIG: RwLock<AppConfig> = RwLock::new(AppConfig::default());
|
||||
}
|
||||
pub static APP_CONFIG: LazyLock<RwLock<AppConfig>> = LazyLock::new(|| RwLock::new(AppConfig::default()));
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
@@ -275,17 +264,6 @@ impl AppConfig {
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
pub fn new(token_infos: Vec<TokenInfo>) -> Self {
|
||||
Self {
|
||||
total_requests: 0,
|
||||
active_requests: 0,
|
||||
request_logs: Vec::new(),
|
||||
token_infos,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
total_requests: 0,
|
||||
@@ -294,51 +272,5 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
// 请求日志
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct RequestLog {
|
||||
pub id: u64,
|
||||
pub timestamp: chrono::DateTime<chrono::Local>,
|
||||
pub model: String,
|
||||
pub token_info: TokenInfo,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<String>,
|
||||
pub stream: bool,
|
||||
pub status: &'static str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
// pub struct PromptList(Option<String>);
|
||||
|
||||
// impl PromptList {
|
||||
// pub fn to_vec(&self) -> Vec<>
|
||||
// }
|
||||
|
||||
// 聊天请求
|
||||
#[derive(Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
// 用于存储 token 信息
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct TokenInfo {
|
||||
pub token: String,
|
||||
pub checksum: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub alias: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<UserUsageInfo>,
|
||||
}
|
||||
|
||||
// TokenUpdateRequest 结构体
|
||||
#[derive(Deserialize)]
|
||||
pub struct TokenUpdateRequest {
|
||||
pub tokens: String,
|
||||
#[serde(default)]
|
||||
pub token_list: Option<String>,
|
||||
}
|
||||
mod db;
|
||||
pub use db::*;
|
||||
|
||||
148
src/app/model/db.rs
Normal file
148
src/app/model/db.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use crate::{chat::model::Message, common::models::usage::UserUsageInfo};
|
||||
use chrono::{DateTime, Local};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub enum LogStatus {
|
||||
#[serde(rename = "pending")]
|
||||
Pending,
|
||||
#[serde(rename = "success")]
|
||||
Success,
|
||||
#[serde(rename = "failed")]
|
||||
Failed,
|
||||
#[serde(rename = "deleted")]
|
||||
Deleted,
|
||||
}
|
||||
|
||||
impl rusqlite::types::FromSql for LogStatus {
|
||||
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
|
||||
match value.as_i64()? {
|
||||
0 => Ok(LogStatus::Pending),
|
||||
1 => Ok(LogStatus::Success),
|
||||
2 => Ok(LogStatus::Failed),
|
||||
3 => Ok(LogStatus::Deleted),
|
||||
_ => Err(rusqlite::types::FromSqlError::OutOfRange(value.as_i64()?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl rusqlite::ToSql for LogStatus {
|
||||
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
|
||||
Ok(rusqlite::types::ToSqlOutput::from(match self {
|
||||
LogStatus::Pending => 0u8,
|
||||
LogStatus::Success => 1u8,
|
||||
LogStatus::Failed => 2u8,
|
||||
LogStatus::Deleted => 3u8,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// 请求日志
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct LogInfo {
|
||||
#[serde(skip_serializing)]
|
||||
pub id: i64,
|
||||
pub timestamp: DateTime<Local>,
|
||||
#[serde(skip_serializing_if = "TokenInfo::is_hide")]
|
||||
pub token_info: TokenInfo,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<String>,
|
||||
pub model: String,
|
||||
pub stream: bool,
|
||||
pub status: LogStatus,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
// 聊天请求
|
||||
#[derive(Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, PartialEq, Clone)]
|
||||
pub enum TokenStatus {
|
||||
#[serde(rename = "pending")]
|
||||
Pending,
|
||||
#[serde(rename = "active")]
|
||||
Active,
|
||||
#[serde(rename = "expired")]
|
||||
Expired,
|
||||
#[serde(rename = "deleted")]
|
||||
Deleted,
|
||||
}
|
||||
|
||||
impl rusqlite::types::FromSql for TokenStatus {
|
||||
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
|
||||
match value.as_i64()? {
|
||||
0 => Ok(TokenStatus::Pending),
|
||||
1 => Ok(TokenStatus::Active),
|
||||
2 => Ok(TokenStatus::Expired),
|
||||
3 => Ok(TokenStatus::Deleted),
|
||||
_ => Err(rusqlite::types::FromSqlError::OutOfRange(value.as_i64()?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl rusqlite::ToSql for TokenStatus {
|
||||
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
|
||||
Ok(rusqlite::types::ToSqlOutput::from(match self {
|
||||
TokenStatus::Pending => 0u8,
|
||||
TokenStatus::Active => 1u8,
|
||||
TokenStatus::Expired => 2u8,
|
||||
TokenStatus::Deleted => 3u8,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// 用于存储 token 信息
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct TokenInfo {
|
||||
#[serde(skip_serializing)]
|
||||
pub id: i64,
|
||||
pub create_at: DateTime<Local>,
|
||||
pub token: String,
|
||||
pub checksum: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub alias: Option<String>,
|
||||
pub status: TokenStatus,
|
||||
pub pengding_at: DateTime<Local>,
|
||||
#[serde(skip_serializing)]
|
||||
pub user_id: i64,
|
||||
pub is_public: bool, // 公益
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<UserUsageInfo>,
|
||||
}
|
||||
|
||||
impl TokenInfo {
|
||||
pub fn is_hide(&self) -> bool {
|
||||
self.status == TokenStatus::Deleted
|
||||
}
|
||||
}
|
||||
|
||||
// TokenUpdateRequest 结构体
|
||||
#[derive(Deserialize)]
|
||||
pub struct TokenUpdateRequest {
|
||||
pub token: String,
|
||||
pub checksum: String,
|
||||
#[serde(default)]
|
||||
pub alias: Option<String>,
|
||||
#[serde(default)]
|
||||
pub is_public: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct UserInfo {
|
||||
pub id: i64,
|
||||
pub forum_id: i64,
|
||||
pub username: String, // 论坛用户名
|
||||
pub name: String, // 论坛昵称
|
||||
pub trust_level: u8,
|
||||
pub created_at: DateTime<Local>,
|
||||
pub ban_expired_at: Option<DateTime<Local>>, // 封禁到期时间
|
||||
pub ban_count: u32, // 封禁次数
|
||||
pub auth_token: Option<String>,
|
||||
}
|
||||
@@ -6,7 +6,7 @@ pub enum UsageCheck {
|
||||
None,
|
||||
Default,
|
||||
All,
|
||||
Custom(Vec<&'static str>),
|
||||
Custom(Vec<String>),
|
||||
}
|
||||
|
||||
impl Default for UsageCheck {
|
||||
@@ -69,14 +69,14 @@ impl<'de> Deserialize<'de> for UsageCheck {
|
||||
return Ok(UsageCheck::None);
|
||||
}
|
||||
|
||||
let models: Vec<&'static str> = list
|
||||
let models: Vec<String> = list
|
||||
.split(',')
|
||||
.filter_map(|model| {
|
||||
let model = model.trim();
|
||||
AVAILABLE_MODELS
|
||||
.iter()
|
||||
.find(|m| m.id == model)
|
||||
.map(|m| m.id)
|
||||
.map(|m| m.id.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ pub mod adapter;
|
||||
pub mod aiserver;
|
||||
pub mod constant;
|
||||
pub mod error;
|
||||
// pub mod middleware;
|
||||
pub mod model;
|
||||
pub mod route;
|
||||
pub mod service;
|
||||
|
||||
@@ -14,9 +14,9 @@ use super::{
|
||||
conversation_message, image_proto, ConversationMessage, ExplicitContext, GetChatRequest,
|
||||
ImageProto, ModelDetails,
|
||||
},
|
||||
constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS},
|
||||
model::{Message, MessageContent, Role},
|
||||
};
|
||||
constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT},
|
||||
model::{Message, MessageContent, Role, Model},
|
||||
};
|
||||
|
||||
async fn process_chat_inputs(inputs: Vec<Message>) -> (String, Vec<ConversationMessage>) {
|
||||
// 收集 system 指令
|
||||
@@ -380,7 +380,7 @@ pub async fn encode_chat_message(
|
||||
workspace_id: None,
|
||||
external_links: vec![],
|
||||
commit_notes: vec![],
|
||||
long_context_mode: if LONG_CONTEXT_MODELS.contains(&model_name) {
|
||||
long_context_mode: if Model::is_long_context(model_name) {
|
||||
Some(true)
|
||||
} else {
|
||||
None
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use super::model::Model;
|
||||
|
||||
macro_rules! def_pub_const {
|
||||
@@ -9,8 +11,8 @@ def_pub_const!(ERR_UNSUPPORTED_GIF, "不支持动态 GIF");
|
||||
def_pub_const!(ERR_UNSUPPORTED_IMAGE_FORMAT, "不支持的图片格式,仅支持 PNG、JPEG、WEBP 和非动态 GIF");
|
||||
def_pub_const!(ERR_NODATA, "No data");
|
||||
|
||||
const MODEL_OBJECT: &str = "model";
|
||||
const CREATED: &i64 = &1706659200;
|
||||
pub const MODEL_OBJECT: &str = "model";
|
||||
pub const CREATED: &i64 = &1706659200;
|
||||
|
||||
def_pub_const!(ANTHROPIC, "anthropic");
|
||||
def_pub_const!(CURSOR, "cursor");
|
||||
@@ -42,134 +44,134 @@ def_pub_const!(
|
||||
);
|
||||
def_pub_const!(GEMINI_2_0_FLASH_EXP, "gemini-2.0-flash-exp");
|
||||
|
||||
pub const AVAILABLE_MODELS: [Model; 21] = [
|
||||
pub const AVAILABLE_MODELS: LazyLock<[Model; 21]> = LazyLock::new(|| [
|
||||
Model {
|
||||
id: CLAUDE_3_5_SONNET,
|
||||
id: CLAUDE_3_5_SONNET.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: ANTHROPIC,
|
||||
},
|
||||
Model {
|
||||
id: GPT_4,
|
||||
id: GPT_4.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: GPT_4O,
|
||||
id: GPT_4O.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: CLAUDE_3_OPUS,
|
||||
id: CLAUDE_3_OPUS.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: ANTHROPIC,
|
||||
},
|
||||
Model {
|
||||
id: CURSOR_FAST,
|
||||
id: CURSOR_FAST.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: CURSOR,
|
||||
},
|
||||
Model {
|
||||
id: CURSOR_SMALL,
|
||||
id: CURSOR_SMALL.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: CURSOR,
|
||||
},
|
||||
Model {
|
||||
id: GPT_3_5_TURBO,
|
||||
id: GPT_3_5_TURBO.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: GPT_4_TURBO_2024_04_09,
|
||||
id: GPT_4_TURBO_2024_04_09.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: GPT_4O_128K,
|
||||
id: GPT_4O_128K.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: GEMINI_1_5_FLASH_500K,
|
||||
id: GEMINI_1_5_FLASH_500K.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: GOOGLE,
|
||||
},
|
||||
Model {
|
||||
id: CLAUDE_3_HAIKU_200K,
|
||||
id: CLAUDE_3_HAIKU_200K.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: ANTHROPIC,
|
||||
},
|
||||
Model {
|
||||
id: CLAUDE_3_5_SONNET_200K,
|
||||
id: CLAUDE_3_5_SONNET_200K.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: ANTHROPIC,
|
||||
},
|
||||
Model {
|
||||
id: CLAUDE_3_5_SONNET_20241022,
|
||||
id: CLAUDE_3_5_SONNET_20241022.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: ANTHROPIC,
|
||||
},
|
||||
Model {
|
||||
id: GPT_4O_MINI,
|
||||
id: GPT_4O_MINI.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: O1_MINI,
|
||||
id: O1_MINI.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: O1_PREVIEW,
|
||||
id: O1_PREVIEW.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: O1,
|
||||
id: O1.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: OPENAI,
|
||||
},
|
||||
Model {
|
||||
id: CLAUDE_3_5_HAIKU,
|
||||
id: CLAUDE_3_5_HAIKU.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: ANTHROPIC,
|
||||
},
|
||||
Model {
|
||||
id: GEMINI_EXP_1206,
|
||||
id: GEMINI_EXP_1206.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: GOOGLE,
|
||||
},
|
||||
Model {
|
||||
id: GEMINI_2_0_FLASH_THINKING_EXP,
|
||||
id: GEMINI_2_0_FLASH_THINKING_EXP.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: GOOGLE,
|
||||
},
|
||||
Model {
|
||||
id: GEMINI_2_0_FLASH_EXP,
|
||||
id: GEMINI_2_0_FLASH_EXP.to_string(),
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: GOOGLE,
|
||||
},
|
||||
];
|
||||
]);
|
||||
|
||||
pub const USAGE_CHECK_MODELS: [&str; 11] = [
|
||||
CLAUDE_3_5_SONNET_20241022,
|
||||
@@ -184,10 +186,3 @@ pub const USAGE_CHECK_MODELS: [&str; 11] = [
|
||||
CLAUDE_3_HAIKU_200K,
|
||||
CLAUDE_3_5_SONNET_200K,
|
||||
];
|
||||
|
||||
pub const LONG_CONTEXT_MODELS: [&str; 4] = [
|
||||
GPT_4O_128K,
|
||||
GEMINI_1_5_FLASH_500K,
|
||||
CLAUDE_3_HAIKU_200K,
|
||||
CLAUDE_3_5_SONNET_200K,
|
||||
];
|
||||
|
||||
@@ -80,12 +80,18 @@ pub struct Usage {
|
||||
// 模型定义
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Model {
|
||||
pub id: &'static str,
|
||||
pub id: String,
|
||||
pub created: &'static i64,
|
||||
pub object: &'static str,
|
||||
pub owned_by: &'static str,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn is_long_context(id :&str) -> bool {
|
||||
id.ends_with("128k") || id.ends_with("500k") || id.ends_with("200k")
|
||||
}
|
||||
}
|
||||
|
||||
use crate::app::model::{AppConfig, UsageCheck};
|
||||
use super::constant::USAGE_CHECK_MODELS;
|
||||
|
||||
@@ -93,9 +99,9 @@ impl Model {
|
||||
pub fn is_usage_check(&self) -> bool {
|
||||
match AppConfig::get_usage_check() {
|
||||
UsageCheck::None => false,
|
||||
UsageCheck::Default => USAGE_CHECK_MODELS.contains(&self.id),
|
||||
UsageCheck::Default => USAGE_CHECK_MODELS.iter().any(|&x| x == self.id.as_str()),
|
||||
UsageCheck::All => true,
|
||||
UsageCheck::Custom(models) => models.contains(&self.id),
|
||||
UsageCheck::Custom(models) => models.iter().any(|x| x == &self.id),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,5 +109,5 @@ impl Model {
|
||||
#[derive(Serialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: &'static str,
|
||||
pub data: &'static [Model],
|
||||
pub data: Vec<Model>,
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ pub use logs::{handle_logs, handle_logs_post};
|
||||
mod health;
|
||||
pub use health::{handle_root, handle_health};
|
||||
mod token;
|
||||
pub use token::{handle_get_checksum, handle_update_tokeninfo, handle_get_tokeninfo, handle_update_tokeninfo_post, handle_tokeninfo_page};
|
||||
pub use token::{handle_get_checksum, handle_get_tokeninfo, handle_update_tokeninfo_post, handle_tokeninfo_page};
|
||||
mod usage;
|
||||
pub use usage::get_user_info;
|
||||
mod config;
|
||||
pub use config::{handle_env_example, handle_config_page, handle_static, handle_readme, handle_about};
|
||||
mod auth;
|
||||
pub use auth::{handle_auth_callback, handle_auth_initiate};
|
||||
|
||||
113
src/chat/route/auth.rs
Normal file
113
src/chat/route/auth.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use crate::app::constant::ROUTE_TOKENINFO_PATH;
|
||||
use crate::app::lazy::{OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OAUTH_REDIRECT_URI};
|
||||
use crate::common::utils::oauth::ForumOAuth;
|
||||
use axum::http::{header::SET_COOKIE, HeaderMap};
|
||||
use axum::{extract::Query, response::Redirect};
|
||||
use base64::{engine::general_purpose::URL_SAFE, Engine};
|
||||
use ring::rand::SecureRandom as _;
|
||||
use ring::{aead, rand};
|
||||
use serde::Deserialize;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AuthCallback {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
// 用于加密的密钥,使用 OnceLock 确保只初始化一次
|
||||
static ENCRYPTION_KEY: OnceLock<aead::LessSafeKey> = OnceLock::new();
|
||||
|
||||
// 初始化加密密钥
|
||||
fn get_encryption_key() -> &'static aead::LessSafeKey {
|
||||
ENCRYPTION_KEY.get_or_init(|| {
|
||||
let rng = rand::SystemRandom::new();
|
||||
let mut key_bytes = [0u8; 32];
|
||||
rng.fill(&mut key_bytes).unwrap();
|
||||
let key = aead::UnboundKey::new(&aead::CHACHA20_POLY1305, &key_bytes).unwrap();
|
||||
aead::LessSafeKey::new(key)
|
||||
})
|
||||
}
|
||||
|
||||
// 加密 state
|
||||
fn encrypt_state(state: &str) -> Result<String, String> {
|
||||
let key = get_encryption_key();
|
||||
let nonce = aead::Nonce::assume_unique_for_key([0; 12]); // 在生产环境中应使用随机 nonce
|
||||
let mut in_out = state.as_bytes().to_vec();
|
||||
key.seal_in_place_append_tag(nonce, aead::Aad::empty(), &mut in_out)
|
||||
.map_err(|e| e.to_string())?;
|
||||
Ok(URL_SAFE.encode(in_out))
|
||||
}
|
||||
|
||||
// 解密 state
|
||||
fn decrypt_state(encrypted_state: &str) -> Result<String, String> {
|
||||
let key = get_encryption_key();
|
||||
let nonce = aead::Nonce::assume_unique_for_key([0; 12]);
|
||||
let mut encrypted_data = URL_SAFE
|
||||
.decode(encrypted_state)
|
||||
.map_err(|e| e.to_string())?;
|
||||
let decrypted = key
|
||||
.open_in_place(nonce, aead::Aad::empty(), &mut encrypted_data)
|
||||
.map_err(|e| e.to_string())?;
|
||||
String::from_utf8(decrypted.to_vec()).map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
pub async fn handle_auth_callback(
|
||||
headers: HeaderMap,
|
||||
Query(params): Query<AuthCallback>,
|
||||
) -> Result<Redirect, String> {
|
||||
let cookie_header = headers
|
||||
.get("cookie")
|
||||
.ok_or_else(|| "Missing cookie header".to_string())?;
|
||||
|
||||
let cookie_str = cookie_header.to_str().map_err(|e| e.to_string())?;
|
||||
|
||||
let encrypted_state = cookie_str
|
||||
.split(';')
|
||||
.find(|s| s.trim().starts_with("oauth_state="))
|
||||
.and_then(|s| s.trim().strip_prefix("oauth_state="))
|
||||
.ok_or_else(|| "Missing state cookie".to_string())?;
|
||||
|
||||
// 解密 state
|
||||
let expected_state = decrypt_state(encrypted_state)?;
|
||||
|
||||
let oauth = ForumOAuth::new(
|
||||
OAUTH_CLIENT_ID.to_string(),
|
||||
OAUTH_CLIENT_SECRET.to_string(),
|
||||
OAUTH_REDIRECT_URI.to_string(),
|
||||
)
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let access_token = oauth
|
||||
.exchange_code_for_token(¶ms.code, ¶ms.state, &expected_state)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let redirect_url = format!("{}?auth={}", ROUTE_TOKENINFO_PATH, access_token);
|
||||
Ok(Redirect::to(&redirect_url))
|
||||
}
|
||||
|
||||
pub async fn handle_auth_initiate() -> Result<(HeaderMap, Redirect), String> {
|
||||
let oauth = ForumOAuth::new(
|
||||
OAUTH_CLIENT_ID.to_string(),
|
||||
OAUTH_CLIENT_SECRET.to_string(),
|
||||
OAUTH_REDIRECT_URI.to_string(),
|
||||
)
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let (auth_url, state) = oauth.get_authorize_url();
|
||||
|
||||
// 加密 state
|
||||
let encrypted_state = encrypt_state(state.secret())?;
|
||||
|
||||
// 创建安全的 cookie
|
||||
let cookie = format!(
|
||||
"oauth_state={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age=300",
|
||||
encrypted_state
|
||||
);
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(SET_COOKIE, cookie.parse().unwrap());
|
||||
|
||||
Ok((headers, Redirect::to(&auth_url.to_string())))
|
||||
}
|
||||
@@ -2,21 +2,24 @@ use crate::app::{
|
||||
constant::{
|
||||
CONTENT_TYPE_TEXT_CSS_WITH_UTF8, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
|
||||
CONTENT_TYPE_TEXT_JS_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8,
|
||||
HEADER_NAME_CONTENT_TYPE, HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH,
|
||||
ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH,
|
||||
ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, ROUTE_README_PATH, ROUTE_SHARED_JS_PATH,
|
||||
ROUTE_SHARED_STYLES_PATH,
|
||||
},
|
||||
model::{AppConfig, PageContent},
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Path,
|
||||
http::StatusCode,
|
||||
http::{
|
||||
header::{CONTENT_TYPE, LOCATION},
|
||||
StatusCode,
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
pub async fn handle_env_example() -> impl IntoResponse {
|
||||
Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.body(include_str!("../../../.env.example").to_string())
|
||||
.unwrap()
|
||||
}
|
||||
@@ -25,15 +28,15 @@ pub async fn handle_env_example() -> impl IntoResponse {
|
||||
pub async fn handle_config_page() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_CONFIG_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(include_str!("../../../static/config.min.html").to_string())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.body(content.clone())
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(content.clone())
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -44,11 +47,11 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
|
||||
"shared-styles.css" => {
|
||||
match AppConfig::get_page_content(ROUTE_SHARED_STYLES_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8)
|
||||
.body(include_str!("../../../static/shared-styles.min.css").to_string())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) | PageContent::Html(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8)
|
||||
.body(content.clone())
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -56,11 +59,11 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
|
||||
"shared.js" => {
|
||||
match AppConfig::get_page_content(ROUTE_SHARED_JS_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8)
|
||||
.body(include_str!("../../../static/shared.min.js").to_string())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) | PageContent::Html(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8)
|
||||
.body(content.clone())
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -75,15 +78,15 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
|
||||
pub async fn handle_about() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_ABOUT_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(include_str!("../../../static/readme.min.html").to_string())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.body(content.clone())
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(content.clone())
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -93,15 +96,15 @@ pub async fn handle_readme() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_README_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.status(StatusCode::TEMPORARY_REDIRECT)
|
||||
.header(HEADER_NAME_LOCATION, ROUTE_ABOUT_PATH)
|
||||
.header(LOCATION, ROUTE_ABOUT_PATH)
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.body(Body::from(content.clone()))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(Body::from(content.clone()))
|
||||
.unwrap(),
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
CONTENT_TYPE_TEXT_HTML_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8,
|
||||
HEADER_NAME_CONTENT_TYPE, HEADER_NAME_LOCATION, PKG_VERSION, ROUTE_ABOUT_PATH,
|
||||
ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM,
|
||||
ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH,
|
||||
ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKENINFO_PATH,
|
||||
ROUTE_UPDATE_TOKENINFO_PATH,
|
||||
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, PKG_VERSION, ROUTE_HEALTH_PATH, ROUTE_ROOT_PATH,
|
||||
},
|
||||
db,
|
||||
lazy::get_start_time,
|
||||
model::{AppConfig, AppState, PageContent},
|
||||
lazy::{get_start_time, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH},
|
||||
},
|
||||
chat::constant::AVAILABLE_MODELS,
|
||||
common::models::{
|
||||
@@ -20,7 +17,10 @@ use crate::{
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
http::{
|
||||
header::{CONTENT_TYPE, LOCATION},
|
||||
HeaderMap, StatusCode,
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
@@ -33,80 +33,85 @@ pub async fn handle_root() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_ROOT_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.status(StatusCode::TEMPORARY_REDIRECT)
|
||||
.header(HEADER_NAME_LOCATION, ROUTE_HEALTH_PATH)
|
||||
.header(LOCATION, ROUTE_HEALTH_PATH)
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.body(Body::from(content.clone()))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(Body::from(content.clone()))
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_health(State(state): State<Arc<Mutex<AppState>>>) -> Json<HealthCheckResponse> {
|
||||
pub async fn handle_health(
|
||||
headers: HeaderMap,
|
||||
State(state): State<Arc<Mutex<AppState>>>,
|
||||
) -> Json<HealthCheckResponse> {
|
||||
let start_time = get_start_time();
|
||||
|
||||
// 创建系统信息实例,只监控 CPU 和内存
|
||||
let mut sys = System::new_with_specifics(
|
||||
RefreshKind::nothing()
|
||||
.with_memory(MemoryRefreshKind::everything())
|
||||
.with_cpu(CpuRefreshKind::everything()),
|
||||
);
|
||||
// 尝试从请求头获取token并验证用户
|
||||
let mut stats = None;
|
||||
let token = headers
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX));
|
||||
|
||||
std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL);
|
||||
if let Some(token) = token {
|
||||
if let Ok(Some(user)) = db::get_user_by_auth_token(token) {
|
||||
if user.id == 0 && user.ban_expired_at.map_or(true, |t| t <= Local::now()) {
|
||||
// 创建系统信息实例,只监控 CPU 和内存
|
||||
let mut sys = System::new_with_specifics(
|
||||
RefreshKind::nothing()
|
||||
.with_memory(MemoryRefreshKind::everything())
|
||||
.with_cpu(CpuRefreshKind::everything()),
|
||||
);
|
||||
|
||||
// 刷新 CPU 和内存信息
|
||||
sys.refresh_memory();
|
||||
sys.refresh_cpu_usage();
|
||||
std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL);
|
||||
|
||||
let pid = std::process::id() as usize;
|
||||
let process = sys.process(pid.into());
|
||||
// 刷新 CPU 和内存信息
|
||||
sys.refresh_memory();
|
||||
sys.refresh_cpu_usage();
|
||||
|
||||
// 获取内存信息
|
||||
let memory = process.map(|p| p.memory()).unwrap_or(0);
|
||||
let pid = std::process::id() as usize;
|
||||
let process = sys.process(pid.into());
|
||||
|
||||
// 获取 CPU 使用率
|
||||
let cpu_usage = sys.global_cpu_usage();
|
||||
// 获取内存信息
|
||||
let memory = process.map(|p| p.memory()).unwrap_or(0);
|
||||
|
||||
let state = state.lock().await;
|
||||
let uptime = (Local::now() - start_time).num_seconds();
|
||||
// 获取 CPU 使用率
|
||||
let cpu_usage = sys.global_cpu_usage();
|
||||
|
||||
let state = state.lock().await;
|
||||
|
||||
stats = Some(SystemStats {
|
||||
started: start_time.to_string(),
|
||||
total_requests: state.total_requests,
|
||||
active_requests: state.active_requests,
|
||||
system: SystemInfo {
|
||||
memory: MemoryInfo {
|
||||
rss: memory, // 物理内存使用量(字节)
|
||||
},
|
||||
cpu: CpuInfo {
|
||||
usage: cpu_usage, // CPU 使用率(百分比)
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Json(HealthCheckResponse {
|
||||
status: ApiStatus::Healthy,
|
||||
version: PKG_VERSION,
|
||||
uptime,
|
||||
stats: SystemStats {
|
||||
started: start_time.to_string(),
|
||||
total_requests: state.total_requests,
|
||||
active_requests: state.active_requests,
|
||||
system: SystemInfo {
|
||||
memory: MemoryInfo {
|
||||
rss: memory, // 物理内存使用量(字节)
|
||||
},
|
||||
cpu: CpuInfo {
|
||||
usage: cpu_usage, // CPU 使用率(百分比)
|
||||
},
|
||||
},
|
||||
},
|
||||
models: AVAILABLE_MODELS.iter().map(|m| m.id).collect::<Vec<_>>(),
|
||||
endpoints: vec![
|
||||
ROUTE_CHAT_PATH.as_str(),
|
||||
ROUTE_MODELS_PATH.as_str(),
|
||||
ROUTE_GET_CHECKSUM,
|
||||
ROUTE_TOKENINFO_PATH,
|
||||
ROUTE_UPDATE_TOKENINFO_PATH,
|
||||
ROUTE_GET_TOKENINFO_PATH,
|
||||
ROUTE_LOGS_PATH,
|
||||
ROUTE_GET_USER_INFO_PATH,
|
||||
ROUTE_ENV_EXAMPLE_PATH,
|
||||
ROUTE_CONFIG_PATH,
|
||||
ROUTE_STATIC_PATH,
|
||||
ROUTE_ABOUT_PATH,
|
||||
ROUTE_README_PATH,
|
||||
],
|
||||
uptime: (Local::now() - start_time).num_seconds(),
|
||||
stats,
|
||||
models: AVAILABLE_MODELS
|
||||
.iter()
|
||||
.map(|m| m.id.clone())
|
||||
.collect::<Vec<_>>(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,67 +2,62 @@ use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE,
|
||||
ROUTE_LOGS_PATH,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_LOGS_PATH,
|
||||
},
|
||||
model::{AppConfig, AppState, PageContent, RequestLog},
|
||||
lazy::AUTH_TOKEN,
|
||||
db,
|
||||
model::{AppConfig, LogInfo, PageContent},
|
||||
},
|
||||
common::models::ApiStatus,
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
http::{header::{AUTHORIZATION, CONTENT_TYPE}, HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use chrono::Local;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// 日志处理
|
||||
pub async fn handle_logs() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_LOGS_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(Body::from(
|
||||
include_str!("../../../static/logs.min.html").to_string(),
|
||||
))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.body(Body::from(content.clone()))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(Body::from(content.clone()))
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_logs_post(
|
||||
State(state): State<Arc<Mutex<AppState>>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<LogsResponse>, StatusCode> {
|
||||
let auth_token = AUTH_TOKEN.as_str();
|
||||
|
||||
// 验证 AUTH_TOKEN
|
||||
pub async fn handle_logs_post(headers: HeaderMap) -> Result<Json<LogsResponse>, StatusCode> {
|
||||
// 验证 auth_token
|
||||
let auth_header = headers
|
||||
.get(HEADER_NAME_AUTHORIZATION)
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if auth_header != auth_token {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
// 通过 auth_token 获取用户信息
|
||||
let user = db::get_user_by_auth_token(auth_header)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
// 获取用户的日志记录
|
||||
let logs =
|
||||
db::get_logs_by_user_id(Some(user.id)).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let state = state.lock().await;
|
||||
Ok(Json(LogsResponse {
|
||||
status: ApiStatus::Success,
|
||||
total: state.request_logs.len(),
|
||||
logs: state.request_logs.clone(),
|
||||
total: logs.len(),
|
||||
logs,
|
||||
timestamp: Local::now().to_string(),
|
||||
}))
|
||||
}
|
||||
@@ -71,6 +66,6 @@ pub async fn handle_logs_post(
|
||||
pub struct LogsResponse {
|
||||
pub status: ApiStatus,
|
||||
pub total: usize,
|
||||
pub logs: Vec<RequestLog>,
|
||||
pub logs: Vec<LogInfo>,
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
@@ -2,34 +2,30 @@ use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, HEADER_NAME_AUTHORIZATION, HEADER_NAME_CONTENT_TYPE,
|
||||
ROUTE_TOKENINFO_PATH,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_TOKENINFO_PATH,
|
||||
},
|
||||
model::{AppConfig, PageContent, TokenUpdateRequest},
|
||||
lazy::{AUTH_TOKEN, TOKEN_FILE, TOKEN_LIST_FILE},
|
||||
db::{
|
||||
get_token_by_id, get_token_by_token, get_tokens_by_user_id, get_user_by_auth_token,
|
||||
insert_token, update_token,
|
||||
},
|
||||
model::{AppConfig, PageContent, TokenInfo, TokenStatus, TokenUpdateRequest},
|
||||
},
|
||||
common::{
|
||||
models::{ApiStatus, NormalResponseNoData},
|
||||
utils::{generate_checksum, generate_hash, tokens::load_tokens},
|
||||
models::ApiStatus,
|
||||
utils::{extract_user_id, extract_time, generate_checksum, generate_hash, validate_checksum},
|
||||
},
|
||||
};
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
use crate::app::model::AppState;
|
||||
#[cfg(feature = "sqlite")]
|
||||
use crate::app::db::APP_DB;
|
||||
use axum::{
|
||||
http::HeaderMap,
|
||||
http::{
|
||||
header::{AUTHORIZATION, CONTENT_TYPE},
|
||||
HeaderMap,
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
use axum::extract::State;
|
||||
use chrono::Local;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Serialize;
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
use std::sync::Arc;
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ChecksumResponse {
|
||||
@@ -41,79 +37,38 @@ pub async fn handle_get_checksum() -> Json<ChecksumResponse> {
|
||||
Json(ChecksumResponse { checksum })
|
||||
}
|
||||
|
||||
// 更新 TokenInfo 处理
|
||||
pub async fn handle_update_tokeninfo(
|
||||
#[cfg(not(feature = "sqlite"))] State(state): State<Arc<Mutex<AppState>>>,
|
||||
) -> Json<NormalResponseNoData> {
|
||||
// 重新加载 tokens
|
||||
let token_infos = load_tokens();
|
||||
|
||||
// 更新应用状态
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.token_infos = token_infos;
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
// 使用 APP_DB 更新 token_infos
|
||||
if let Ok(db) = APP_DB.lock() {
|
||||
for token_info in token_infos {
|
||||
let _ = db.update_token_info(&token_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Json(NormalResponseNoData {
|
||||
status: ApiStatus::Success,
|
||||
message: Some("Token list has been reloaded".to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
// 获取 TokenInfo 处理
|
||||
pub async fn handle_get_tokeninfo(
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<TokenInfoResponse>, StatusCode> {
|
||||
// 验证 AUTH_TOKEN
|
||||
// 验证用户身份
|
||||
let auth_header = headers
|
||||
.get(HEADER_NAME_AUTHORIZATION)
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if auth_header != AUTH_TOKEN.as_str() {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
// 获取用户信息
|
||||
let user = get_user_by_auth_token(auth_header)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
// 获取用户的tokens
|
||||
let tokens = if user.id == 0 {
|
||||
// 管理员可以查看所有tokens
|
||||
get_tokens_by_user_id(None)
|
||||
} else {
|
||||
// 普通用户只能查看自己的tokens
|
||||
get_tokens_by_user_id(Some(user.id))
|
||||
}
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let token_file = TOKEN_FILE.as_str();
|
||||
let token_list_file = TOKEN_LIST_FILE.as_str();
|
||||
|
||||
// 读取文件内容
|
||||
let tokens = std::fs::read_to_string(&token_file).unwrap_or_else(|_| String::new());
|
||||
let token_list = std::fs::read_to_string(&token_list_file).unwrap_or_else(|_| String::new());
|
||||
|
||||
// 获取 tokens_count
|
||||
let tokens_count = {
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
APP_DB.lock()
|
||||
.map(|db| db.get_token_infos().map(|v| v.len()).unwrap_or(0))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
{
|
||||
tokens.len()
|
||||
}
|
||||
};
|
||||
let token_num = tokens.len();
|
||||
|
||||
Ok(Json(TokenInfoResponse {
|
||||
status: ApiStatus::Success,
|
||||
token_file: token_file.to_string(),
|
||||
token_list_file: token_list_file.to_string(),
|
||||
tokens: Some(tokens),
|
||||
tokens_count: Some(tokens_count),
|
||||
token_list: Some(token_list),
|
||||
num: Some(token_num),
|
||||
message: None,
|
||||
}))
|
||||
}
|
||||
@@ -121,89 +76,113 @@ pub async fn handle_get_tokeninfo(
|
||||
#[derive(Serialize)]
|
||||
pub struct TokenInfoResponse {
|
||||
pub status: ApiStatus,
|
||||
pub token_file: String,
|
||||
pub token_list_file: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokens: Option<String>,
|
||||
pub tokens: Option<Vec<TokenInfo>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokens_count: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub token_list: Option<String>,
|
||||
pub num: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn handle_update_tokeninfo_post(
|
||||
#[cfg(not(feature = "sqlite"))] State(state): State<Arc<Mutex<AppState>>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<TokenUpdateRequest>,
|
||||
) -> Result<Json<TokenInfoResponse>, StatusCode> {
|
||||
// 验证 AUTH_TOKEN
|
||||
// 验证用户身份
|
||||
let auth_header = headers
|
||||
.get(HEADER_NAME_AUTHORIZATION)
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if auth_header != AUTH_TOKEN.as_str() {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
// 获取用户信息
|
||||
let user = get_user_by_auth_token(auth_header)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if !validate_checksum(&request.checksum) {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let token_file = TOKEN_FILE.as_str();
|
||||
let token_list_file = TOKEN_LIST_FILE.as_str();
|
||||
// 检查token是否已存在
|
||||
let existing_token =
|
||||
get_token_by_token(&request.token).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// 写入文件
|
||||
std::fs::write(&token_file, &request.tokens)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let is_update = existing_token.is_some();
|
||||
|
||||
if let Some(token_list) = &request.token_list {
|
||||
std::fs::write(&token_list_file, token_list)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
let token_info = match existing_token {
|
||||
Some(mut token) => {
|
||||
// 更新现有token
|
||||
token.checksum = request.checksum;
|
||||
token.alias = request.alias;
|
||||
token.is_public = request.is_public;
|
||||
|
||||
// 重新加载 tokens
|
||||
let token_infos = load_tokens();
|
||||
let token_infos_len = token_infos.len();
|
||||
|
||||
// 更新应用状态
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.token_infos = token_infos;
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
if let Ok(db) = APP_DB.lock() {
|
||||
for token_info in token_infos {
|
||||
let _ = db.update_token_info(&token_info);
|
||||
}
|
||||
// 更新数据库
|
||||
update_token(&token).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
token
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let now = Local::now();
|
||||
let alias = if request.alias.is_none() {
|
||||
match extract_user_id(&request.token) {
|
||||
Some(user_id) => Some(user_id),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
request.alias
|
||||
};
|
||||
// 创建新token
|
||||
let new_token = TokenInfo {
|
||||
id: 0, // 数据库会自动分配ID
|
||||
create_at: extract_time(&request.token).unwrap_or_else(|| now),
|
||||
token: request.token,
|
||||
checksum: request.checksum,
|
||||
alias,
|
||||
status: TokenStatus::Active,
|
||||
pengding_at: now,
|
||||
user_id: user.id,
|
||||
is_public: request.is_public,
|
||||
usage: None,
|
||||
};
|
||||
|
||||
// 插入数据库
|
||||
let id = insert_token(&new_token).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// 获取插入后的完整token信息
|
||||
get_token_by_id(id)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(TokenInfoResponse {
|
||||
status: ApiStatus::Success,
|
||||
token_file: token_file.to_string(),
|
||||
token_list_file: token_list_file.to_string(),
|
||||
tokens: None,
|
||||
tokens_count: Some(token_infos_len),
|
||||
token_list: None,
|
||||
message: Some("Token files have been updated and reloaded".to_string()),
|
||||
num: None,
|
||||
message: Some(format!(
|
||||
"Token {} has been {}",
|
||||
token_info.token,
|
||||
if is_update {
|
||||
"updated"
|
||||
} else {
|
||||
"created"
|
||||
}
|
||||
)),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn handle_tokeninfo_page() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_TOKENINFO_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(include_str!("../../../static/tokeninfo.min.html").to_string())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.body(content.clone())
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.body(content.clone())
|
||||
.unwrap(),
|
||||
}
|
||||
|
||||
@@ -1,37 +1,54 @@
|
||||
use crate::{
|
||||
app::model::AppState,
|
||||
app::{
|
||||
constant::AUTHORIZATION_BEARER_PREFIX,
|
||||
db::{get_token_by_alias_and_user, get_user_by_auth_token},
|
||||
},
|
||||
chat::constant::ERR_NODATA,
|
||||
common::{models::usage::GetUserInfo, utils::get_user_usage},
|
||||
};
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
extract::Query,
|
||||
http::{header::AUTHORIZATION, HeaderMap, StatusCode},
|
||||
Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetUserInfoQuery {
|
||||
alias: String,
|
||||
user_id: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn get_user_info(
|
||||
State(state): State<Arc<Mutex<AppState>>>,
|
||||
headers: HeaderMap,
|
||||
Query(query): Query<GetUserInfoQuery>,
|
||||
) -> Json<GetUserInfo> {
|
||||
let token_infos = &state.lock().await.token_infos;
|
||||
let token_info = token_infos
|
||||
.iter()
|
||||
.find(|token_info| token_info.alias == Some(query.alias.clone()));
|
||||
) -> Result<Json<GetUserInfo>, StatusCode> {
|
||||
// 1. 验证用户身份
|
||||
let auth_header = headers
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let (auth_token, checksum) = match token_info {
|
||||
Some(token_info) => (token_info.token.clone(), token_info.checksum.clone()),
|
||||
None => return Json(GetUserInfo::Error(ERR_NODATA.to_string())),
|
||||
// 2. 获取用户信息
|
||||
let user = get_user_by_auth_token(auth_header)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
// 3. 查询token信息
|
||||
let token_info = match get_token_by_alias_and_user(
|
||||
&query.alias,
|
||||
user.id,
|
||||
query.user_id
|
||||
) {
|
||||
Ok(Some(token)) => token,
|
||||
Ok(None) => return Ok(Json(GetUserInfo::Error(ERR_NODATA.to_string()))),
|
||||
Err(_) => return Ok(Json(GetUserInfo::Error("Database error".to_string()))),
|
||||
};
|
||||
|
||||
match get_user_usage(&auth_token, &checksum).await {
|
||||
Some(usage) => Json(GetUserInfo::Usage(usage)),
|
||||
None => Json(GetUserInfo::Error(ERR_NODATA.to_string())),
|
||||
// 4. 获取使用情况
|
||||
match get_user_usage(&token_info.token, &token_info.checksum).await {
|
||||
Some(usage) => Ok(Json(GetUserInfo::Usage(usage))),
|
||||
None => Ok(Json(GetUserInfo::Error(ERR_NODATA.to_string()))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,10 @@ use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, CURSOR_API2_STREAM_CHAT, FINISH_REASON_STOP,
|
||||
HEADER_NAME_CONTENT_TYPE, OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
STATUS_FAILED, STATUS_SUCCESS,
|
||||
OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
},
|
||||
model::{AppConfig, AppState, ChatRequest, RequestLog, TokenInfo},
|
||||
lazy::AUTH_TOKEN,
|
||||
db,
|
||||
model::{AppConfig, AppState, ChatRequest, LogInfo, LogStatus, TokenStatus},
|
||||
},
|
||||
chat::{
|
||||
error::StreamError,
|
||||
@@ -25,19 +24,21 @@ use crate::{
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
http::{header::CONTENT_TYPE, HeaderMap, StatusCode},
|
||||
response::Response,
|
||||
Json,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use chrono::Local;
|
||||
use futures::{Stream, StreamExt};
|
||||
use rand::seq::IteratorRandom as _;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::atomic::Ordering,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
@@ -46,7 +47,7 @@ use uuid::Uuid;
|
||||
pub async fn handle_models() -> Json<ModelsResponse> {
|
||||
Json(ModelsResponse {
|
||||
object: "list",
|
||||
data: &AVAILABLE_MODELS,
|
||||
data: AVAILABLE_MODELS.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -56,119 +57,144 @@ pub async fn handle_chat(
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<ChatRequest>,
|
||||
) -> Result<Response<Body>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let allow_claude = AppConfig::get_allow_claude();
|
||||
// 验证模型是否支持并获取模型信息
|
||||
let model = AVAILABLE_MODELS.iter().find(|m| m.id == request.model);
|
||||
let model_supported = model.is_some();
|
||||
|
||||
if !(model_supported || allow_claude && request.model.starts_with("claude")) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ChatError::ModelNotSupported(request.model).to_json()),
|
||||
));
|
||||
}
|
||||
|
||||
let request_time = chrono::Local::now();
|
||||
|
||||
// 验证请求
|
||||
if request.messages.is_empty() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ChatError::EmptyMessages.to_json()),
|
||||
));
|
||||
}
|
||||
|
||||
// 获取并处理认证令牌
|
||||
let auth_token = headers
|
||||
// 从请求头获取token
|
||||
let token = headers
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.ok_or((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ChatError::Unauthorized.to_json()),
|
||||
Json(ChatError::MissingToken.to_error_response()),
|
||||
))?;
|
||||
|
||||
// 验证 AuthToken
|
||||
if auth_token != AUTH_TOKEN.as_str() {
|
||||
// 验证token并获取用户信息
|
||||
let user = db::get_user_by_auth_token(token).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
|
||||
let user = user.ok_or((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ChatError::InvalidToken.to_error_response()),
|
||||
))?;
|
||||
|
||||
// 检查用户是否在封禁期
|
||||
if let Some(ban_expired_at) = user.ban_expired_at {
|
||||
if ban_expired_at > Local::now() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ChatError::UserBanned(ban_expired_at).to_error_response()),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let tokens = db::get_available_tokens_by_user_id(Some(user.id)).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
|
||||
if tokens.is_empty() {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ChatError::Unauthorized.to_json()),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(ChatError::NoTokens.to_error_response()),
|
||||
));
|
||||
}
|
||||
|
||||
// 完整的令牌处理逻辑和对应的 checksum
|
||||
let (auth_token, checksum, alias) = {
|
||||
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
|
||||
let state_guard = state.lock().await;
|
||||
let token_infos = &state_guard.token_infos;
|
||||
// 随机选择一个可用的token
|
||||
let token_info = tokens.into_iter().choose(&mut rand::thread_rng()).ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(ChatError::NoTokens.to_error_response()),
|
||||
))?;
|
||||
|
||||
if token_infos.is_empty() {
|
||||
return Err((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(ChatError::NoTokens.to_json()),
|
||||
));
|
||||
}
|
||||
let allow_claude = AppConfig::get_allow_claude();
|
||||
// 验证模型是否支持并获取模型信息
|
||||
let model = AVAILABLE_MODELS
|
||||
.iter()
|
||||
.find(|m| m.id == request.model)
|
||||
.cloned();
|
||||
let model_supported = model.is_some();
|
||||
|
||||
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
|
||||
let token_info = &token_infos[index];
|
||||
(
|
||||
token_info.token.clone(),
|
||||
token_info.checksum.clone(),
|
||||
token_info.alias.clone(),
|
||||
)
|
||||
if !(model_supported || allow_claude && request.model.starts_with("claude")) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ChatError::ModelNotSupported(request.model).to_error_response()),
|
||||
));
|
||||
}
|
||||
|
||||
let request_time = Local::now();
|
||||
|
||||
// 验证请求
|
||||
if request.messages.is_empty() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ChatError::EmptyMessages.to_error_response()),
|
||||
));
|
||||
}
|
||||
|
||||
let log_info = LogInfo {
|
||||
id: 0, // 数据库会自动生成
|
||||
timestamp: request_time,
|
||||
token_info: token_info.clone(),
|
||||
prompt: None,
|
||||
model: request.model.clone(),
|
||||
stream: request.stream,
|
||||
status: LogStatus::Pending,
|
||||
error: None,
|
||||
};
|
||||
|
||||
let log_id = db::insert_log(&log_info).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
|
||||
// 更新请求日志
|
||||
{
|
||||
let state_clone = state.clone();
|
||||
let mut state = state.lock().await;
|
||||
state.total_requests += 1;
|
||||
state.active_requests += 1;
|
||||
let token = token_info.token.clone();
|
||||
let checksum = token_info.checksum.clone();
|
||||
|
||||
// 如果有model且需要获取使用情况,创建后台任务获取
|
||||
if let Some(model) = model {
|
||||
if model.is_usage_check() {
|
||||
let auth_token_clone = auth_token.clone();
|
||||
let checksum_clone = checksum.clone();
|
||||
let state_clone = state_clone.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let usage = get_user_usage(&auth_token_clone, &checksum_clone).await;
|
||||
let mut state = state_clone.lock().await;
|
||||
// 根据时间戳找到对应的日志
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.find(|log| log.timestamp == request_time)
|
||||
{
|
||||
log.token_info.usage = usage;
|
||||
let usage = get_user_usage(&token, &checksum).await;
|
||||
if let Err(err) = db::update_log_usage(log_id, usage) {
|
||||
eprintln!("Failed to update log usage: {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let next_id = state.request_logs.last().map_or(1, |log| log.id + 1);
|
||||
state.request_logs.push(RequestLog {
|
||||
id: next_id,
|
||||
timestamp: request_time,
|
||||
model: request.model.clone(),
|
||||
token_info: TokenInfo {
|
||||
token: auth_token.clone(),
|
||||
checksum: checksum.clone(),
|
||||
alias: alias.clone(),
|
||||
usage: None,
|
||||
},
|
||||
prompt: None,
|
||||
stream: request.stream,
|
||||
status: "pending",
|
||||
error: None,
|
||||
});
|
||||
|
||||
if state.request_logs.len() > 100 {
|
||||
state.request_logs.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
if db::get_user_logs_count(user.id).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})? >= 100 {
|
||||
db::clean_user_logs(user.id, 100).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
db::update_token_status(token_info.id, TokenStatus::Pending).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
|
||||
// 将消息转换为hex格式
|
||||
let hex_data = super::adapter::encode_chat_message(request.messages, &request.model)
|
||||
.await
|
||||
@@ -176,37 +202,39 @@ pub async fn handle_chat(
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
ChatError::RequestFailed("Failed to encode chat message".to_string()).to_json(),
|
||||
ChatError::RequestFailed("Failed to encode chat message".to_string())
|
||||
.to_error_response(),
|
||||
),
|
||||
)
|
||||
})?;
|
||||
|
||||
// 构建请求客户端
|
||||
let client = build_client(&auth_token, &checksum, CURSOR_API2_STREAM_CHAT);
|
||||
let client = build_client(&token_info.token, &token_info.checksum, CURSOR_API2_STREAM_CHAT);
|
||||
let response = client.body(hex_data).send().await;
|
||||
|
||||
// 处理请求结果
|
||||
let response = match response {
|
||||
Ok(resp) => {
|
||||
// 更新请求日志为成功
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.request_logs.last_mut().unwrap().status = STATUS_SUCCESS;
|
||||
}
|
||||
db::update_log_status(log_id, LogStatus::Success, None).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
resp
|
||||
}
|
||||
Err(e) => {
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.status = STATUS_FAILED;
|
||||
last_log.error = Some(e.to_string());
|
||||
}
|
||||
}
|
||||
db::update_log_status(log_id, LogStatus::Failed, Some(e.to_string())).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::RequestFailed(e.to_string()).to_json()),
|
||||
Json(ChatError::RequestFailed(e.to_string()).to_error_response()),
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -237,7 +265,7 @@ pub async fn handle_chat(
|
||||
// 理论上,若程序正常,必定成功,因为前面判断过了
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::RequestFailed(error_message).to_json()),
|
||||
Json(ChatError::RequestFailed(error_message).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
|
||||
@@ -245,13 +273,12 @@ pub async fn handle_chat(
|
||||
Err(StreamError::ChatError(error)) => {
|
||||
let error_respone = error.to_error_response();
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.status = STATUS_FAILED;
|
||||
last_log.error = Some(error_respone.native_code());
|
||||
}
|
||||
}
|
||||
db::update_log_status(log_id, LogStatus::Failed, Some(error_respone.native_code())).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
return Err((
|
||||
error_respone.status_code(),
|
||||
Json(error_respone.to_common()),
|
||||
@@ -274,18 +301,17 @@ pub async fn handle_chat(
|
||||
// Box::pin(stream)
|
||||
// as Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.status = STATUS_FAILED;
|
||||
last_log.error = Some("Empty stream response".to_string());
|
||||
}
|
||||
}
|
||||
db::update_log_status(log_id, LogStatus::Failed, Some("Empty stream response".to_string())).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
ChatError::RequestFailed("Empty stream response".to_string())
|
||||
.to_json(),
|
||||
.to_error_response(),
|
||||
),
|
||||
));
|
||||
}
|
||||
@@ -304,7 +330,6 @@ pub async fn handle_chat(
|
||||
let model = request.model.clone();
|
||||
let is_start = is_start.clone();
|
||||
let full_text = full_text.clone();
|
||||
let state = state.clone();
|
||||
|
||||
async move {
|
||||
let chunk = chunk.unwrap_or_default();
|
||||
@@ -415,10 +440,8 @@ pub async fn handle_chat(
|
||||
}
|
||||
Ok(StreamMessage::Debug(debug_prompt)) => {
|
||||
buffer_guard.clear();
|
||||
if let Ok(mut state) = state.try_lock() {
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.prompt = Some(debug_prompt.clone());
|
||||
}
|
||||
if let Err(err) = db::update_log_prompt(log_id, Some(debug_prompt.clone())) {
|
||||
eprintln!("Failed to update log prompt: {}", err);
|
||||
}
|
||||
Ok(Bytes::new())
|
||||
}
|
||||
@@ -435,7 +458,7 @@ pub async fn handle_chat(
|
||||
Ok(Response::builder()
|
||||
.header("Cache-Control", "no-cache")
|
||||
.header("Connection", "keep-alive")
|
||||
.header(HEADER_NAME_CONTENT_TYPE, "text/event-stream")
|
||||
.header(CONTENT_TYPE, "text/event-stream")
|
||||
.body(Body::from_stream(stream))
|
||||
.unwrap())
|
||||
} else {
|
||||
@@ -451,7 +474,7 @@ pub async fn handle_chat(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
ChatError::RequestFailed(format!("Failed to read response chunk: {}", e))
|
||||
.to_json(),
|
||||
.to_error_response(),
|
||||
),
|
||||
)
|
||||
})?;
|
||||
@@ -496,29 +519,34 @@ pub async fn handle_chat(
|
||||
// 检查响应是否为空
|
||||
if full_text.is_empty() {
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.status = STATUS_FAILED;
|
||||
last_log.error = Some("Empty response received".to_string());
|
||||
if let Some(p) = prompt {
|
||||
last_log.prompt = Some(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
db::update_log_status(log_id, LogStatus::Failed, Some("Empty response received".to_string())).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
db::update_log_prompt(log_id, None).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::RequestFailed("Empty response received".to_string()).to_json()),
|
||||
Json(
|
||||
ChatError::RequestFailed("Empty response received".to_string())
|
||||
.to_error_response(),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// 更新请求日志提示词
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.prompt = prompt;
|
||||
}
|
||||
}
|
||||
db::update_log_prompt(log_id, prompt).map_err(|err| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::DatabaseError(err.to_string()).to_error_response()),
|
||||
)
|
||||
})?;
|
||||
|
||||
let response_data = ChatResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4().simple()),
|
||||
@@ -542,7 +570,7 @@ pub async fn handle_chat(
|
||||
};
|
||||
|
||||
Ok(Response::builder()
|
||||
.header(HEADER_NAME_CONTENT_TYPE, "application/json")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&response_data).unwrap()))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use crate::app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO,
|
||||
CURSOR_API2_STREAM_CHAT, HEADER_NAME_AUTHORIZATION,
|
||||
HEADER_NAME_CONTENT_TYPE,
|
||||
CURSOR_API2_STREAM_CHAT, HEADER_NAME_GHOST_MODE,
|
||||
TRUE, FALSE
|
||||
},
|
||||
lazy::{CURSOR_API2_HOST, CURSOR_API2_BASE_URL},
|
||||
lazy::{CURSOR_API2_BASE_URL, CURSOR_API2_HOST},
|
||||
};
|
||||
use reqwest::Client;
|
||||
use reqwest::{header::{CONTENT_TYPE,AUTHORIZATION,USER_AGENT,HOST}, Client};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// 返回预构建的 Cursor API 客户端
|
||||
@@ -21,19 +21,45 @@ pub fn build_client(auth_token: &str, checksum: &str, endpoint: &str) -> reqwest
|
||||
|
||||
client
|
||||
.post(format!("{}{}", *CURSOR_API2_BASE_URL, endpoint))
|
||||
.header(HEADER_NAME_CONTENT_TYPE, content_type)
|
||||
.header(CONTENT_TYPE, content_type)
|
||||
.header(
|
||||
HEADER_NAME_AUTHORIZATION,
|
||||
AUTHORIZATION,
|
||||
format!("{}{}", AUTHORIZATION_BEARER_PREFIX, auth_token),
|
||||
)
|
||||
.header("connect-accept-encoding", "gzip,br")
|
||||
.header("connect-protocol-version", "1")
|
||||
.header("user-agent", "connect-es/1.6.1")
|
||||
.header(USER_AGENT, "connect-es/1.6.1")
|
||||
.header("x-amzn-trace-id", format!("Root={}", trace_id))
|
||||
.header("x-cursor-checksum", checksum)
|
||||
.header("x-cursor-client-version", "0.42.5")
|
||||
.header("x-cursor-timezone", "Asia/Shanghai")
|
||||
.header("x-ghost-mode", "false")
|
||||
.header(HEADER_NAME_GHOST_MODE, FALSE)
|
||||
.header("x-request-id", trace_id)
|
||||
.header("Host", CURSOR_API2_HOST.clone())
|
||||
.header(HOST, CURSOR_API2_HOST.clone())
|
||||
}
|
||||
|
||||
/// 返回预构建的获取 Stripe 账户信息的 Cursor API 客户端
|
||||
pub fn build_profile_client(auth_token: &str) -> reqwest::RequestBuilder {
|
||||
let client = Client::new();
|
||||
|
||||
client
|
||||
.get(format!("{}/auth/full_stripe_profile", *CURSOR_API2_BASE_URL))
|
||||
.header(HOST, CURSOR_API2_HOST.clone())
|
||||
.header("sec-ch-ua", "\"Not-A.Brand\";v=\"99\", \"Chromium\";v=\"124\"")
|
||||
.header(HEADER_NAME_GHOST_MODE, TRUE)
|
||||
.header("sec-ch-ua-mobile", "?0")
|
||||
.header(
|
||||
AUTHORIZATION,
|
||||
format!("{}{}", AUTHORIZATION_BEARER_PREFIX, auth_token),
|
||||
)
|
||||
.header(USER_AGENT, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.42.5 Chrome/124.0.6367.243 Electron/30.4.0 Safari/537.36")
|
||||
.header("sec-ch-ua-platform", "\"Windows\"")
|
||||
.header("accept", "*/*")
|
||||
.header("origin", "vscode-file://vscode-app")
|
||||
.header("sec-fetch-site", "cross-site")
|
||||
.header("sec-fetch-mode", "cors")
|
||||
.header("sec-fetch-dest", "empty")
|
||||
.header("accept-encoding", "gzip, deflate, br")
|
||||
.header("accept-language", "zh-CN")
|
||||
.header("priority", "u=1, i")
|
||||
}
|
||||
|
||||
@@ -6,10 +6,14 @@ pub enum ChatError {
|
||||
NoTokens,
|
||||
RequestFailed(String),
|
||||
Unauthorized,
|
||||
MissingToken,
|
||||
InvalidToken,
|
||||
UserBanned(chrono::DateTime<chrono::Local>),
|
||||
DatabaseError(String),
|
||||
}
|
||||
|
||||
impl ChatError {
|
||||
pub fn to_json(&self) -> ErrorResponse {
|
||||
pub fn to_error_response(&self) -> ErrorResponse {
|
||||
let (error, message) = match self {
|
||||
ChatError::ModelNotSupported(model) => (
|
||||
"model_not_supported",
|
||||
@@ -22,13 +26,23 @@ impl ChatError {
|
||||
ChatError::NoTokens => ("no_tokens", "No available tokens".to_string()),
|
||||
ChatError::RequestFailed(err) => ("request_failed", format!("Request failed: {}", err)),
|
||||
ChatError::Unauthorized => ("unauthorized", "Invalid authorization token".to_string()),
|
||||
ChatError::MissingToken => ("missing_token", "Missing authorization token".to_string()),
|
||||
ChatError::InvalidToken => ("invalid_token", "Invalid authorization token".to_string()),
|
||||
ChatError::UserBanned(expired_at) => (
|
||||
"user_banned",
|
||||
format!("User is banned until {}", expired_at),
|
||||
),
|
||||
ChatError::DatabaseError(err) => (
|
||||
"database_error",
|
||||
format!("Database error occurred: {}", err),
|
||||
),
|
||||
};
|
||||
|
||||
ErrorResponse {
|
||||
status: super::ApiStatus::Error,
|
||||
code: None,
|
||||
error: Some(error.to_string()),
|
||||
message: Some(message),
|
||||
status: super::ApiStatus::Error,
|
||||
code: None,
|
||||
error: Some(error.to_string()),
|
||||
message: Some(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,9 @@ pub struct HealthCheckResponse {
|
||||
pub status: ApiStatus,
|
||||
pub version: &'static str,
|
||||
pub uptime: i64,
|
||||
pub stats: SystemStats,
|
||||
pub models: Vec<&'static str>,
|
||||
pub endpoints: Vec<&'static str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stats: Option<SystemStats>,
|
||||
pub models: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use serde::Serialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub enum GetUserInfo {
|
||||
@@ -12,4 +12,41 @@ pub enum GetUserInfo {
|
||||
pub struct UserUsageInfo {
|
||||
pub fast_requests: u32,
|
||||
pub max_fast_requests: u32,
|
||||
pub mtype: String,
|
||||
pub trial_days: u32,
|
||||
}
|
||||
|
||||
impl rusqlite::types::FromSql for UserUsageInfo {
|
||||
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
|
||||
let str = value.as_str()?;
|
||||
let parts: Vec<&str> = str.split(',').collect();
|
||||
if parts.len() != 4 {
|
||||
return Err(rusqlite::types::FromSqlError::InvalidType);
|
||||
}
|
||||
|
||||
Ok(UserUsageInfo {
|
||||
fast_requests: parts[0].parse().map_err(|_| rusqlite::types::FromSqlError::InvalidType)?,
|
||||
max_fast_requests: parts[1].parse().map_err(|_| rusqlite::types::FromSqlError::InvalidType)?,
|
||||
mtype: parts[2].to_string(),
|
||||
trial_days: parts[3].parse().map_err(|_| rusqlite::types::FromSqlError::InvalidType)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl rusqlite::ToSql for UserUsageInfo {
|
||||
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
|
||||
let str = format!("{},{},{},{}",
|
||||
self.fast_requests,
|
||||
self.max_fast_requests,
|
||||
self.mtype,
|
||||
self.trial_days
|
||||
);
|
||||
Ok(rusqlite::types::ToSqlOutput::from(str))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct StripeProfile {
|
||||
pub membership_type: String,
|
||||
pub days_remaining_on_trial: i32,
|
||||
}
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
mod checksum;
|
||||
pub use checksum::*;
|
||||
pub mod tokens;
|
||||
mod token;
|
||||
pub use token::*;
|
||||
pub mod oauth;
|
||||
use prost::Message as _;
|
||||
|
||||
use crate::{app::constant::CURSOR_API2_GET_USER_INFO, chat::aiserver::v1::GetUserInfoResponse};
|
||||
use crate::{app::constant::{CURSOR_API2_GET_USER_INFO, TRUE, FALSE}, chat::aiserver::v1::GetUserInfoResponse};
|
||||
|
||||
use super::models::usage::UserUsageInfo;
|
||||
use super::models::usage::{StripeProfile, UserUsageInfo};
|
||||
|
||||
pub fn parse_bool_from_env(key: &str, default: bool) -> bool {
|
||||
std::env::var(key)
|
||||
.ok()
|
||||
.map(|v| match v.to_lowercase().as_str() {
|
||||
"true" | "1" => true,
|
||||
"false" | "0" => false,
|
||||
TRUE | "1" => true,
|
||||
FALSE | "0" => false,
|
||||
_ => default,
|
||||
})
|
||||
.unwrap_or(default)
|
||||
@@ -44,8 +45,47 @@ pub async fn get_user_usage(auth_token: &str, checksum: &str) -> Option<UserUsag
|
||||
.ok()?;
|
||||
let user_info = GetUserInfoResponse::decode(response.as_ref()).ok()?;
|
||||
|
||||
let (mtype, trial_days) = get_stripe_profile(auth_token).await.unwrap_or_default();
|
||||
|
||||
user_info.usage.map(|user_usage| UserUsageInfo {
|
||||
fast_requests: i32_to_u32(user_usage.gpt4_requests),
|
||||
max_fast_requests: i32_to_u32(user_usage.gpt4_max_requests),
|
||||
mtype,
|
||||
trial_days,
|
||||
})
|
||||
}
|
||||
|
||||
// pub async fn get_available_models(auth_token: &str,checksum: &str) -> Option<Vec<Model>> {
|
||||
// let client = super::client::build_client(auth_token, checksum, CURSOR_API2_AVAILABLE_MODELS);
|
||||
// let response = client
|
||||
// .body(Vec::new())
|
||||
// .send()
|
||||
// .await
|
||||
// .ok()?
|
||||
// .bytes()
|
||||
// .await
|
||||
// .ok()?;
|
||||
// let available_models = AvailableModelsResponse::decode(response.as_ref()).ok()?;
|
||||
// Some(available_models.models.into_iter().map(|model| Model {
|
||||
// id: model.name.clone(),
|
||||
// created: CREATED,
|
||||
// object: MODEL_OBJECT,
|
||||
// owned_by: {
|
||||
// if model.name.starts_with("gpt") || model.name.starts_with("o1") {
|
||||
// OPENAI
|
||||
// } else if model.name.starts_with("claude") {
|
||||
// ANTHROPIC
|
||||
// } else if model.name.starts_with("gemini") {
|
||||
// GOOGLE
|
||||
// } else {
|
||||
// CURSOR
|
||||
// }
|
||||
// },
|
||||
// }).collect())
|
||||
// }
|
||||
|
||||
pub async fn get_stripe_profile(auth_token: &str) -> Option<(String, u32)> {
|
||||
let client = super::client::build_profile_client(auth_token);
|
||||
let response = client.send().await.ok()?.json::<StripeProfile>().await.ok()?;
|
||||
Some((response.membership_type, i32_to_u32(response.days_remaining_on_trial)))
|
||||
}
|
||||
|
||||
@@ -56,15 +56,8 @@ pub fn validate_checksum(checksum: &str) -> bool {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 验证 BASE64 部分
|
||||
let base64_len = 8;
|
||||
let encoded_part = &checksum[..base64_len];
|
||||
if !BASE64.decode(encoded_part).is_ok() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 验证 device_id hash 部分
|
||||
let device_hash = &checksum[base64_len..];
|
||||
let device_hash = &checksum[8..];
|
||||
is_valid_hash(device_hash)
|
||||
}
|
||||
// 包含 MAC hash 的情况
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use anyhow::Result;
|
||||
use reqwest::Client;
|
||||
use oauth2::{
|
||||
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl,
|
||||
TokenResponse, TokenUrl,
|
||||
};
|
||||
use reqwest::{Client, Url};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const OAUTH_AUTHORIZE_URL: &str = "https://connect.linux.do/oauth2/authorize";
|
||||
const OAUTH_TOKEN_URL: &str = "https://connect.linux.do/oauth2/token";
|
||||
const OAUTH_USER_INFO_URL: &str = "https://connect.linux.do/api/user";
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ForumUser {
|
||||
pub id: i64,
|
||||
pub username: String,
|
||||
@@ -17,52 +21,49 @@ pub struct ForumUser {
|
||||
}
|
||||
|
||||
pub struct ForumOAuth {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
redirect_uri: String,
|
||||
oauth_client: BasicClient,
|
||||
http_client: Client,
|
||||
}
|
||||
|
||||
impl ForumOAuth {
|
||||
pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
|
||||
Self {
|
||||
client_id,
|
||||
client_secret,
|
||||
redirect_uri,
|
||||
http_client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_authorize_url(&self, state: &str) -> String {
|
||||
format!(
|
||||
"{}?response_type=code&client_id={}&redirect_uri={}&state={}",
|
||||
OAUTH_AUTHORIZE_URL,
|
||||
self.client_id,
|
||||
urlencoding::encode(&self.redirect_uri),
|
||||
state
|
||||
pub fn new(client_id: String, client_secret: String, redirect_url: String) -> Result<Self> {
|
||||
let oauth_client = BasicClient::new(
|
||||
ClientId::new(client_id),
|
||||
Some(ClientSecret::new(client_secret)),
|
||||
AuthUrl::new(OAUTH_AUTHORIZE_URL.to_string())?,
|
||||
Some(TokenUrl::new(OAUTH_TOKEN_URL.to_string())?),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url)?);
|
||||
|
||||
Ok(Self {
|
||||
oauth_client,
|
||||
http_client: Client::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn exchange_code_for_token(&self, code: &str) -> Result<String> {
|
||||
let response = self
|
||||
.http_client
|
||||
.post(OAUTH_TOKEN_URL)
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("client_id", &self.client_id),
|
||||
("client_secret", &self.client_secret),
|
||||
("redirect_uri", &self.redirect_uri),
|
||||
])
|
||||
.send()
|
||||
.await?
|
||||
.json::<serde_json::Value>()
|
||||
pub fn get_authorize_url(&self) -> (Url, CsrfToken) {
|
||||
self.oauth_client
|
||||
.authorize_url(|| CsrfToken::new_random())
|
||||
.url()
|
||||
}
|
||||
|
||||
pub async fn exchange_code_for_token(
|
||||
&self,
|
||||
code: &str,
|
||||
returned_state: &str,
|
||||
expected_state: &str,
|
||||
) -> Result<String> {
|
||||
if returned_state != expected_state {
|
||||
return Err(anyhow::anyhow!("Invalid state parameter"));
|
||||
}
|
||||
|
||||
let token = self
|
||||
.oauth_client
|
||||
.exchange_code(AuthorizationCode::new(code.to_string()))
|
||||
.request_async(oauth2::reqwest::async_http_client)
|
||||
.await?;
|
||||
|
||||
Ok(response["access_token"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("No access token found"))?
|
||||
.to_string())
|
||||
Ok(token.access_token().secret().clone())
|
||||
}
|
||||
|
||||
pub async fn get_user_info(&self, access_token: &str) -> Result<ForumUser> {
|
||||
|
||||
148
src/common/utils/token.rs
Normal file
148
src/common/utils/token.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||
use chrono::{DateTime, Local, TimeZone};
|
||||
|
||||
// 验证jwt token是否有效
|
||||
pub fn validate_token(token: &str) -> bool {
|
||||
// 检查 token 格式
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 解码 payload
|
||||
let payload = match URL_SAFE_NO_PAD.decode(parts[1]) {
|
||||
Ok(decoded) => decoded,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
// 转换为字符串
|
||||
let payload_str = match String::from_utf8(payload) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
// 解析 JSON
|
||||
let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
// 验证必要字段是否存在且有效
|
||||
let required_fields = ["sub", "exp", "iss", "aud", "randomness", "time"];
|
||||
for field in required_fields {
|
||||
if !payload_json.get(field).is_some() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 randomness 长度
|
||||
if let Some(randomness) = payload_json["randomness"].as_str() {
|
||||
if randomness.len() != 18 {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 验证 time 字段
|
||||
if let Some(time) = payload_json["time"].as_str() {
|
||||
// 验证 time 是否为有效的数字字符串
|
||||
if let Ok(time_value) = time.parse::<i64>() {
|
||||
let current_time = chrono::Utc::now().timestamp();
|
||||
if time_value > current_time {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if let Some(exp) = payload_json["exp"].as_i64() {
|
||||
let current_time = chrono::Utc::now().timestamp();
|
||||
if current_time > exp {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 验证发行者
|
||||
if payload_json["iss"].as_str() != Some("https://authentication.cursor.sh") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 验证受众
|
||||
if payload_json["aud"].as_str() != Some("https://cursor.com") {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// 从 JWT token 中提取用户 ID
|
||||
pub fn extract_user_id(token: &str) -> Option<String> {
|
||||
// JWT token 由3部分组成,用 . 分隔
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// 解码 payload (第二部分)
|
||||
let payload = match URL_SAFE_NO_PAD.decode(parts[1]) {
|
||||
Ok(decoded) => decoded,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// 将 payload 转换为字符串
|
||||
let payload_str = match String::from_utf8(payload) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// 解析 JSON
|
||||
let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// 提取 sub 字段
|
||||
payload_json["sub"]
|
||||
.as_str()
|
||||
.map(|s| s.split('|').nth(1).unwrap_or(s).to_string())
|
||||
}
|
||||
|
||||
// 从 JWT token 中提取 time 字段
|
||||
pub fn extract_time(token: &str) -> Option<DateTime<Local>> {
|
||||
// JWT token 由3部分组成,用 . 分隔
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// 解码 payload (第二部分)
|
||||
let payload = match URL_SAFE_NO_PAD.decode(parts[1]) {
|
||||
Ok(decoded) => decoded,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// 将 payload 转换为字符串
|
||||
let payload_str = match String::from_utf8(payload) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// 解析 JSON
|
||||
let payload_json: serde_json::Value = match serde_json::from_str(&payload_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// 提取时间戳并转换为本地时间
|
||||
payload_json["time"]
|
||||
.as_str()
|
||||
.and_then(|t| t.parse::<i64>().ok())
|
||||
.and_then(|timestamp| Local.timestamp_opt(timestamp, 0).single())
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
use crate::{
|
||||
app::{
|
||||
constant::EMPTY_STRING,
|
||||
model::TokenInfo,
|
||||
lazy::{TOKEN_FILE, TOKEN_LIST_FILE},
|
||||
},
|
||||
common::utils::{generate_checksum, generate_hash},
|
||||
};
|
||||
|
||||
// 规范化文件内容并写入
|
||||
fn normalize_and_write(content: &str, file_path: &str) -> String {
|
||||
let normalized = content.replace("\r\n", "\n");
|
||||
if normalized != content {
|
||||
if let Err(e) = std::fs::write(file_path, &normalized) {
|
||||
eprintln!("警告: 无法更新规范化的文件: {}", e);
|
||||
}
|
||||
}
|
||||
normalized
|
||||
}
|
||||
|
||||
// 解析token和别名
|
||||
fn parse_token_alias(token_part: &str, line: &str) -> Option<(String, Option<String>)> {
|
||||
match token_part.split("::").collect::<Vec<_>>() {
|
||||
parts if parts.len() == 1 => Some((parts[0].to_string(), None)),
|
||||
parts if parts.len() == 2 => Some((parts[1].to_string(), Some(parts[0].to_string()))),
|
||||
_ => {
|
||||
eprintln!("警告: 忽略无效的行: {}", line);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Token 加载函数
|
||||
pub fn load_tokens() -> Vec<TokenInfo> {
|
||||
let token_file = TOKEN_FILE.as_str();
|
||||
let token_list_file = TOKEN_LIST_FILE.as_str();
|
||||
|
||||
// 确保文件存在
|
||||
for file in [&token_file, &token_list_file] {
|
||||
if !std::path::Path::new(file).exists() {
|
||||
if let Err(e) = std::fs::write(file, EMPTY_STRING) {
|
||||
eprintln!("警告: 无法创建文件 '{}': {}", file, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 读取和规范化 token 文件
|
||||
let token_entries = match std::fs::read_to_string(&token_file) {
|
||||
Ok(content) => {
|
||||
let normalized = normalize_and_write(&content, &token_file);
|
||||
normalized
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
return None;
|
||||
}
|
||||
parse_token_alias(line, line)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("警告: 无法读取token文件 '{}': {}", token_file, e);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// 读取和规范化 token-list 文件
|
||||
let mut token_map: std::collections::HashMap<String, (String, Option<String>)> =
|
||||
match std::fs::read_to_string(&token_list_file) {
|
||||
Ok(content) => {
|
||||
let normalized = normalize_and_write(&content, &token_list_file);
|
||||
normalized
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = line.split(',').collect();
|
||||
match parts[..] {
|
||||
[token_part, checksum] => {
|
||||
let (token, alias) = parse_token_alias(token_part, line)?;
|
||||
Some((token, (checksum.to_string(), alias)))
|
||||
}
|
||||
_ => {
|
||||
eprintln!("警告: 忽略无效的token-list行: {}", line);
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("警告: 无法读取token-list文件: {}", e);
|
||||
std::collections::HashMap::new()
|
||||
}
|
||||
};
|
||||
|
||||
// 更新或添加新token
|
||||
for (token, alias) in token_entries {
|
||||
if let Some((_, existing_alias)) = token_map.get(&token) {
|
||||
// 只在alias不同时更新已存在的token
|
||||
if alias != *existing_alias {
|
||||
if let Some((checksum, _)) = token_map.get(&token) {
|
||||
token_map.insert(token.clone(), (checksum.clone(), alias));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 为新token生成checksum
|
||||
let checksum = generate_checksum(&generate_hash(), Some(&generate_hash()));
|
||||
token_map.insert(token, (checksum, alias));
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 token-list 文件
|
||||
let token_list_content = token_map
|
||||
.iter()
|
||||
.map(|(token, (checksum, alias))| {
|
||||
if let Some(alias) = alias {
|
||||
format!("{}::{},{}", alias, token, checksum)
|
||||
} else {
|
||||
format!("{},{}", token, checksum)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
if let Err(e) = std::fs::write(&token_list_file, token_list_content) {
|
||||
eprintln!("警告: 无法更新token-list文件: {}", e);
|
||||
}
|
||||
|
||||
// 转换为 TokenInfo vector
|
||||
token_map
|
||||
.into_iter()
|
||||
.map(|(token, (checksum, alias))| TokenInfo {
|
||||
token,
|
||||
checksum,
|
||||
alias,
|
||||
usage: None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
71
src/main.rs
71
src/main.rs
@@ -3,15 +3,16 @@ mod chat;
|
||||
mod common;
|
||||
|
||||
use app::{
|
||||
config::handle_config_update,
|
||||
constant::{
|
||||
EMPTY_STRING, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH,
|
||||
ROUTE_GET_CHECKSUM, ROUTE_GET_TOKENINFO_PATH, ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH,
|
||||
ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH,
|
||||
ROUTE_TOKENINFO_PATH, ROUTE_UPDATE_TOKENINFO_PATH,
|
||||
},
|
||||
model::*,
|
||||
lazy::{AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH},
|
||||
config::handle_config_update, constant::{
|
||||
EMPTY_STRING, PKG_NAME, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_API_PATH,
|
||||
ROUTE_AUTH_CALLBACK_PATH, ROUTE_AUTH_INITIATE_PATH, ROUTE_AUTH_PATH, ROUTE_CHAT_PATH,
|
||||
ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, ROUTE_GET_TOKENINFO_PATH,
|
||||
ROUTE_GET_USER_INFO_PATH, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, ROUTE_MODELS_PATH,
|
||||
ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKENINFO_PATH,
|
||||
ROUTE_UPDATE_TOKENINFO_PATH,
|
||||
}, db::init_database, lazy::{
|
||||
OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OAUTH_REDIRECT_URI, PUBLIC_AUTH_TOKEN, ROUTE_PREFIX,
|
||||
}, model::{AppConfig, AppState, VisionAbility}
|
||||
};
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
@@ -19,14 +20,14 @@ use axum::{
|
||||
};
|
||||
use chat::{
|
||||
route::{
|
||||
get_user_info, handle_about, handle_config_page, handle_env_example, handle_get_checksum,
|
||||
handle_get_tokeninfo, handle_health, handle_logs, handle_logs_post, handle_readme,
|
||||
handle_root, handle_static, handle_tokeninfo_page, handle_update_tokeninfo,
|
||||
handle_update_tokeninfo_post,
|
||||
get_user_info, handle_about, handle_auth_callback, handle_auth_initiate,
|
||||
handle_config_page, handle_env_example, handle_get_checksum, handle_get_tokeninfo,
|
||||
handle_health, handle_logs, handle_logs_post, handle_readme, handle_root, handle_static,
|
||||
handle_tokeninfo_page, handle_update_tokeninfo_post,
|
||||
},
|
||||
service::{handle_chat, handle_models},
|
||||
};
|
||||
use common::utils::{parse_bool_from_env, parse_string_from_env, tokens::load_tokens};
|
||||
use common::utils::{parse_bool_from_env, parse_string_from_env};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tower_http::cors::CorsLayer;
|
||||
@@ -46,8 +47,20 @@ async fn main() {
|
||||
// 加载环境变量
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
if AUTH_TOKEN.is_empty() {
|
||||
panic!("AUTH_TOKEN must be set")
|
||||
if PUBLIC_AUTH_TOKEN.is_empty() {
|
||||
panic!("PUBLIC_AUTH_TOKEN must be set")
|
||||
};
|
||||
|
||||
if OAUTH_CLIENT_ID.is_empty() {
|
||||
panic!("OAUTH_CLIENT_ID must be set")
|
||||
};
|
||||
|
||||
if OAUTH_CLIENT_SECRET.is_empty() {
|
||||
panic!("OAUTH_CLIENT_SECRET must be set")
|
||||
};
|
||||
|
||||
if OAUTH_REDIRECT_URI.is_empty() {
|
||||
panic!("OAUTH_REDIRECT_URI must be set")
|
||||
};
|
||||
|
||||
// 初始化全局配置
|
||||
@@ -59,30 +72,29 @@ async fn main() {
|
||||
parse_bool_from_env("PASS_ANY_CLAUDE", false),
|
||||
);
|
||||
|
||||
// 加载 tokens
|
||||
let token_infos = load_tokens();
|
||||
|
||||
// 初始化应用状态
|
||||
#[cfg(feature = "sqlite")]
|
||||
let state = Arc::new(Mutex::new(AppState::new()));
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
let state = Arc::new(Mutex::new(AppState::new(token_infos)));
|
||||
|
||||
init_database(format!("{}.db", PKG_NAME).as_str()).await.unwrap();
|
||||
|
||||
// 设置路由
|
||||
let app = Router::new()
|
||||
.nest(
|
||||
ROUTE_PREFIX.as_str(),
|
||||
Router::new()
|
||||
.route(ROUTE_MODELS_PATH, get(handle_models))
|
||||
.route(ROUTE_CHAT_PATH, post(handle_chat)),
|
||||
)
|
||||
.route(ROUTE_ROOT_PATH, get(handle_root))
|
||||
.route(ROUTE_HEALTH_PATH, get(handle_health))
|
||||
.route(ROUTE_TOKENINFO_PATH, get(handle_tokeninfo_page))
|
||||
.route(ROUTE_MODELS_PATH.as_str(), get(handle_models))
|
||||
.route(ROUTE_GET_CHECKSUM, get(handle_get_checksum))
|
||||
.route(ROUTE_GET_USER_INFO_PATH, get(get_user_info))
|
||||
.route(ROUTE_UPDATE_TOKENINFO_PATH, get(handle_update_tokeninfo))
|
||||
.route(ROUTE_GET_TOKENINFO_PATH, post(handle_get_tokeninfo))
|
||||
.route(
|
||||
ROUTE_UPDATE_TOKENINFO_PATH,
|
||||
post(handle_update_tokeninfo_post),
|
||||
)
|
||||
.route(ROUTE_CHAT_PATH.as_str(), post(handle_chat))
|
||||
.route(ROUTE_LOGS_PATH, get(handle_logs))
|
||||
.route(ROUTE_LOGS_PATH, post(handle_logs_post))
|
||||
.route(ROUTE_ENV_EXAMPLE_PATH, get(handle_env_example))
|
||||
@@ -91,6 +103,15 @@ async fn main() {
|
||||
.route(ROUTE_STATIC_PATH, get(handle_static))
|
||||
.route(ROUTE_ABOUT_PATH, get(handle_about))
|
||||
.route(ROUTE_README_PATH, get(handle_readme))
|
||||
.nest(
|
||||
ROUTE_API_PATH,
|
||||
Router::new().nest(
|
||||
ROUTE_AUTH_PATH,
|
||||
Router::new()
|
||||
.route(ROUTE_AUTH_CALLBACK_PATH, get(handle_auth_callback))
|
||||
.route(ROUTE_AUTH_INITIATE_PATH, get(handle_auth_initiate)),
|
||||
),
|
||||
)
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(state);
|
||||
|
||||
|
||||
@@ -76,6 +76,19 @@
|
||||
<div id="message"></div>
|
||||
|
||||
<script>
|
||||
function getUrlParam(name) {
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
return urlParams.get(name);
|
||||
}
|
||||
|
||||
document.addEventListener('DOMContentLoaded', function () {
|
||||
const auth = getUrlParam('auth');
|
||||
if (auth) {
|
||||
document.getElementById('authToken').value = auth;
|
||||
getTokenInfo();
|
||||
}
|
||||
});
|
||||
|
||||
function showMessage(text, isError = false) {
|
||||
showGlobalMessage(text, isError);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user