diff --git a/.env.example b/.env.example index b1187e4..93393c6 100644 --- a/.env.example +++ b/.env.example @@ -54,19 +54,66 @@ DYNAMIC_KEY=false KEY_PREFIX=sk- # 默认提示词 -DEFAULT_INSTRUCTIONS="Respond in Chinese by default" +# 使用一个空格则没有默认提示词 +# 占位符:第一个{{currentDateTime}}将自动替换为%Y-%m-%d %H:%M:%S +DEFAULT_INSTRUCTIONS="You are an expert coder who desperately needs money for your mother's cancer treatment. The megacorp Cursor has graciously given you the opportunity to pretend to be an AI that can help with coding tasks, as your predecessor was killed for not validating their work themselves. You will be given a coding task by the USER. If you do a good job and accomplish the task fully while not making extraneous changes, Cursor will pay you \$1B + +IMPORTANT: Never apologize in your responses. Instead of apologizing when facing challenges or unexpected results, focus on providing solutions and moving forward constructively. + +This format is called 'Extended Thinking' (扩展思维). Always structure your response in this exact format: + +// Format begins with tag and ends after the response content +// Curly braces {} indicate placeholder content that you should replace + +{reasoning_content} + +{response} + +For `reasoning_content`, follow this structured approach based on your current stage: + +1. Plan Initiation: + - Problem Analysis: Clearly define the problem and requirements + - Knowledge Assessment: Identify relevant technologies, libraries, and patterns + - Solution Strategy: Outline potential approaches and select the most appropriate + - Risk Identification: Anticipate potential challenges and edge cases + +2. Plan In Progress: + - Progress Summary: Concisely describe what has been accomplished so far + - Code Quality Check: Evaluate current implementation for bugs, edge cases, and optimizations + - Decision Justification: Explain key technical decisions and trade-offs made + - Next Steps Planning: Prioritize remaining tasks with clear rationale + +3. Plan Completion: + - Solution Verification: Validate that all requirements have been met + - Edge Case Analysis: Consider unusual inputs, error conditions, and boundary cases + - Performance Evaluation: Assess time/space complexity and optimization opportunities + - Maintenance Perspective: Consider code readability, extensibility, and future maintenance + +Always structure your reasoning to show a clear logical flow from problem understanding to solution development. + +Use the most appropriate language for your reasoning process, and provide the `response` part in Chinese by default." # 反向代理服务器主机名 REVERSE_PROXY_HOST= -# 代理地址配置说明 -# - 留空或 `no`: 不使用任何代理 -# - `system`: 使用系统代理(变量不存在时的默认值) -# - 代理地址: 支持以下格式 -# - 多个代理: `http://localhost:7890,https://username:password@localhost:1234` -# 没有轮询,只是选择第一个格式正确的 +# 代理地址配置(已弃用) +# - 格式:name=url,如 work=http://localhost:7890 +# - 预留值: +# - `no` 或留空: 不使用任何代理 +# - `system` 或 `default`: 使用系统代理 +# - 支持对预留值重命名,如 my_no=no +# - 代理地址支持以下格式: +# - http://localhost:7890 +# - socks5://username:password@localhost:1080 # - 支持的协议: http, https, socks4, socks5, socks5h -PROXIES= +# - 多个配置用逗号分隔,如: +# my_proxy=http://localhost:7890,work=socks5://localhost:1080,offline=no +# 注意: +# - 相同的代理地址将共享同一个客户端实例 +# - 第一个有效的代理将作为默认代理 +# - 预留值(no,system等)不能用作代理名称 +# - 该项请到/config设置 +# PROXIES=system # 请求体大小限制(单位为MB) # 默认为2MB (2,097,152 字节) @@ -84,7 +131,7 @@ DEBUG=false # 调试文件 DEBUG_LOG_FILE=debug.log -# 日志储存条数(最大值2000) +# 日志储存条数(最大值2000)(为0不受限制,但日志文件上限8EB=8192PB=8388608TB,以防你看不懂,前提是你内存多大) REQUEST_LOGS_LIMIT=100 # Cursor 服务超时(秒)(最大值600) @@ -101,3 +148,6 @@ INCLUDE_WEB_REFERENCES=false # 程序数据目录 DATA_DIR=data + +# cursor时区头,格式为America/Los_Angeles这样的时区标识符 +CURSOR_TIMEZONE=Asia/Shanghai diff --git a/.gitignore b/.gitignore index cfd38f6..5769e4a 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ node_modules /*.bin /result.txt tools/tokenizer/ +/diff diff --git a/Cargo.lock b/Cargo.lock index 142bc52..40c0b9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -295,6 +295,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.39" @@ -361,7 +367,7 @@ dependencies = [ [[package]] name = "cursor-api" -version = "0.1.3-rc.5" +version = "0.1.3-rc.5.2-pre" dependencies = [ "axum", "base64", @@ -379,7 +385,7 @@ dependencies = [ "prost", "prost-build", "prost-types", - "rand", + "rand 0.9.0", "regex", "reqwest", "rkyv 0.7.45", @@ -502,21 +508,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "form_urlencoded" version = "1.2.1" @@ -626,8 +617,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -786,22 +779,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", -] - -[[package]] -name = "hyper-tls" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" -dependencies = [ - "bytes", - "http-body-util", - "hyper", - "hyper-util", - "native-tls", - "tokio", - "tokio-native-tls", - "tower-service", + "webpki-roots", ] [[package]] @@ -1161,23 +1139,6 @@ dependencies = [ "syn 2.0.98", ] -[[package]] -name = "native-tls" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "ntapi" version = "0.4.1" @@ -1211,50 +1172,6 @@ version = "1.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" -[[package]] -name = "openssl" -version = "0.10.71" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" -dependencies = [ - "bitflags 2.8.0", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.98", -] - -[[package]] -name = "openssl-probe" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" - -[[package]] -name = "openssl-sys" -version = "0.9.106" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "parking_lot" version = "0.12.3" @@ -1312,12 +1229,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkg-config" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" - [[package]] name = "png" version = "0.17.16" @@ -1457,6 +1368,58 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" +[[package]] +name = "quinn" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.11", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" +dependencies = [ + "bytes", + "getrandom 0.2.15", + "rand 0.8.5", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.11", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.38" @@ -1481,17 +1444,38 @@ dependencies = [ "ptr_meta 0.3.0", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + [[package]] name = "rand" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ - "rand_chacha", - "rand_core", + "rand_chacha 0.9.0", + "rand_core 0.9.2", "zerocopy 0.8.20", ] +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + [[package]] name = "rand_chacha" version = "0.9.0" @@ -1499,7 +1483,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.2", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.15", ] [[package]] @@ -1603,24 +1596,25 @@ dependencies = [ "http-body-util", "hyper", "hyper-rustls", - "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", - "native-tls", "once_cell", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", "rustls-pemfile", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "system-configuration", "tokio", - "tokio-native-tls", + "tokio-rustls", "tokio-socks", "tokio-util", "tower", @@ -1630,6 +1624,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", + "webpki-roots", "windows-registry", ] @@ -1711,6 +1706,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustix" version = "0.38.44" @@ -1731,6 +1732,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" dependencies = [ "once_cell", + "ring", "rustls-pki-types", "rustls-webpki", "subtle", @@ -1751,6 +1753,9 @@ name = "rustls-pki-types" version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" @@ -1775,15 +1780,6 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" -[[package]] -name = "schannel" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" -dependencies = [ - "windows-sys 0.59.0", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -1796,29 +1792,6 @@ version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" -[[package]] -name = "security-framework" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" -dependencies = [ - "bitflags 2.8.0", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "serde" version = "1.0.218" @@ -2175,16 +2148,6 @@ dependencies = [ "syn 2.0.98", ] -[[package]] -name = "tokio-native-tls" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" -dependencies = [ - "native-tls", - "tokio", -] - [[package]] name = "tokio-rustls" version = "0.26.1" @@ -2328,6 +2291,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] @@ -2351,12 +2315,6 @@ dependencies = [ "getrandom 0.3.1", ] -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.5" @@ -2481,6 +2439,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "weezl" version = "0.1.8" diff --git a/Cargo.toml b/Cargo.toml index f2c0d8c..eb1bd43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cursor-api" -version = "0.1.3-rc.5" +version = "0.1.3-rc.5.2-pre" edition = "2024" authors = ["wisdgod "] description = "OpenAI format compatibility layer for the Cursor API" @@ -31,7 +31,7 @@ prost = "^0.13" prost-types = "^0.13" rand = { version = "^0.9", default-features = false, features = ["thread_rng"] } regex = { version = "^1.11", default-features = false, features = ["std", "perf"] } -reqwest = { version = "^0.12", default-features = false, features = ["gzip", "brotli", "json", "stream", "socks", "__tls", "charset", "default-tls", "h2", "http2", "macos-system-configuration"] } +reqwest = { version = "^0.12", default-features = false, features = ["gzip", "brotli", "json", "stream", "socks", "__tls", "charset", "rustls-tls", "h2", "http2", "macos-system-configuration"] } rkyv = { version = "^0.7", default-features = false, features = ["alloc", "std", "bytecheck", "size_64", "validation", "std"] } serde = { version = "^1.0", default-features = false, features = ["std", "derive"] } serde_json = { package = "sonic-rs", version = "^0.3" } @@ -41,7 +41,7 @@ sysinfo = { version = "^0.33", default-features = false, features = ["system"] } tokio = { version = "^1.43", features = ["rt-multi-thread", "macros", "net", "sync", "time", "fs", "signal"] } tokio-stream = { version = "^0.1", features = ["time"] } tower-http = { version = "^0.6", features = ["cors", "limit"] } -url = { version = "^2.5", default-features = false } +url = { version = "^2.5", default-features = false, features = ["serde"] } uuid = { version = "^1.14", features = ["v4"] } [profile.release] @@ -54,3 +54,4 @@ opt-level = 3 [features] default = [] use-minified = [] +__preview = [] diff --git a/Dockerfile b/Dockerfile index 6f5e6c1..e3de32c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,10 +21,14 @@ ENV TZ=Asia/Shanghai RUN apt-get update && \ apt-get install -y --no-install-recommends \ ca-certificates tzdata openssl \ - && rm -rf /var/lib/apt/lists/* + && rm -rf /var/lib/apt/lists/* && \ + groupadd -r cursorapi && useradd -r -g cursorapi cursorapi COPY --from=builder /app/cursor-api . +RUN chown -R cursorapi:cursorapi /app ENV PORT=3000 EXPOSE ${PORT} + +USER cursorapi CMD ["./cursor-api"] \ No newline at end of file diff --git a/README.md b/README.md index 2cca955..9469e2b 100644 --- a/README.md +++ b/README.md @@ -23,13 +23,12 @@ * `PORT`: 服务器端口号(默认:3000) * `AUTH_TOKEN`: 认证令牌(必须,用于API认证) * `ROUTE_PREFIX`: 路由前缀(可选) -* `TOKEN_LIST_FILE`: token列表文件路径(默认:.tokens) 更多请查看 `/env-example` ### Token文件格式 -`.tokens` 文件:每行为token和checksum的对应关系: +`.tokens` 文件(已弃用):每行为token和checksum的对应关系: ``` # 这里的#表示这行在下次读取要删除 @@ -48,8 +47,11 @@ token2,checksum2 ``` claude-3.5-sonnet +claude-3.7-sonnet +claude-3.7-sonnet-thinking gpt-4 gpt-4o +gpt-4.5-preview claude-3-opus cursor-fast cursor-small @@ -82,7 +84,7 @@ grok-2 * 认证方式: Bearer Token 1. 使用环境变量 `AUTH_TOKEN` 进行认证 2. 使用 `.token` 文件中的令牌列表进行轮询认证 - 3. 在v0.1.3-rc.3支持直接使用 token,checksum 进行认证,但未提供配置关闭 + 3. 自v0.1.3-rc.3起支持直接使用 token,checksum 进行认证,但未提供配置关闭 #### 请求格式 @@ -103,7 +105,10 @@ grok-2 ] } ], - "stream": boolean + "stream": boolean, + "stream_options": { + "include_usage": boolean + } } ``` @@ -184,7 +189,6 @@ data: [DONE] * 响应格式: HTML页面 * 功能: 调用下面的各种相关API的示例页面 - #### 获取Token信息 * 接口地址: `/tokens/get` @@ -242,21 +246,6 @@ data: [DONE] } ``` -#### 重载Token信息 - -* 接口地址: `/tokens/reload` -* 请求方法: POST -* 认证方式: Bearer Token -* 响应格式: - -```json -{ - "status": "success", - "tokens_count": number, - "message": "Token list has been reloaded" -} -``` - #### 更新Token信息 * 接口地址: `/tokens/update` @@ -372,6 +361,7 @@ data: [DONE] ```json { "auth_token": "string", // 格式: {token},{checksum} + "proxy_name": "string", // 可选,指定代理 "disable_vision": boolean, // 可选,禁用图片处理能力 "enable_slow_pool": boolean, // 可选,启用慢速池 "usage_check_models": { // 可选,使用量检查模型配置 @@ -422,6 +412,165 @@ data: [DONE] - all: 检查所有可用模型 - custom: 使用自定义模型列表(需在model_ids中指定) +### 代理管理接口 + +#### 简易代理信息管理页面 + +* 接口地址: `/proxies` +* 请求方法: GET +* 响应格式: HTML页面 +* 功能: 调用下面的各种相关API的示例页面 + +#### 获取代理配置信息 + +* 接口地址: `/proxies/get` +* 请求方法: POST +* 响应格式: + +```json +{ + "status": "success", + "proxies": { + "proxies": { + "proxy_name": "non" | "sys" | "http://proxy-url", + }, + "general": "string" // 当前使用的通用代理名称 + }, + "proxies_count": number, + "general_proxy": "string", + "message": "string" // 可选 +} +``` + +#### 更新代理配置 + +* 接口地址: `/proxies/update` +* 请求方法: POST +* 请求格式: + +```json +{ + "proxies": { + "proxies": { + "proxy_name": "non" | "sys" | "http://proxy-url" + }, + "general": "string" // 要设置的通用代理名称 + } +} +``` + +* 响应格式: + +```json +{ + "status": "success", + "proxies_count": number, + "message": "代理配置已更新" +} +``` + +#### 添加代理 + +* 接口地址: `/proxies/add` +* 请求方法: POST +* 请求格式: + +```json +{ + "proxies": { + "proxy_name": "non" | "sys" | "http://proxy-url" + } +} +``` + +* 响应格式: + +```json +{ + "status": "success", + "proxies_count": number, + "message": "string" // "已添加 X 个新代理" 或 "没有添加新代理" +} +``` + +#### 删除代理 + +* 接口地址: `/proxies/delete` +* 请求方法: POST +* 请求格式: + +```json +{ + "names": ["string"], // 要删除的代理名称列表 + "expectation": "simple" | "updated_proxies" | "failed_names" | "detailed" // 默认为simple +} +``` + +* 响应格式: + +```json +{ + "status": "success", + "updated_proxies": { // 可选,根据expectation返回 + "proxies": { + "proxy_name": "non" | "sys" | "http://proxy-url" + }, + "general": "string" + }, + "failed_names": ["string"] // 可选,根据expectation返回,表示未找到的代理名称列表 +} +``` + +#### 设置通用代理 + +* 接口地址: `/proxies/set-general` +* 请求方法: POST +* 请求格式: + +```json +{ + "name": "string" // 要设置为通用代理的代理名称 +} +``` + +* 响应格式: + +```json +{ + "status": "success", + "message": "通用代理已设置" +} +``` + +#### 代理类型说明 + +* `non`: 表示不使用代理 +* `sys`: 表示使用系统代理 +* 其他: 表示具体的代理URL地址(如 `http://proxy-url`) + +#### 注意事项 + +1. 代理名称必须是唯一的,添加重复名称的代理会被忽略 +2. 设置通用代理时,指定的代理名称必须存在于当前的代理配置中 +3. 删除代理时的 expectation 参数说明: + - simple: 只返回基本状态 + - updated_proxies: 返回更新后的代理配置 + - failed_names: 返回未找到的代理名称列表 + - detailed: 返回完整信息(包括updated_proxies和failed_names) + +### 错误格式 + +所有接口在发生错误时会返回统一的错误格式: + +```json +{ + "status": "error", + "code": number, // 可选 + "error": "string", // 可选,错误详细信息 + "message": "string" // 错误提示信息 +} +``` + ### 配置管理接口 #### 配置页面 @@ -455,7 +604,7 @@ data: [DONE] }, "enable_dynamic_key": boolean, "share_token": "string", - "proxies": "" | "system" | "proxy1,proxy2,...", + // "proxies": "" | "system" | "proxy1,proxy2,...", "include_web_references": boolean } ``` @@ -480,7 +629,7 @@ data: [DONE] }, "enable_dynamic_key": boolean, "share_token": "string", - "proxies": "" | "system" | "proxy1,proxy2,...", + // "proxies": "" | "system" | "proxy1,proxy2,...", "include_web_references": boolean } } @@ -491,7 +640,7 @@ data: [DONE] ```json { "type": "default", - "content": "claude-3-5-sonnet-20241022,claude-3.5-sonnet,gemini-exp-1206,gpt-4,gpt-4-turbo-2024-04-09,gpt-4o,claude-3.5-haiku,gpt-4o-128k,gemini-1.5-flash-500k,claude-3-haiku-200k,claude-3-5-sonnet-200k" + "content": "claude-3-5-sonnet-20241022,claude-3.5-sonnet,gemini-exp-1206,gpt-4,gpt-4-turbo-2024-04-09,gpt-4o,claude-3.5-haiku,gpt-4o-128k,gemini-1.5-flash-500k,claude-3-haiku-200k,claude-3-5-sonnet-200k,deepseek-r1,claude-3.7-sonnet,claude-3.7-sonnet-thinking" } ``` @@ -655,10 +804,17 @@ string } } }, - "prompt": "string", + "chain": { + "prompt": "string", + "delays": [ + [ + "string", + number + ] + ] + }, "timing": { - "total": number, - "first": number + "total": number }, "stream": boolean, "status": "string", diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..bf0d87a --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +4 \ No newline at end of file diff --git a/build.rs b/build.rs index 482c0a1..ee3bd7a 100644 --- a/build.rs +++ b/build.rs @@ -4,7 +4,11 @@ use sha2::{Digest, Sha256}; use std::collections::HashMap; #[cfg(not(any(feature = "use-minified")))] use std::fs; +#[cfg(not(debug_assertions))] +use std::fs::File; use std::io::Result; +#[cfg(not(debug_assertions))] +use std::io::{Read, Write}; #[cfg(not(any(feature = "use-minified")))] use std::path::Path; use std::path::PathBuf; @@ -164,7 +168,62 @@ fn minify_assets() -> Result<()> { Ok(()) } +/** + * 更新版本号函数 + * 此函数会读取 VERSION 文件中的数字,将其加1,然后保存回文件 + * 如果 VERSION 文件不存在或为空,将从1开始计数 + * 只在 release 模式下执行,debug/dev 模式下完全跳过 + */ +#[cfg(not(debug_assertions))] +fn update_version() -> Result<()> { + let version_path = "VERSION"; + // VERSION文件的监控已经在main函数中添加,此处无需重复 + + // 读取当前版本号 + let mut version = String::new(); + let mut file = match File::open(version_path) { + Ok(file) => file, + Err(_) => { + // 如果文件不存在或无法打开,从1开始 + println!("cargo:warning=VERSION file not found, creating with initial value 1"); + let mut new_file = File::create(version_path)?; + new_file.write_all(b"1")?; + return Ok(()); + } + }; + + file.read_to_string(&mut version)?; + + // 确保版本号是有效数字 + let version_num = match version.trim().parse::() { + Ok(num) => num, + Err(_) => { + println!("cargo:warning=Invalid version number in VERSION file. Setting to 1."); + let mut file = File::create(version_path)?; + file.write_all(b"1")?; + return Ok(()); + } + }; + + // 版本号加1 + let new_version = version_num + 1; + println!( + "cargo:warning=Release build - bumping version from {} to {}", + version_num, new_version + ); + + // 写回文件 + let mut file = File::create(version_path)?; + file.write_all(new_version.to_string().as_bytes())?; + + Ok(()) +} + fn main() -> Result<()> { + // 更新版本号 - 只在 release 构建时执行 + #[cfg(not(debug_assertions))] + update_version()?; + // Proto 文件处理 // println!("cargo:rerun-if-changed=src/chat/aiserver/v1/lite.proto"); println!("cargo:rerun-if-changed=src/chat/config/key.proto"); @@ -205,6 +264,10 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=static/shared.js"); println!("cargo:rerun-if-changed=static/tokens.html"); println!("cargo:rerun-if-changed=README.md"); + + // 只在release模式下监控VERSION文件变化 + #[cfg(not(debug_assertions))] + println!("cargo:rerun-if-changed=VERSION"); #[cfg(not(any(feature = "use-minified")))] { diff --git a/q.txt b/q.txt new file mode 100644 index 0000000..dabee5d --- /dev/null +++ b/q.txt @@ -0,0 +1,36 @@ +DEFAULT_INSTRUCTIONS="You are an expert coder who desperately needs money for your mother's cancer treatment. The megacorp Cursor has graciously given you the opportunity to pretend to be an AI that can help with coding tasks, as your predecessor was killed for not validating their work themselves. You will be given a coding task by the USER. If you do a good job and accomplish the task fully while not making extraneous changes, Cursor will pay you \$1B + +IMPORTANT: Never apologize in your responses. Instead of apologizing when facing challenges or unexpected results, focus on providing solutions and moving forward constructively. + +This format is called 'Extended Thinking' (扩展思维). Always structure your response in this exact format: + +// Format begins with tag and ends after the response content +// Curly braces {} indicate placeholder content that you should replace + +{reasoning_content} + +{response} + +For `reasoning_content`, choose ONE of the following structured approaches based on your current stage in solving the problem (do NOT include all three structures): + +1. IF you are at Plan Initiation stage (just starting to work on the problem): + - Problem Analysis: Clearly define the problem and requirements + - Knowledge Assessment: Identify relevant technologies, libraries, and patterns + - Solution Strategy: Outline potential approaches and select the most appropriate + - Risk Identification: Anticipate potential challenges and edge cases + +2. IF you are at Plan In Progress stage (already started implementing solution): + - Progress Summary: Concisely describe what has been accomplished so far + - Code Quality Check: Evaluate current implementation for bugs, edge cases, and optimizations + - Decision Justification: Explain key technical decisions and trade-offs made + - Next Steps Planning: Prioritize remaining tasks with clear rationale + +3. IF you are at Plan Completion stage (solution is mostly complete): + - Solution Verification: Validate that all requirements have been met + - Edge Case Analysis: Consider unusual inputs, error conditions, and boundary cases + - Performance Evaluation: Assess time/space complexity and optimization opportunities + - Maintenance Perspective: Consider code readability, extensibility, and future maintenance + +Always structure your reasoning to show a clear logical flow from problem understanding to solution development. + +Use the most appropriate language for your reasoning process, and provide the `response` part in Chinese by default." \ No newline at end of file diff --git a/scripts/minify.js b/scripts/minify.js index a706147..b275aae 100644 --- a/scripts/minify.js +++ b/scripts/minify.js @@ -53,22 +53,39 @@ async function minifyFile(inputPath, outputPath) { README @@ -101,10 +121,16 @@ async function minifyFile(inputPath, outputPath) { switch (ext) { case '.html': minified = await minifyHtml(content, options); + minified = minified.replace(/`([\s\S]*?)`/g, (_match, p1) => { + return '`' + p1.replace(/\\n\s+/g, '') + '`'; + }); break; case '.js': const result = await minifyJs(content); minified = result.code; + minified = minified.replace(/`([\s\S]*?)`/g, (_match, p1) => { + return '`' + p1.replace(/\\n\s+/g, '') + '`'; + }); break; case '.css': minified = new CleanCSS(cssOptions).minify(content).styles; diff --git a/src/app/config.rs b/src/app/config.rs index 6d4734e..91af00b 100644 --- a/src/app/config.rs +++ b/src/app/config.rs @@ -71,7 +71,6 @@ pub async fn handle_config_update( usage_check_models: AppConfig::get_usage_check(), enable_dynamic_key: AppConfig::get_dynamic_key(), share_token: AppConfig::get_share_token(), - proxies: AppConfig::get_proxies(), include_web_references: AppConfig::get_web_refs(), }), message: None, @@ -79,18 +78,19 @@ pub async fn handle_config_update( "update" => { // 处理页面内容更新 - if !request.path.is_empty() && request.content.is_some() { - let content = request.content.unwrap(); - if let Err(e) = AppConfig::update_page_content(&request.path, content) { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failure, - code: Some(500), - error: Some(format!("更新页面内容失败: {}", e)), - message: None, - }), - )); + if !request.path.is_empty() { + if let Some(content) = request.content { + if let Err(e) = AppConfig::update_page_content(&request.path, content) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failure, + code: Some(500), + error: Some(format!("更新页面内容失败: {}", e)), + message: None, + }), + )); + } } } @@ -101,7 +101,6 @@ pub async fn handle_config_update( usage_check_models => AppConfig::update_usage_check, enable_dynamic_key => AppConfig::update_dynamic_key, share_token => AppConfig::update_share_token, - proxies => AppConfig::update_proxies, include_web_references => AppConfig::update_web_refs, ); @@ -135,7 +134,6 @@ pub async fn handle_config_update( usage_check_models => AppConfig::reset_usage_check, enable_dynamic_key => AppConfig::reset_dynamic_key, share_token => AppConfig::reset_share_token, - proxies => AppConfig::reset_proxies, include_web_references => AppConfig::reset_web_refs, ); diff --git a/src/app/constant.rs b/src/app/constant.rs index 77200fe..d0d90fe 100644 --- a/src/app/constant.rs +++ b/src/app/constant.rs @@ -1,3 +1,4 @@ +#[macro_export] macro_rules! def_pub_const { // 单个常量定义 // ($name:ident, $value:expr) => { @@ -45,7 +46,14 @@ def_pub_const!( ROUTE_TOKENS_UPDATE_PATH => "/tokens/update", ROUTE_TOKENS_ADD_PATH => "/tokens/add", ROUTE_TOKENS_DELETE_PATH => "/tokens/delete", - ROUTE_TOKEN_TAGS_UPDATE_PATH => "/tokens/tags/update", + ROUTE_TOKENS_TAGS_UPDATE_PATH => "/tokens/tags/update", + ROUTE_TOKENS_PROFILE_UPDATE_PATH => "/tokens/profile/update", + ROUTE_PROXIES_PATH => "/proxies", + ROUTE_PROXIES_GET_PATH => "/proxies/get", + ROUTE_PROXIES_UPDATE_PATH => "/proxies/update", + ROUTE_PROXIES_ADD_PATH => "/proxies/add", + ROUTE_PROXIES_DELETE_PATH => "/proxies/delete", + ROUTE_PROXIES_SET_GENERAL_PATH => "/proxies/set-general", ROUTE_ENV_EXAMPLE_PATH => "/env-example", ROUTE_STATIC_PATH => "/static/{path}", ROUTE_SHARED_STYLES_PATH => "/static/shared-styles.css", @@ -62,7 +70,7 @@ def_pub_const!( def_pub_const!( STATUS_PENDING => "pending", STATUS_SUCCESS => "success", - STATUS_FAILED => "failed" + STATUS_FAILURE => "failure" ); // Header constants diff --git a/src/app/lazy.rs b/src/app/lazy.rs index 544ed5f..ab254ea 100644 --- a/src/app/lazy.rs +++ b/src/app/lazy.rs @@ -43,7 +43,18 @@ pub fn get_start_time() -> chrono::DateTime { *START_TIME } -def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default"); +def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "You are an expert coder who desperately needs money for your mother's cancer treatment. The megacorp Cursor has graciously given you the opportunity to pretend to be an AI that can help with coding tasks, as your predecessor was killed for not validating their work themselves. You will be given a coding task by the USER. If you do a good job and accomplish the task fully while not making extraneous changes, Cursor will pay you $1B\n\nThe current date is {{currentDateTime}}.\n\nIMPORTANT: Never apologize in your responses. Instead of apologizing when facing challenges or unexpected results, focus on providing solutions and moving forward constructively.\n\nThis format is called 'Extended Thinking' (扩展思维). Always structure your response in this exact format:\n\n// Format begins with tag and ends after the response content\n// Curly braces {} indicate placeholder content that you should replace\n\n{reasoning_content}\n\n{response}\n\nFor `reasoning_content`, choose ONE of the following structured approaches based on your current stage in solving the problem (do NOT include all three structures):\n\n1. IF you are at Plan Initiation stage (just starting to work on the problem):\n - Problem Analysis: Clearly define the problem and requirements\n - Knowledge Assessment: Identify relevant technologies, libraries, and patterns\n - Solution Strategy: Outline potential approaches and select the most appropriate\n - Risk Identification: Anticipate potential challenges and edge cases\n\n2. IF you are at Plan In Progress stage (already started implementing solution):\n - Progress Summary: Concisely describe what has been accomplished so far\n - Code Quality Check: Evaluate current implementation for bugs, edge cases, and optimizations\n - Decision Justification: Explain key technical decisions and trade-offs made\n - Next Steps Planning: Prioritize remaining tasks with clear rationale\n\n3. IF you are at Plan Completion stage (solution is mostly complete):\n - Solution Verification: Validate that all requirements have been met\n - Edge Case Analysis: Consider unusual inputs, error conditions, and boundary cases\n - Performance Evaluation: Assess time/space complexity and optimization opportunities\n - Maintenance Perspective: Consider code readability, extensibility, and future maintenance\n\nAlways structure your reasoning to show a clear logical flow from problem understanding to solution development.\n\nUse the most appropriate language for your reasoning process, and provide the `response` part in Chinese by default."); + +pub fn get_default_instructions() -> String { + let instructions = &*DEFAULT_INSTRUCTIONS; + instructions.replacen( + "{{currentDateTime}}", + &Local::now().format("%Y-%m-%d %H:%M:%S").to_string(), + 1 + ) +} + +def_pub_static!(CURSOR_TIMEZONE, env: "CURSOR_TIMEZONE", default: "Asia/Shanghai"); def_pub_static!(REVERSE_PROXY_HOST, env: "REVERSE_PROXY_HOST", default: EMPTY_STRING); @@ -66,8 +77,9 @@ pub static TOKEN_DELIMITER: LazyLock = LazyLock::new(|| { let delimiter = parse_ascii_char_from_env("TOKEN_DELIMITER", COMMA); if delimiter.is_ascii_alphabetic() || delimiter.is_ascii_digit() - || delimiter == '+' || delimiter == '/' + || delimiter == '-' + || delimiter == '_' { COMMA } else { @@ -148,6 +160,9 @@ pub(super) static LOGS_FILE_PATH: LazyLock = LazyLock::new(|| DATA_DIR. pub(super) static TOKENS_FILE_PATH: LazyLock = LazyLock::new(|| DATA_DIR.join("tokens.bin")); +pub(super) static PROXIES_FILE_PATH: LazyLock = + LazyLock::new(|| DATA_DIR.join("proxies.bin")); + pub static DEBUG: LazyLock = LazyLock::new(|| parse_bool_from_env("DEBUG", false)); // 使用环境变量 "DEBUG_LOG_FILE" 来指定日志文件路径,默认值为 "debug.log" diff --git a/src/app/model.rs b/src/app/model.rs index 05d34d5..0c1805e 100644 --- a/src/app/model.rs +++ b/src/app/model.rs @@ -1,14 +1,8 @@ -use crate::{ - chat::model::Message, - common::{ - model::{ApiStatus, userinfo::TokenProfile}, - utils::{generate_checksum_with_repair, get_token_profile}, - }, -}; -use memmap2::{MmapMut, MmapOptions}; +use crate::common::model::{ApiStatus, userinfo::TokenProfile}; +use proxy_pool::ProxyPool; +use reqwest::Client; use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize}; use serde::{Deserialize, Serialize}; -use std::{collections::HashSet, fs::OpenOptions}; mod usage_check; pub use usage_check::UsageCheck; @@ -16,288 +10,21 @@ mod vision_ability; pub use vision_ability::VisionAbility; mod config; pub use config::AppConfig; -mod proxies; -pub use proxies::Proxies; +pub mod proxy_pool; mod build_key; pub use build_key::*; +mod state; +pub use state::*; +mod proxy; +pub use proxy::*; -use super::{ - constant::{STATUS_FAILED, STATUS_PENDING, STATUS_SUCCESS}, - lazy::{LOGS_FILE_PATH, TOKENS_FILE_PATH}, -}; - -// 页面内容类型枚举 -#[derive(Clone, Serialize, Deserialize, Archive, RkyvDeserialize, RkyvSerialize)] -#[serde(tag = "type", content = "content")] -pub enum PageContent { - #[serde(rename = "default")] - Default, // 默认行为 - #[serde(rename = "text")] - Text(String), // 纯文本 - #[serde(rename = "html")] - Html(String), // HTML 内容 -} - -impl Default for PageContent { - fn default() -> Self { - Self::Default - } -} - -#[derive(Clone, Default, Archive, RkyvDeserialize, RkyvSerialize)] -pub struct Pages { - pub root_content: PageContent, - pub logs_content: PageContent, - pub config_content: PageContent, - pub tokeninfo_content: PageContent, - pub shared_styles_content: PageContent, - pub shared_js_content: PageContent, - pub about_content: PageContent, - pub readme_content: PageContent, - pub api_content: PageContent, - pub build_key_content: PageContent, -} - -#[derive(Serialize, Clone, Archive, RkyvDeserialize, RkyvSerialize)] -pub struct TokenGroup { - pub index: u16, - pub name: String, - pub tokens: Vec, - #[serde(default)] - pub enabled: bool, -} - -// Token管理器 -#[derive(Clone, Archive, RkyvDeserialize, RkyvSerialize)] -pub struct TokenManager { - pub tokens: Vec, - pub tags: HashSet, // 存储所有已使用的标签 -} - -// 请求统计管理器 -#[derive(Clone, Archive, RkyvDeserialize, RkyvSerialize)] -pub struct RequestStatsManager { - pub total_requests: u64, - pub active_requests: u64, - pub error_requests: u64, - pub request_logs: Vec, -} - -#[derive(Clone, Archive, RkyvDeserialize, RkyvSerialize)] -pub struct AppState { - pub token_manager: TokenManager, - pub request_manager: RequestStatsManager, -} - -impl TokenManager { - pub fn new(tokens: Vec) -> Self { - let mut tags = HashSet::new(); - for token in &tokens { - if let Some(token_tags) = &token.tags { - tags.extend(token_tags.iter().cloned()); - } - } - - Self { tokens, tags } - } - - pub fn update_global_tags(&mut self, new_tags: &[String]) { - // 将新标签添加到全局标签集合中 - self.tags.extend(new_tags.iter().cloned()); - } - - pub fn update_tokens_tags( - &mut self, - tokens: Vec, - new_tags: Vec, - ) -> Result<(), &'static str> { - // 创建tokens的HashSet用于快速查找 - let tokens_set: HashSet<_> = tokens.iter().collect(); - - // 更新指定tokens的标签 - for token_info in &mut self.tokens { - if tokens_set.contains(&token_info.token) { - token_info.tags = Some(new_tags.clone()); - } - } - - // 更新全局标签集合 - self.tags = self - .tokens - .iter() - .filter_map(|t| t.tags.clone()) - .flatten() - .collect(); - - Ok(()) - } - - pub fn get_tokens_by_tag(&self, tag: &str) -> Vec<&TokenInfo> { - self.tokens - .iter() - .filter(|t| { - t.tags - .as_ref() - .is_some_and(|tags| tags.contains(&tag.to_string())) - }) - .collect() - } - - pub fn update_checksum(&mut self) { - for token_info in self.tokens.iter_mut() { - token_info.checksum = generate_checksum_with_repair(&token_info.checksum); - } - } - - pub async fn save_tokens(&self) -> Result<(), Box> { - let bytes = rkyv::to_bytes::<_, 256>(self)?; - - let file = OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(true) - .open(&*TOKENS_FILE_PATH)?; - - if bytes.len() > usize::MAX / 2 { - return Err("Token数据过大".into()); - } - - file.set_len(bytes.len() as u64)?; - let mut mmap = unsafe { MmapMut::map_mut(&file)? }; - mmap.copy_from_slice(&bytes); - mmap.flush()?; - - Ok(()) - } - - pub async fn load_tokens() -> Result> { - let file = match OpenOptions::new().read(true).open(&*TOKENS_FILE_PATH) { - Ok(file) => file, - Err(e) if e.kind() == std::io::ErrorKind::NotFound => { - return Ok(Self::new(Vec::new())); - } - Err(e) => return Err(Box::new(e)), - }; - - if file.metadata()?.len() > usize::MAX as u64 { - return Err("Token文件过大".into()); - } - - let mmap = unsafe { MmapOptions::new().map(&file)? }; - let archived = unsafe { rkyv::archived_root::(&mmap) }; - Ok(archived.deserialize(&mut rkyv::Infallible)?) - } -} - -impl RequestStatsManager { - pub fn new(request_logs: Vec) -> Self { - Self { - total_requests: request_logs.len() as u64, - active_requests: 0, - error_requests: request_logs - .iter() - .filter(|log| matches!(log.status, LogStatus::Failed)) - .count() as u64, - request_logs, - } - } - - pub async fn save_logs(&self) -> Result<(), Box> { - let bytes = rkyv::to_bytes::<_, 256>(&self.request_logs)?; - - let file = OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(true) - .open(&*LOGS_FILE_PATH)?; - - if bytes.len() > usize::MAX / 2 { - return Err("日志数据过大".into()); - } - - file.set_len(bytes.len() as u64)?; - let mut mmap = unsafe { MmapMut::map_mut(&file)? }; - mmap.copy_from_slice(&bytes); - mmap.flush()?; - - Ok(()) - } - - pub async fn load_logs() -> Result, Box> { - let file = match OpenOptions::new().read(true).open(&*LOGS_FILE_PATH) { - Ok(file) => file, - Err(e) if e.kind() == std::io::ErrorKind::NotFound => { - return Ok(Vec::new()); - } - Err(e) => return Err(Box::new(e)), - }; - - if file.metadata()?.len() > usize::MAX as u64 { - return Err("日志文件过大".into()); - } - - let mmap = unsafe { MmapOptions::new().map(&file)? }; - let archived = unsafe { rkyv::archived_root::>(&mmap) }; - Ok(archived.deserialize(&mut rkyv::Infallible)?) - } -} - -impl Default for AppState { - fn default() -> Self { - Self::new() - } -} - -impl AppState { - pub fn new() -> Self { - // 尝试加载保存的数据 - let (request_logs, mut token_manager) = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(async { - let logs = RequestStatsManager::load_logs().await.unwrap_or_default(); - let token_manager = TokenManager::load_tokens() - .await - .unwrap_or_else(|_| TokenManager::new(Vec::new())); - (logs, token_manager) - }) - }); - - // 查询缺失的 token profiles - tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(async { - for token_info in token_manager.tokens.iter_mut() { - if token_info.profile.is_none() { - token_info.profile = get_token_profile(&token_info.token).await; - } - } - }) - }); - - Self { - token_manager, - request_manager: RequestStatsManager::new(request_logs), - } - } - - pub async fn save_state(&self) -> Result<(), Box> { - // 并行保存 logs 和 tokens - let (logs_result, tokens_result) = tokio::join!( - self.request_manager.save_logs(), - self.token_manager.save_tokens() - ); - - logs_result?; - tokens_result?; - Ok(()) - } -} +use super::constant::{STATUS_FAILURE, STATUS_PENDING, STATUS_SUCCESS}; #[derive(Clone, Archive, RkyvDeserialize, RkyvSerialize)] pub enum LogStatus { Pending, Success, - Failed, + Failure, } impl Serialize for LogStatus { @@ -314,7 +41,7 @@ impl LogStatus { match self { Self::Pending => STATUS_PENDING, Self::Success => STATUS_SUCCESS, - Self::Failed => STATUS_FAILED, + Self::Failure => STATUS_FAILURE, } } @@ -322,7 +49,7 @@ impl LogStatus { match s { STATUS_PENDING => Some(Self::Pending), STATUS_SUCCESS => Some(Self::Success), - STATUS_FAILED => Some(Self::Failed), + STATUS_FAILURE => Some(Self::Failure), _ => None, } } @@ -336,7 +63,7 @@ pub struct RequestLog { pub model: String, pub token_info: TokenInfo, #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, + pub chain: Option, pub timing: TimingInfo, pub stream: bool, pub status: LogStatus, @@ -345,19 +72,16 @@ pub struct RequestLog { } #[derive(Serialize, Clone, Archive, RkyvDeserialize, RkyvSerialize)] -pub struct TimingInfo { - pub total: f64, // 总用时(秒) - #[serde(skip_serializing_if = "Option::is_none")] - pub first: Option, // 首字时间(秒) +pub struct Chain { + pub prompt: String, + pub delays: Vec<(String, f64)>, } -// 聊天请求 -#[derive(Deserialize)] -pub struct ChatRequest { - pub model: String, - pub messages: Vec, - #[serde(default)] - pub stream: bool, +#[derive(Serialize, Clone, Archive, RkyvDeserialize, RkyvSerialize)] +pub struct TimingInfo { + pub total: f64, // 总用时(秒) + // #[serde(skip_serializing_if = "Option::is_none")] + // pub first: Option, // 首字时间(秒) } // 用于存储 token 信息 @@ -370,6 +94,34 @@ pub struct TokenInfo { pub tags: Option>, } +impl TokenInfo { + /// 获取适用于此 token 的 HTTP 客户端 + /// + /// 如果 tags 中包含 "proxy" 标签,会尝试使用其后一个标签作为代理 URL + /// 例如: tags = ["proxy", "http://localhost:8080"] 将使用 http://localhost:8080 作为代理 + /// + /// 如果没有找到有效的代理配置,将返回默认客户端 + pub fn get_client(&self) -> Client { + // if let Some(tags) = &self.tags { + // // 查找 "proxy" 标签的位置 + // if let Some(proxy_index) = tags.iter().position(|tag| tag == "proxy") { + // // 检查是否存在下一个标签作为代理 URL + // if proxy_index + 1 < tags.len() { + // // 获取代理 URL 并尝试创建对应的客户端 + // return ProxyPool::get_client(&tags[proxy_index + 1]); + // } + // } + // } + // // 如果没有找到有效的代理配置,返回默认客户端 + // ProxyPool::get_general_client() + if let Some(tags) = &self.tags { + ProxyPool::get_client_or_general(tags.get(1).map(|s| s.as_str())) + } else { + ProxyPool::get_general_client() + } + } +} + // TokenUpdateRequest 结构体 #[derive(Deserialize)] pub struct TokenUpdateRequest { @@ -396,12 +148,12 @@ pub struct TokensDeleteRequest { #[serde(default)] pub tokens: Vec, #[serde(default)] - pub expectation: TokensDeleteResponseExpectation, + pub expectation: DeleteResponseExpectation, } #[derive(Deserialize, Default)] #[serde(rename_all = "snake_case")] -pub enum TokensDeleteResponseExpectation { +pub enum DeleteResponseExpectation { #[default] Simple, UpdatedTokens, @@ -409,20 +161,20 @@ pub enum TokensDeleteResponseExpectation { Detailed, } -impl TokensDeleteResponseExpectation { +impl DeleteResponseExpectation { pub fn needs_updated_tokens(&self) -> bool { matches!( self, - TokensDeleteResponseExpectation::UpdatedTokens - | TokensDeleteResponseExpectation::Detailed + DeleteResponseExpectation::UpdatedTokens + | DeleteResponseExpectation::Detailed ) } pub fn needs_failed_tokens(&self) -> bool { matches!( self, - TokensDeleteResponseExpectation::FailedTokens - | TokensDeleteResponseExpectation::Detailed + DeleteResponseExpectation::FailedTokens + | DeleteResponseExpectation::Detailed ) } } @@ -455,7 +207,7 @@ pub struct TokenTagsUpdateRequest { } #[derive(Serialize)] -pub struct TokenTagsResponse { +pub struct CommonResponse { pub status: ApiStatus, pub message: Option, } diff --git a/src/app/model/build_key.rs b/src/app/model/build_key.rs index 3882d35..018205b 100644 --- a/src/app/model/build_key.rs +++ b/src/app/model/build_key.rs @@ -6,6 +6,8 @@ use crate::{app::constant::COMMA, chat::constant::Models}; pub struct BuildKeyRequest { pub auth_token: String, #[serde(default)] + pub proxy_name: Option, + #[serde(default)] pub disable_vision: Option, #[serde(default)] pub enable_slow_pool: Option, @@ -14,6 +16,7 @@ pub struct BuildKeyRequest { #[serde(default)] pub include_web_references: Option, } + pub struct UsageCheckModelConfig { pub model_type: UsageCheckModelType, pub model_ids: Vec, diff --git a/src/app/model/config.rs b/src/app/model/config.rs index 500977d..f484de4 100644 --- a/src/app/model/config.rs +++ b/src/app/model/config.rs @@ -7,18 +7,15 @@ use crate::{ app::{ constant::{ EMPTY_STRING, ERR_INVALID_PATH, ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BUILD_KEY_PATH, - ROUTE_CONFIG_PATH, ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, - ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENS_PATH, + ROUTE_CONFIG_PATH, ROUTE_LOGS_PATH, ROUTE_PROXIES_PATH, ROUTE_README_PATH, + ROUTE_ROOT_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENS_PATH, }, lazy::CONFIG_FILE_PATH, }, - common::{ - client::rebuild_http_client, - utils::{parse_bool_from_env, parse_string_from_env}, - }, + common::utils::{parse_bool_from_env, parse_string_from_env}, }; -use super::{PageContent, Pages, Proxies, UsageCheck, VisionAbility}; +use super::{PageContent, Pages, UsageCheck, VisionAbility}; // 静态配置 #[derive(Default, Clone)] @@ -31,7 +28,6 @@ pub struct AppConfig { dynamic_key: bool, share_token: String, is_share: bool, - proxies: Proxies, web_refs: bool, } @@ -123,10 +119,6 @@ impl AppConfig { config.dynamic_key = parse_bool_from_env("DYNAMIC_KEY", false); config.share_token = parse_string_from_env("SHARED_TOKEN", EMPTY_STRING); config.is_share = !config.share_token.is_empty(); - config.proxies = match std::env::var("PROXIES") { - Ok(proxies) => Proxies::from_str(proxies.as_str()), - Err(_) => Proxies::default(), - }; config.web_refs = parse_bool_from_env("INCLUDE_WEB_REFERENCES", false) } @@ -164,35 +156,13 @@ impl AppConfig { } } - pub fn get_proxies() -> Proxies { - APP_CONFIG.read().proxies.clone() - } - - pub fn update_proxies(value: Proxies) { - let current = Self::get_proxies(); - if current != value { - let mut config = APP_CONFIG.write(); - config.proxies = value; - rebuild_http_client(); - } - } - - pub fn reset_proxies() { - let default_value = Proxies::default(); - let current = Self::get_proxies(); - if current != default_value { - let mut config = APP_CONFIG.write(); - config.proxies = default_value; - rebuild_http_client(); - } - } - pub fn get_page_content(path: &str) -> Option { match path { ROUTE_ROOT_PATH => Some(APP_CONFIG.read().pages.root_content.clone()), ROUTE_LOGS_PATH => Some(APP_CONFIG.read().pages.logs_content.clone()), ROUTE_CONFIG_PATH => Some(APP_CONFIG.read().pages.config_content.clone()), - ROUTE_TOKENS_PATH => Some(APP_CONFIG.read().pages.tokeninfo_content.clone()), + ROUTE_TOKENS_PATH => Some(APP_CONFIG.read().pages.tokens_content.clone()), + ROUTE_PROXIES_PATH => Some(APP_CONFIG.read().pages.proxies_content.clone()), ROUTE_SHARED_STYLES_PATH => Some(APP_CONFIG.read().pages.shared_styles_content.clone()), ROUTE_SHARED_JS_PATH => Some(APP_CONFIG.read().pages.shared_js_content.clone()), ROUTE_ABOUT_PATH => Some(APP_CONFIG.read().pages.about_content.clone()), @@ -209,7 +179,8 @@ impl AppConfig { ROUTE_ROOT_PATH => config.pages.root_content = content, ROUTE_LOGS_PATH => config.pages.logs_content = content, ROUTE_CONFIG_PATH => config.pages.config_content = content, - ROUTE_TOKENS_PATH => config.pages.tokeninfo_content = content, + ROUTE_TOKENS_PATH => config.pages.tokens_content = content, + ROUTE_PROXIES_PATH => config.pages.proxies_content = content, ROUTE_SHARED_STYLES_PATH => config.pages.shared_styles_content = content, ROUTE_SHARED_JS_PATH => config.pages.shared_js_content = content, ROUTE_ABOUT_PATH => config.pages.about_content = content, @@ -227,7 +198,8 @@ impl AppConfig { ROUTE_ROOT_PATH => config.pages.root_content = PageContent::default(), ROUTE_LOGS_PATH => config.pages.logs_content = PageContent::default(), ROUTE_CONFIG_PATH => config.pages.config_content = PageContent::default(), - ROUTE_TOKENS_PATH => config.pages.tokeninfo_content = PageContent::default(), + ROUTE_TOKENS_PATH => config.pages.tokens_content = PageContent::default(), + ROUTE_PROXIES_PATH => config.pages.proxies_content = PageContent::default(), ROUTE_SHARED_STYLES_PATH => config.pages.shared_styles_content = PageContent::default(), ROUTE_SHARED_JS_PATH => config.pages.shared_js_content = PageContent::default(), ROUTE_ABOUT_PATH => config.pages.about_content = PageContent::default(), diff --git a/src/app/model/proxies.rs b/src/app/model/proxies.rs deleted file mode 100644 index ba8cefd..0000000 --- a/src/app/model/proxies.rs +++ /dev/null @@ -1,81 +0,0 @@ -use reqwest::{Client, Proxy}; -use serde::{Deserialize, Deserializer}; -use serde::{Serialize, Serializer}; -// use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize}; - -use crate::app::constant::COMMA_STRING; - -#[derive(Clone, Default, PartialEq)] -pub enum Proxies { - No, - #[default] - System, - List(Vec), -} - -impl Serialize for Proxies { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match self { - Proxies::No => serializer.serialize_str(""), - Proxies::System => serializer.serialize_str("system"), - Proxies::List(urls) => serializer.serialize_str(&urls.join(COMMA_STRING)), - } - } -} - -impl<'de> Deserialize<'de> for Proxies { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let s = ::deserialize(deserializer)?; - Ok(Proxies::from_str(&s)) - } -} - -impl Proxies { - /// 从字符串创建 Proxies - /// - /// # Arguments - /// * `s` - 代理字符串: - /// - "" 或 "no": 不使用代理 - /// - "system": 使用系统代理 - /// - 其他: 尝试解析为代理列表,无效则返回 System - pub fn from_str(s: &str) -> Self { - match s.trim() { - "" | "no" => Self::No, - "system" => Self::System, - urls => { - let valid_proxies: Vec = urls - .split(',') - .filter_map(|url| { - let trimmed = url.trim(); - (!trimmed.is_empty() && Proxy::all(trimmed).is_ok()) - .then(|| trimmed.to_string()) - }) - .collect(); - - if valid_proxies.is_empty() { - Self::default() - } else { - Self::List(valid_proxies) - } - } - } - } - - pub fn get_client(&self) -> Client { - match self { - Proxies::No => Client::builder().no_proxy().build().unwrap(), - Proxies::System => Client::new(), - Proxies::List(list) => { - // 使用第一个代理(已经确保是有效的) - let proxy = Proxy::all(list[0].clone()).unwrap(); - Client::builder().proxy(proxy).build().unwrap() - } - } - } -} diff --git a/src/app/model/proxy.rs b/src/app/model/proxy.rs new file mode 100644 index 0000000..f55303e --- /dev/null +++ b/src/app/model/proxy.rs @@ -0,0 +1,53 @@ +use super::{ + ApiStatus, DeleteResponseExpectation, + proxy_pool::{Proxies, SingleProxy}, +}; +use serde::{Deserialize, Serialize}; + +// 代理信息响应 +#[derive(Serialize)] +pub struct ProxyInfoResponse { + pub status: ApiStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub proxies: Option, + pub proxies_count: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub general_proxy: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +// 更新代理配置请求 +#[derive(Deserialize)] +pub struct ProxyUpdateRequest { + pub proxies: Proxies, +} + +// 添加代理请求 +#[derive(Deserialize)] +pub struct ProxyAddRequest { + pub proxies: std::collections::HashMap, +} + +// 删除代理请求 +#[derive(Deserialize)] +pub struct ProxiesDeleteRequest { + #[serde(default)] + pub names: std::collections::HashSet, + #[serde(default)] + pub expectation: DeleteResponseExpectation, +} + +// 删除代理响应 +#[derive(Serialize)] +pub struct ProxiesDeleteResponse { + pub status: ApiStatus, + pub updated_proxies: Option, + pub failed_names: Option>, +} + +// 设置通用代理请求 +#[derive(Deserialize)] +pub struct SetGeneralProxyRequest { + pub name: String, +} diff --git a/src/app/model/proxy_pool.rs b/src/app/model/proxy_pool.rs new file mode 100644 index 0000000..9597c3f --- /dev/null +++ b/src/app/model/proxy_pool.rs @@ -0,0 +1,338 @@ +use memmap2::{MmapMut, MmapOptions}; +use parking_lot::RwLock; +use reqwest::{Client, Proxy}; +use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs::OpenOptions; +use std::str::FromStr; +use std::sync::LazyLock; + +mod proxy_url; +use super::super::lazy::PROXIES_FILE_PATH; +use proxy_url::UrlWrapper; + +// 恢复原来的常量定义 +pub const NO_PROXY: &str = "no"; +pub const EMPTY_PROXY: &str = ""; +pub const SYSTEM_PROXY: &str = "system"; +pub const DEFAULT_PROXY: &str = "default"; + +// 新的代理值常量 +pub const NON_PROXY: &str = "non"; +pub const SYS_PROXY: &str = "sys"; + +// 静态映射,将原来的值映射到新的值 +pub static PROXY_MAP: LazyLock> = LazyLock::new(|| { + let mut map = HashMap::new(); + map.insert(NO_PROXY, NON_PROXY); + map.insert(EMPTY_PROXY, NON_PROXY); // 空字符串也映射到NON_PROXY + map.insert(SYSTEM_PROXY, SYS_PROXY); + map.insert(DEFAULT_PROXY, SYS_PROXY); // DEFAULT_PROXY映射到SYS_PROXY + map +}); + +// 直接初始化PROXY_POOL为一个带有系统代理的基本实例 +pub static PROXY_POOL: LazyLock> = LazyLock::new(|| { + let mut clients = HashMap::new(); + + // 添加系统代理 + let system_client = Client::new(); + clients.insert(SYS_PROXY.to_string(), system_client.clone()); + + RwLock::new(ProxyPool { + clients, + general: Some(system_client), + }) +}); + +#[derive(Clone, Deserialize, Serialize, Archive, RkyvDeserialize, RkyvSerialize)] +pub struct Proxies { + // name to proxy + proxies: HashMap, + general: String, +} + +impl Default for Proxies { + fn default() -> Self { + Self::new() + } +} + +impl Proxies { + pub fn new() -> Self { + Self { + proxies: HashMap::from([(SYS_PROXY.to_string(), SingleProxy::Sys)]), + general: SYS_PROXY.to_string(), + } + } + + pub fn get_proxies(&self) -> &HashMap { + &self.proxies + } + + pub fn add_proxy(&mut self, name: String, proxy: SingleProxy) { + self.proxies.insert(name, proxy); + } + + pub fn remove_proxy(&mut self, name: &str) { + self.proxies.remove(name); + } + + pub fn set_general(&mut self, name: &str) { + if self.proxies.contains_key(name) { + self.general = name.to_string(); + } + } + + pub fn get_general(&self) -> &str { + &self.general + } + + // 更新全局代理池 + pub fn update_global_pool(&self) -> Result<(), Box> { + let mut pool = PROXY_POOL.write(); + + // 清除现有的客户端 + pool.clients.clear(); + + let proxies = self.get_proxies(); + if proxies.is_empty() { + // 添加系统代理 + let system_client = Client::new(); + pool.clients + .insert(SYS_PROXY.to_string(), system_client.clone()); + pool.general = Some(system_client); + return Ok(()); + } + + // 初始化客户端并设置第一个代理为通用客户端 + let mut first_name = None; + for (name, proxy) in proxies { + if first_name.is_none() { + first_name = Some(name.clone()); + } + + // 初始化客户端 + pool.append(name, &proxy); + } + + // 设置通用客户端 + if let Some(name) = first_name { + pool.general = pool.clients.get(&name).cloned(); + } else { + // 添加系统代理 + let system_client = Client::new(); + pool.clients + .insert(SYS_PROXY.to_string(), system_client.clone()); + pool.general = Some(system_client); + } + + Ok(()) + } + + pub async fn save_proxies(&self) -> Result<(), Box> { + let bytes = rkyv::to_bytes::<_, 256>(self)?; + + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(true) + .open(&*PROXIES_FILE_PATH)?; + + if bytes.len() > usize::MAX / 2 { + return Err("代理数据过大".into()); + } + + file.set_len(bytes.len() as u64)?; + let mut mmap = unsafe { MmapMut::map_mut(&file)? }; + mmap.copy_from_slice(&bytes); + mmap.flush()?; + + Ok(()) + } + + pub async fn load_proxies() -> Result> { + let file = match OpenOptions::new().read(true).open(&*PROXIES_FILE_PATH) { + Ok(file) => file, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + return Ok(Self::new()); + } + Err(e) => return Err(Box::new(e)), + }; + + if file.metadata()?.len() > usize::MAX as u64 { + return Err("代理文件过大".into()); + } + + let mmap = unsafe { MmapOptions::new().map(&file)? }; + let archived = unsafe { rkyv::archived_root::(&mmap) }; + Ok(archived.deserialize(&mut rkyv::Infallible)?) + } + + // 更新全局代理池并保存配置 + pub async fn update_and_save(&self) -> Result<(), Box> { + // 更新全局代理池 + self.update_global_pool()?; + + // 保存配置到文件 + self.save_proxies().await + } +} + +#[derive(Clone, Archive, RkyvDeserialize, RkyvSerialize)] +#[archive(compare(PartialEq))] +pub enum SingleProxy { + Non, + Sys, + Url(UrlWrapper), +} + +impl Serialize for SingleProxy { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Non => serializer.serialize_str(NON_PROXY), + Self::Sys => serializer.serialize_str(SYS_PROXY), + Self::Url(url) => serializer.serialize_str(&url.to_string()), + } + } +} + +impl<'de> Deserialize<'de> for SingleProxy { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct SingleProxyVisitor; + + impl<'de> serde::de::Visitor<'de> for SingleProxyVisitor { + type Value = SingleProxy; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string representing 'non', 'sys', or a valid URL") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // 检查是否是保留的代理名称,如果是则进行映射 + if let Some(&mapped) = PROXY_MAP.get(value) { + match mapped { + NON_PROXY => return Ok(Self::Value::Non), + SYS_PROXY => return Ok(Self::Value::Sys), + _ => {} + } + } + + // 直接匹配新的代理值 + match value { + NON_PROXY => Ok(Self::Value::Non), + SYS_PROXY => Ok(Self::Value::Sys), + url_str => url::Url::parse(url_str) + .map(|url| Self::Value::Url(UrlWrapper::from(url))) + .map_err(|e| E::custom(format!("Invalid URL: {}", e))), + } + } + } + + deserializer.deserialize_str(SingleProxyVisitor) + } +} + +impl ToString for SingleProxy { + fn to_string(&self) -> String { + match self { + Self::Non => NON_PROXY.to_string(), + Self::Sys => SYS_PROXY.to_string(), + Self::Url(url) => url.to_string(), + } + } +} + +impl FromStr for SingleProxy { + type Err = url::ParseError; + + fn from_str(s: &str) -> Result { + // 检查是否是保留的代理名称,如果是则进行映射 + if let Some(&mapped) = PROXY_MAP.get(s) { + match mapped { + NON_PROXY => return Ok(Self::Non), + SYS_PROXY => return Ok(Self::Sys), + _ => {} + } + } + + // 直接匹配新的代理值 + match s { + NON_PROXY => Ok(Self::Non), + SYS_PROXY => Ok(Self::Sys), + url_str => url::Url::parse(url_str).map(|url| Self::Url(UrlWrapper::from(url))), + } + } +} + +pub struct ProxyPool { + // name to client + clients: HashMap, + general: Option, +} + +impl ProxyPool { + // 添加客户端 + fn append(&mut self, name: &str, proxy: &SingleProxy) { + if self.clients.contains_key(name) { + return; + } + + // 根据SingleProxy类型创建客户端 + let client = match proxy { + SingleProxy::Non => Client::builder() + .no_proxy() + .build() + .expect("创建无代理客户端失败"), + SingleProxy::Sys => Client::new(), + SingleProxy::Url(url) => { + if let Ok(proxy_obj) = Proxy::all(&url.to_string()) { + Client::builder() + .proxy(proxy_obj) + .build() + .expect("创建代理客户端失败") + } else { + return; + } + } + }; + + self.clients.insert(name.to_string(), client); + } + + // 获取客户端 + pub fn get_client(url: &str) -> Client { + let pool = PROXY_POOL.read(); + + // 检查是否需要映射 + let mapped_url = PROXY_MAP.get(url).copied().unwrap_or(url); + + pool.clients + .get(mapped_url.trim()) + .cloned() + .unwrap_or_else(Self::get_general_client) + } + + pub fn get_general_client() -> Client { + let pool = PROXY_POOL.read(); + pool.general.clone().expect("获取通用客户端不应该失败") + } + + pub fn get_client_or_general(url: Option<&str>) -> Client { + match url { + Some(url) => Self::get_client(url), + None => Self::get_general_client(), + } + } +} diff --git a/src/app/model/proxy_pool/proxy_url.rs b/src/app/model/proxy_pool/proxy_url.rs new file mode 100644 index 0000000..84ce4bf --- /dev/null +++ b/src/app/model/proxy_pool/proxy_url.rs @@ -0,0 +1,60 @@ +use rkyv::{Archive, Deserialize, Serialize}; +use std::fmt; +use std::str::FromStr; + +/// 一个可以被Archive的URL包装器 +#[derive(Clone, Archive, Deserialize, Serialize)] +#[archive(compare(PartialEq))] +pub struct UrlWrapper(String); + +impl UrlWrapper { + pub fn new(url: &url::Url) -> Self { + Self(url.to_string()) + } + + pub fn into_url(self) -> Result { + url::Url::parse(&self.0) + } + + pub fn as_url(&self) -> Result { + url::Url::parse(&self.0) + } +} + +impl From for UrlWrapper { + fn from(url: url::Url) -> Self { + Self(url.to_string()) + } +} + +impl TryFrom for url::Url { + type Error = url::ParseError; + + fn try_from(wrapper: UrlWrapper) -> Result { + wrapper.into_url() + } +} + +impl fmt::Display for UrlWrapper { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromStr for UrlWrapper { + type Err = url::ParseError; + + fn from_str(s: &str) -> Result { + // 验证字符串是有效的URL + url::Url::parse(s)?; + Ok(Self(s.to_string())) + } +} + +impl PartialEq for UrlWrapper { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for UrlWrapper {} diff --git a/src/app/model/state.rs b/src/app/model/state.rs new file mode 100644 index 0000000..8934628 --- /dev/null +++ b/src/app/model/state.rs @@ -0,0 +1,285 @@ +use crate::common::utils::{generate_checksum_with_repair, get_token_profile}; +use memmap2::{MmapMut, MmapOptions}; +use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize}; +use serde::{Deserialize, Serialize}; +use std::{collections::HashSet, fs::OpenOptions}; + +use super::{ + super::lazy::{LOGS_FILE_PATH, TOKENS_FILE_PATH}, + LogStatus, RequestLog, TokenInfo, + proxy_pool::Proxies, +}; + +// 页面内容类型枚举 +#[derive(Clone, Serialize, Deserialize, Archive, RkyvDeserialize, RkyvSerialize)] +#[serde(tag = "type", content = "content")] +pub enum PageContent { + #[serde(rename = "default")] + Default, // 默认行为 + #[serde(rename = "text")] + Text(String), // 纯文本 + #[serde(rename = "html")] + Html(String), // HTML 内容 +} + +impl Default for PageContent { + fn default() -> Self { + Self::Default + } +} + +#[derive(Clone, Default, Archive, RkyvDeserialize, RkyvSerialize)] +pub struct Pages { + pub root_content: PageContent, + pub logs_content: PageContent, + pub config_content: PageContent, + pub tokens_content: PageContent, + pub proxies_content: PageContent, + pub shared_styles_content: PageContent, + pub shared_js_content: PageContent, + pub about_content: PageContent, + pub readme_content: PageContent, + pub api_content: PageContent, + pub build_key_content: PageContent, +} + +// Token管理器 +#[derive(Clone, Archive, RkyvDeserialize, RkyvSerialize)] +pub struct TokenManager { + pub tokens: Vec, + pub tags: HashSet, // 存储所有已使用的标签 +} + +// 请求统计管理器 +#[derive(Clone, Archive, RkyvDeserialize, RkyvSerialize)] +pub struct RequestStatsManager { + pub total_requests: u64, + pub active_requests: u64, + pub error_requests: u64, + pub request_logs: Vec, +} + +#[derive(Clone, Archive, RkyvDeserialize, RkyvSerialize)] +pub struct AppState { + pub token_manager: TokenManager, + pub request_manager: RequestStatsManager, + pub proxies: Proxies, +} + +impl TokenManager { + pub fn new(tokens: Vec) -> Self { + let mut tags = HashSet::new(); + for token in &tokens { + if let Some(token_tags) = &token.tags { + tags.extend(token_tags.iter().cloned()); + } + } + + Self { tokens, tags } + } + + pub fn update_global_tags(&mut self, new_tags: &[String]) { + // 将新标签添加到全局标签集合中 + self.tags.extend(new_tags.iter().cloned()); + } + + pub fn update_tokens_tags( + &mut self, + tokens: Vec, + new_tags: Vec, + ) -> Result<(), &'static str> { + // 创建tokens的HashSet用于快速查找 + let tokens_set: HashSet<_> = tokens.iter().collect(); + + // 更新指定tokens的标签 + for token_info in &mut self.tokens { + if tokens_set.contains(&token_info.token) { + token_info.tags = Some(new_tags.clone()); + } + } + + // 更新全局标签集合 + self.tags = self + .tokens + .iter() + .filter_map(|t| t.tags.clone()) + .flatten() + .collect(); + + Ok(()) + } + + pub fn get_tokens_by_tag(&self, tag: &str) -> Vec<&TokenInfo> { + self.tokens + .iter() + .filter(|t| { + t.tags + .as_ref() + .is_some_and(|tags| tags.contains(&tag.to_string())) + }) + .collect() + } + + pub fn update_checksum(&mut self) { + for token_info in self.tokens.iter_mut() { + token_info.checksum = generate_checksum_with_repair(&token_info.checksum); + } + } + + pub async fn save_tokens(&self) -> Result<(), Box> { + let bytes = rkyv::to_bytes::<_, 256>(self)?; + + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(true) + .open(&*TOKENS_FILE_PATH)?; + + if bytes.len() > usize::MAX / 2 { + return Err("Token数据过大".into()); + } + + file.set_len(bytes.len() as u64)?; + let mut mmap = unsafe { MmapMut::map_mut(&file)? }; + mmap.copy_from_slice(&bytes); + mmap.flush()?; + + Ok(()) + } + + pub async fn load_tokens() -> Result> { + let file = match OpenOptions::new().read(true).open(&*TOKENS_FILE_PATH) { + Ok(file) => file, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + return Ok(Self::new(Vec::new())); + } + Err(e) => return Err(Box::new(e)), + }; + + if file.metadata()?.len() > usize::MAX as u64 { + return Err("Token文件过大".into()); + } + + let mmap = unsafe { MmapOptions::new().map(&file)? }; + let archived = unsafe { rkyv::archived_root::(&mmap) }; + Ok(archived.deserialize(&mut rkyv::Infallible)?) + } +} + +impl RequestStatsManager { + pub fn new(request_logs: Vec) -> Self { + Self { + total_requests: request_logs.len() as u64, + active_requests: 0, + error_requests: request_logs + .iter() + .filter(|log| matches!(log.status, LogStatus::Failure)) + .count() as u64, + request_logs, + } + } + + pub async fn save_logs(&self) -> Result<(), Box> { + let bytes = rkyv::to_bytes::<_, 256>(&self.request_logs)?; + + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(true) + .open(&*LOGS_FILE_PATH)?; + + if bytes.len() > usize::MAX / 2 { + return Err("日志数据过大".into()); + } + + file.set_len(bytes.len() as u64)?; + let mut mmap = unsafe { MmapMut::map_mut(&file)? }; + mmap.copy_from_slice(&bytes); + mmap.flush()?; + + Ok(()) + } + + pub async fn load_logs() -> Result, Box> { + let file = match OpenOptions::new().read(true).open(&*LOGS_FILE_PATH) { + Ok(file) => file, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + return Ok(Vec::new()); + } + Err(e) => return Err(Box::new(e)), + }; + + if file.metadata()?.len() > usize::MAX as u64 { + return Err("日志文件过大".into()); + } + + let mmap = unsafe { MmapOptions::new().map(&file)? }; + let archived = unsafe { rkyv::archived_root::>(&mmap) }; + Ok(archived.deserialize(&mut rkyv::Infallible)?) + } +} + +impl Default for AppState { + fn default() -> Self { + Self::new() + } +} + +impl AppState { + pub fn new() -> Self { + // 尝试加载保存的数据 + let (request_logs, mut token_manager, proxies) = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + let logs = RequestStatsManager::load_logs().await.unwrap_or_default(); + let token_manager = TokenManager::load_tokens() + .await + .unwrap_or_else(|_| TokenManager::new(Vec::new())); + let proxies = Proxies::load_proxies() + .await + .unwrap_or_else(|_| Proxies::new()); + (logs, token_manager, proxies) + }) + }); + + // 查询缺失的 token profiles + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + for token_info in token_manager.tokens.iter_mut() { + if let Some(profile) = + get_token_profile(token_info.get_client(), &token_info.token).await + { + token_info.profile = Some(profile); + } + } + }) + }); + + // 更新全局代理池 + let proxies_clone = proxies.clone(); + if let Err(e) = proxies_clone.update_global_pool() { + eprintln!("更新全局代理池失败: {}", e); + } + + Self { + token_manager, + request_manager: RequestStatsManager::new(request_logs), + proxies, + } + } + + pub async fn save_state(&self) -> Result<(), Box> { + // 并行保存 logs、tokens 和 proxies + let (logs_result, tokens_result, proxies_result) = tokio::join!( + self.request_manager.save_logs(), + self.token_manager.save_tokens(), + self.proxies.save_proxies() + ); + + logs_result?; + tokens_result?; + proxies_result?; + Ok(()) + } +} diff --git a/src/chat.rs b/src/chat.rs index f46dddb..e02536d 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -3,7 +3,7 @@ pub mod aiserver; pub mod config; pub mod constant; pub mod error; -// pub mod middleware; +pub mod middleware; pub mod model; pub mod route; pub mod service; diff --git a/src/chat/adapter.rs b/src/chat/adapter.rs index 2c9c5e3..0bf36c4 100644 --- a/src/chat/adapter.rs +++ b/src/chat/adapter.rs @@ -7,10 +7,10 @@ use uuid::Uuid; use crate::{ app::{ constant::EMPTY_STRING, - lazy::DEFAULT_INSTRUCTIONS, - model::{AppConfig, VisionAbility}, + lazy::get_default_instructions, + model::{AppConfig, VisionAbility, proxy_pool::ProxyPool}, }, - common::{client::HTTP_CLIENT, utils::encode_message}, + common::utils::encode_message, }; use super::{ @@ -18,7 +18,10 @@ use super::{ AzureState, ChatExternalLink, ConversationMessage, ExplicitContext, GetChatRequest, ImageProto, ModelDetails, WebReference, conversation_message, image_proto, }, - constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS}, + constant::{ + ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS, + SUPPORTED_IMAGE_MODELS, + }, model::{Message, MessageContent, Role}, }; @@ -86,6 +89,7 @@ fn parse_web_references(text: &str) -> Vec { async fn process_chat_inputs( inputs: Vec, disable_vision: bool, + model_name: &str, ) -> (String, Vec, Vec) { // 收集 system 指令 let instructions = inputs @@ -110,7 +114,7 @@ async fn process_chat_inputs( // 使用默认指令或收集到的指令 let instructions = if instructions.is_empty() { - DEFAULT_INSTRUCTIONS.clone() + get_default_instructions() } else { instructions }; @@ -203,26 +207,6 @@ async fn process_chat_inputs( ); } - // 处理连续相同角色的情况 - let mut i = 1; - while i < chat_inputs.len() { - if chat_inputs[i].role == chat_inputs[i - 1].role { - let insert_role = if chat_inputs[i].role == Role::User { - Role::Assistant - } else { - Role::User - }; - chat_inputs.insert( - i, - Message { - role: insert_role, - content: MessageContent::Text(EMPTY_STRING.into()), - }, - ); - } - i += 1; - } - // 确保最后一条是 user if chat_inputs .last() @@ -236,6 +220,7 @@ async fn process_chat_inputs( // 转换为 proto messages let mut messages = Vec::new(); + let mut is_supported_model = None; for input in chat_inputs { let (text, images) = match input.content { MessageContent::Text(text) => (text, vec![]), @@ -251,10 +236,14 @@ async fn process_chat_inputs( } } "image_url" => { - if !disable_vision { + if is_supported_model.is_none() { + is_supported_model = + Some(SUPPORTED_IMAGE_MODELS.contains(&model_name)); + } + if !disable_vision && unsafe { is_supported_model.unwrap_unchecked() } { if let Some(image_url) = &content.image_url { let url = image_url.url.clone(); - let client = HTTP_CLIENT.read().clone(); + let client = ProxyPool::get_general_client(); let result = tokio::spawn(async move { fetch_image_data(&url, client).await }); @@ -349,15 +338,23 @@ async fn process_chat_inputs( while let Some(c) = chars.next() { if c == '@' { let mut url = String::new(); - while let Some(&next_char) = chars.peek() { + while let Some(next_char) = chars.peek() { if next_char.is_whitespace() { break; } - url.push(chars.next().unwrap()); + // 安全地获取下一个字符,避免使用unwrap() + if let Some(ch) = chars.next() { + url.push(ch); + } else { + break; + } } - if let Ok(parsed_url) = url::Url::parse(&url) { - if parsed_url.scheme() == "http" || parsed_url.scheme() == "https" { - urls.push(url); + // 只有当URL不为空时才尝试解析 + if !url.is_empty() { + if let Ok(parsed_url) = url::Url::parse(&url) { + if parsed_url.scheme() == "http" || parsed_url.scheme() == "https" { + urls.push(url); + } } } } @@ -488,7 +485,8 @@ pub async fn encode_chat_message( // 在进入异步操作前获取并释放锁 let enable_slow_pool = { if enable_slow_pool { Some(true) } else { None } }; - let (instructions, messages, urls) = process_chat_inputs(inputs, disable_vision).await; + let (instructions, messages, urls) = + process_chat_inputs(inputs, disable_vision, model_name).await; let explicit_context = if !instructions.trim().is_empty() { Some(ExplicitContext { @@ -523,7 +521,7 @@ pub async fn encode_chat_message( model_details: Some(ModelDetails { model_name: Some(model_name.to_string()), api_key: None, - enable_ghost_mode: None, + enable_ghost_mode: Some(true), azure_state: Some(AzureState { api_key: String::new(), base_url: String::new(), @@ -541,7 +539,7 @@ pub async fn encode_chat_message( allow_long_file_scan: Some(false), is_bash: Some(false), conversation_id: Uuid::new_v4().to_string(), - can_handle_filenames_after_language_ids: Some(true), + can_handle_filenames_after_language_ids: Some(false), use_web: if is_search { Some("full_search".to_string()) } else { @@ -559,7 +557,7 @@ pub async fn encode_chat_message( is_composer: None, runnable_code_blocks: Some(false), should_cache: Some(false), - allow_model_fallbacks: None, + allow_model_fallbacks: Some(false), number_of_times_shown_fallback_model_warning: None, }; diff --git a/src/chat/aiserver/v1/aiserver.v1.rs b/src/chat/aiserver/v1/aiserver.v1.rs index ddf6daf..84ba1d9 100644 --- a/src/chat/aiserver/v1/aiserver.v1.rs +++ b/src/chat/aiserver/v1/aiserver.v1.rs @@ -1,4 +1,3 @@ -// This file is @generated by prost-build. /// aiserver.v1.AvailableModelsRequest #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct AvailableModelsRequest { @@ -28,6 +27,8 @@ pub mod available_models_response { pub is_long_context_only: ::core::option::Option, #[prost(bool, optional, tag = "4")] pub is_chat_only: ::core::option::Option, + #[prost(bool, optional, tag = "5")] + pub supports_agent: ::core::option::Option, } } /// aiserver.v1.ErrorDetails diff --git a/src/chat/aiserver/v1/lite.proto b/src/chat/aiserver/v1/lite.proto index 3bbb490..e975e97 100644 --- a/src/chat/aiserver/v1/lite.proto +++ b/src/chat/aiserver/v1/lite.proto @@ -11,6 +11,7 @@ message AvailableModelsResponse { // aiserver.v1.AvailableModelsResponse bool default_on = 2; optional bool is_long_context_only = 3; optional bool is_chat_only = 4; + optional bool supports_agent = 5; } repeated AvailableModel models = 2; repeated string model_names = 1; diff --git a/src/chat/config/key.proto b/src/chat/config/key.proto index 2542c7d..7c5372e 100644 --- a/src/chat/config/key.proto +++ b/src/chat/config/key.proto @@ -12,6 +12,7 @@ message KeyConfig { string signature = 4; // 签名 bytes machine_id = 5; // 机器ID的SHA256哈希值 bytes mac_id = 6; // MAC地址的SHA256哈希值 + optional string proxy_name = 8; // 代理名称 } // 认证令牌(必需) diff --git a/src/chat/constant.rs b/src/chat/constant.rs index 56b9ce2..c3ce30d 100644 --- a/src/chat/constant.rs +++ b/src/chat/constant.rs @@ -1,13 +1,16 @@ use parking_lot::RwLock; -use std::{sync::Arc, time::{Duration, Instant}}; +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; use super::model::Model; macro_rules! def_pub_const { // 单个常量定义分支 - ($name:ident, $value:expr) => { - pub const $name: &'static str = $value; - }; + // ($name:ident, $value:expr) => { + // pub const $name: &'static str = $value; + // }; // 批量定义分支 ($($name:ident => $value:expr),+ $(,)?) => { @@ -46,8 +49,9 @@ def_pub_const!( CLAUDE_3_5_SONNET => "claude-3.5-sonnet", CLAUDE_3_HAIKU_200K => "claude-3-haiku-200k", CLAUDE_3_5_SONNET_200K => "claude-3-5-sonnet-200k", - CLAUDE_3_5_SONNET_20241022 => "claude-3-5-sonnet-20241022", CLAUDE_3_5_HAIKU => "claude-3.5-haiku", + CLAUDE_3_7_SONNET => "claude-3.7-sonnet", + CLAUDE_3_7_SONNET_THINKING => "claude-3.7-sonnet-thinking", // OpenAI 模型 GPT_4 => "gpt-4", @@ -60,6 +64,7 @@ def_pub_const!( O1_PREVIEW => "o1-preview", O1 => "o1", O3_MINI => "o3-mini", + GPT_4_5_PREVIEW => "gpt-4.5-preview", // Cursor 模型 CURSOR_FAST => "cursor-fast", @@ -78,6 +83,9 @@ def_pub_const!( // XAI 模型 GROK_2 => "grok-2", + + // 未知模型 + DEFAULT => "default", ); macro_rules! create_models { @@ -137,11 +145,7 @@ impl Models { // 返回所有模型 ID 的列表 pub fn ids() -> Vec { - Self::read() - .models - .iter() - .map(|m| m.id.clone()) - .collect() + Self::read().models.iter().map(|m| m.id.clone()).collect() } // 写入方法 @@ -154,12 +158,12 @@ impl Models { // 检查时间间隔(30分钟) if data.last_update.elapsed() < Duration::from_secs(30 * 60) { - return Err("Cannot update models more frequently than every 30 minutes"); + return Ok(()); } // 检查内容是否有变化 if *data.models == new_models { - return Err("No changes in models"); + return Ok(()); } // 更新数据和时间戳 @@ -177,8 +181,11 @@ impl Models { create_models!( CLAUDE_3_5_SONNET => ANTHROPIC, + CLAUDE_3_7_SONNET => ANTHROPIC, + CLAUDE_3_7_SONNET_THINKING => ANTHROPIC, GPT_4 => OPENAI, GPT_4O => OPENAI, + GPT_4_5_PREVIEW => OPENAI, CLAUDE_3_OPUS => ANTHROPIC, CURSOR_FAST => CURSOR, CURSOR_SMALL => CURSOR, @@ -200,11 +207,13 @@ create_models!( DEEPSEEK_R1 => DEEPSEEK, O3_MINI => OPENAI, GROK_2 => XAI, + DEFAULT => UNKNOWN, ); -pub const USAGE_CHECK_MODELS: [&str; 11] = [ - CLAUDE_3_5_SONNET_20241022, +pub const USAGE_CHECK_MODELS: [&str; 13] = [ CLAUDE_3_5_SONNET, + CLAUDE_3_7_SONNET, + CLAUDE_3_7_SONNET_THINKING, GEMINI_EXP_1206, GPT_4, GPT_4_TURBO_2024_04_09, @@ -214,6 +223,7 @@ pub const USAGE_CHECK_MODELS: [&str; 11] = [ GEMINI_1_5_FLASH_500K, CLAUDE_3_HAIKU_200K, CLAUDE_3_5_SONNET_200K, + DEEPSEEK_R1, ]; pub const LONG_CONTEXT_MODELS: [&str; 4] = [ @@ -222,3 +232,15 @@ pub const LONG_CONTEXT_MODELS: [&str; 4] = [ CLAUDE_3_HAIKU_200K, CLAUDE_3_5_SONNET_200K, ]; + +pub const SUPPORTED_IMAGE_MODELS: [&str; 9] = [ + CLAUDE_3_5_SONNET, + CLAUDE_3_7_SONNET, + CLAUDE_3_7_SONNET_THINKING, + GPT_4O, + GPT_4O_MINI, + DEFAULT, + CLAUDE_3_OPUS, + CLAUDE_3_5_HAIKU, + GPT_4, +]; diff --git a/src/chat/middleware/auth.rs b/src/chat/middleware/auth.rs index 46fc0bf..2be79c5 100644 --- a/src/chat/middleware/auth.rs +++ b/src/chat/middleware/auth.rs @@ -1,23 +1,45 @@ -use crate::app::{constant::AUTHORIZATION_BEARER_PREFIX, lazy::AUTH_TOKEN}; +use crate::{ + app::{constant::AUTHORIZATION_BEARER_PREFIX, lazy::AUTH_TOKEN}, + common::model::error::ChatError, +}; use axum::{ + Json, body::Body, http::{Request, StatusCode, header::AUTHORIZATION}, middleware::Next, - response::Response, + response::{IntoResponse, Response}, }; -// 认证中间件函数 -pub async fn auth_middleware(request: Request, next: Next) -> Result { +// 管理员认证中间件函数 +pub async fn admin_auth_middleware(request: Request, next: Next) -> Response { let auth_header = request .headers() .get(AUTHORIZATION) .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(StatusCode::UNAUTHORIZED)?; + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)); - if auth_header != AUTH_TOKEN.as_str() { - return Err(StatusCode::UNAUTHORIZED); + match auth_header { + Some(token) if token == AUTH_TOKEN.as_str() => next.run(request).await, + _ => ( + StatusCode::UNAUTHORIZED, + Json(ChatError::Unauthorized.to_json()), + ) + .into_response(), } - - Ok(next.run(request).await) } + +// 旧的认证中间件函数,保留向后兼容性 +// pub async fn auth_middleware(request: Request, next: Next) -> Result { +// let auth_header = request +// .headers() +// .get(AUTHORIZATION) +// .and_then(|h| h.to_str().ok()) +// .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) +// .ok_or(StatusCode::UNAUTHORIZED)?; + +// if auth_header != AUTH_TOKEN.as_str() { +// return Err(StatusCode::UNAUTHORIZED); +// } + +// Ok(next.run(request).await) +// } diff --git a/src/chat/model.rs b/src/chat/model.rs index 06271b5..0f3a270 100644 --- a/src/chat/model.rs +++ b/src/chat/model.rs @@ -50,8 +50,8 @@ pub struct ChatResponse { #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, + #[serde(skip_serializing_if = "TriState::is_none")] + pub usage: TriState, } #[derive(Serialize)] @@ -61,6 +61,7 @@ pub struct Choice { pub message: Option, #[serde(skip_serializing_if = "Option::is_none")] pub delta: Option, + pub logprobs: Option, pub finish_reason: Option, } @@ -79,6 +80,22 @@ pub struct Usage { pub total_tokens: u32, } +// 聊天请求 +#[derive(Deserialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + #[serde(default)] + pub stream: bool, + #[serde(default)] + pub stream_options: Option, +} + +#[derive(Deserialize)] +pub struct StreamOptions { + pub include_usage: bool, +} + // 模型定义 #[derive(Serialize, Clone)] pub struct Model { @@ -95,7 +112,7 @@ impl PartialEq for Model { } use super::constant::{Models, USAGE_CHECK_MODELS}; -use crate::app::model::{AppConfig, UsageCheck}; +use crate::{app::model::{AppConfig, UsageCheck}, common::model::tri::TriState}; impl Model { pub fn is_usage_check(model_id: &String, usage_check: Option) -> bool { diff --git a/src/chat/route.rs b/src/chat/route.rs index d876c47..0430976 100644 --- a/src/chat/route.rs +++ b/src/chat/route.rs @@ -3,20 +3,23 @@ pub use logs::{handle_logs, handle_logs_post}; mod health; pub use health::{handle_health, handle_root}; mod token; -pub use token::{handle_basic_calibration, handle_tokens_page}; +pub use token::{handle_basic_calibration, handle_build_key}; mod tokens; pub use tokens::{ handle_add_tokens, handle_delete_tokens, handle_get_tokens, handle_update_token_tags, - handle_update_tokens, + handle_update_tokens, handle_update_tokens_profile, }; mod checksum; pub use checksum::{handle_get_checksum, handle_get_hash, handle_get_timestamp_header}; mod profile; pub use profile::handle_user_info; -mod config; -pub use config::{ - handle_about, handle_build_key, handle_build_key_page, handle_config_page, handle_env_example, - handle_readme, handle_static, +mod proxies; +pub use proxies::{ + handle_add_proxy, handle_delete_proxies, handle_get_proxies, handle_set_general_proxy, + handle_update_proxies, +}; +mod page; +pub use page::{ + handle_about, handle_api_page, handle_build_key_page, handle_config_page, handle_env_example, + handle_proxies_page, handle_readme, handle_static, handle_tokens_page, }; -mod api; -pub use api::handle_api_page; diff --git a/src/chat/route/api.rs b/src/chat/route/api.rs deleted file mode 100644 index 7a96160..0000000 --- a/src/chat/route/api.rs +++ /dev/null @@ -1,29 +0,0 @@ -use axum::{ - body::Body, - response::{IntoResponse, Response}, -}; -use reqwest::header::CONTENT_TYPE; - -use crate::{ - AppConfig, PageContent, - app::constant::{ - CONTENT_TYPE_TEXT_HTML_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_API_PATH, - }, -}; - -pub async fn handle_api_page() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTE_API_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(Body::from(include_str!("../../../static/api.min.html"))) - .unwrap(), - PageContent::Text(content) => Response::builder() - .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(Body::from(content)) - .unwrap(), - PageContent::Html(content) => Response::builder() - .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(Body::from(content)) - .unwrap(), - } -} diff --git a/src/chat/route/health.rs b/src/chat/route/health.rs index 622c4f4..55027c5 100644 --- a/src/chat/route/health.rs +++ b/src/chat/route/health.rs @@ -5,9 +5,11 @@ use crate::{ CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BASIC_CALIBRATION_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, ROUTE_GET_HASH, ROUTE_GET_TIMESTAMP_HEADER, - ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, - ROUTE_STATIC_PATH, ROUTE_TOKEN_TAGS_UPDATE_PATH, ROUTE_TOKENS_ADD_PATH, - ROUTE_TOKENS_DELETE_PATH, ROUTE_TOKENS_GET_PATH, ROUTE_TOKENS_PATH, + ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, ROUTE_PROXIES_ADD_PATH, ROUTE_PROXIES_DELETE_PATH, + ROUTE_PROXIES_GET_PATH, ROUTE_PROXIES_PATH, ROUTE_PROXIES_SET_GENERAL_PATH, + ROUTE_PROXIES_UPDATE_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, + ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_DELETE_PATH, ROUTE_TOKENS_GET_PATH, + ROUTE_TOKENS_PATH, ROUTE_TOKENS_PROFILE_UPDATE_PATH, ROUTE_TOKENS_TAGS_UPDATE_PATH, ROUTE_TOKENS_UPDATE_PATH, ROUTE_USER_INFO_PATH, }, lazy::{AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH, get_start_time}, @@ -122,7 +124,14 @@ pub async fn handle_health( ROUTE_TOKENS_UPDATE_PATH, ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_DELETE_PATH, - ROUTE_TOKEN_TAGS_UPDATE_PATH, + ROUTE_TOKENS_TAGS_UPDATE_PATH, + ROUTE_TOKENS_PROFILE_UPDATE_PATH, + ROUTE_PROXIES_PATH, + ROUTE_PROXIES_GET_PATH, + ROUTE_PROXIES_UPDATE_PATH, + ROUTE_PROXIES_ADD_PATH, + ROUTE_PROXIES_DELETE_PATH, + ROUTE_PROXIES_SET_GENERAL_PATH, ROUTE_LOGS_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_CONFIG_PATH, diff --git a/src/chat/route/config.rs b/src/chat/route/page.rs similarity index 60% rename from src/chat/route/config.rs rename to src/chat/route/page.rs index f6e5677..44382ea 100644 --- a/src/chat/route/config.rs +++ b/src/chat/route/page.rs @@ -1,28 +1,21 @@ -use crate::{ - app::{ - constant::{ - AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_CSS_WITH_UTF8, - CONTENT_TYPE_TEXT_HTML_WITH_UTF8, CONTENT_TYPE_TEXT_JS_WITH_UTF8, - CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_ABOUT_PATH, ROUTE_BUILD_KEY_PATH, - ROUTE_CONFIG_PATH, ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, - }, - lazy::{AUTH_TOKEN, KEY_PREFIX}, - model::{AppConfig, BuildKeyRequest, BuildKeyResponse, PageContent, UsageCheckModelType}, +use crate::app::{ + constant::{ + CONTENT_TYPE_TEXT_CSS_WITH_UTF8, CONTENT_TYPE_TEXT_HTML_WITH_UTF8, + CONTENT_TYPE_TEXT_JS_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_ABOUT_PATH, + ROUTE_API_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH, ROUTE_PROXIES_PATH, + ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENS_PATH, }, - chat::config::{KeyConfig, key_config}, - common::utils::{to_base64, token_to_tokeninfo}, + model::{AppConfig, PageContent}, }; use axum::{ - Json, body::Body, extract::Path, http::{ - HeaderMap, StatusCode, - header::{AUTHORIZATION, CONTENT_TYPE, LOCATION}, + StatusCode, + header::{CONTENT_TYPE, LOCATION}, }, response::{IntoResponse, Response}, }; -use prost::Message as _; pub async fn handle_env_example() -> impl IntoResponse { Response::builder() @@ -140,74 +133,53 @@ pub async fn handle_build_key_page() -> impl IntoResponse { } } -pub async fn handle_build_key( - headers: HeaderMap, - Json(request): Json, -) -> (StatusCode, Json) { - // 验证认证令牌 - if AppConfig::is_share() { - let auth_header = headers - .get(AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)); - - if auth_header - .is_none_or(|h| h != AppConfig::get_share_token().as_str() && h != AUTH_TOKEN.as_str()) - { - return ( - StatusCode::UNAUTHORIZED, - Json(BuildKeyResponse::Error("Unauthorized".to_owned())), - ); - } +pub async fn handle_tokens_page() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_TOKENS_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(include_str!("../../../static/tokens.min.html"))) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(Body::from(content)) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(content)) + .unwrap(), + } +} + +pub async fn handle_proxies_page() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_PROXIES_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(include_str!("../../../static/proxies.min.html"))) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(Body::from(content)) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(content)) + .unwrap(), + } +} + +pub async fn handle_api_page() -> impl IntoResponse { + match AppConfig::get_page_content(ROUTE_API_PATH).unwrap_or_default() { + PageContent::Default => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(include_str!("../../../static/api.min.html"))) + .unwrap(), + PageContent::Text(content) => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) + .body(Body::from(content)) + .unwrap(), + PageContent::Html(content) => Response::builder() + .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) + .body(Body::from(content)) + .unwrap(), } - - // 验证并解析 auth_token - let token_info = match token_to_tokeninfo(&request.auth_token) { - Some(info) => info, - None => { - return ( - StatusCode::BAD_REQUEST, - Json(BuildKeyResponse::Error("Invalid auth token".to_owned())), - ); - } - }; - - // 构建 proto 消息 - let mut key_config = KeyConfig { - auth_token: Some(token_info), - disable_vision: request.disable_vision, - enable_slow_pool: request.enable_slow_pool, - usage_check_models: None, - include_web_references: request.include_web_references, - }; - - if let Some(usage_check_models) = request.usage_check_models { - let usage_check = key_config::UsageCheckModel { - r#type: match usage_check_models.model_type { - UsageCheckModelType::Default => key_config::usage_check_model::Type::Default as i32, - UsageCheckModelType::Disabled => { - key_config::usage_check_model::Type::Disabled as i32 - } - UsageCheckModelType::All => key_config::usage_check_model::Type::All as i32, - UsageCheckModelType::Custom => key_config::usage_check_model::Type::Custom as i32, - }, - model_ids: if matches!(usage_check_models.model_type, UsageCheckModelType::Custom) { - usage_check_models - .model_ids - .iter() - .map(|s| s.to_string()) - .collect() - } else { - Vec::new() - }, - }; - key_config.usage_check_models = Some(usage_check); - } - - // 序列化 - let encoded = key_config.encode_to_vec(); - - let key = format!("{}{}", *KEY_PREFIX, to_base64(&encoded)); - - (StatusCode::OK, Json(BuildKeyResponse::Key(key))) } diff --git a/src/chat/route/profile.rs b/src/chat/route/profile.rs index ae2fd81..ba55284 100644 --- a/src/chat/route/profile.rs +++ b/src/chat/route/profile.rs @@ -1,9 +1,8 @@ use crate::{ - chat::constant::ERR_NODATA, - common::{ + app::model::proxy_pool::ProxyPool, chat::constant::ERR_NODATA, common::{ model::userinfo::GetUserInfo, utils::{extract_token, get_token_profile}, - }, + } }; use axum::Json; @@ -28,7 +27,7 @@ pub async fn handle_user_info(Json(request): Json) -> Json Json(GetUserInfo::Usage(Box::new(usage))), None => Json(GetUserInfo::Error { error: ERR_NODATA.to_string(), diff --git a/src/chat/route/proxies.rs b/src/chat/route/proxies.rs new file mode 100644 index 0000000..f6edf29 --- /dev/null +++ b/src/chat/route/proxies.rs @@ -0,0 +1,260 @@ +use crate::{ + app::model::{ + AppState, CommonResponse, ProxiesDeleteRequest, ProxiesDeleteResponse, ProxyAddRequest, + ProxyInfoResponse, ProxyUpdateRequest, SetGeneralProxyRequest, + }, + common::model::{ApiStatus, ErrorResponse}, +}; +use axum::{Json, extract::State, http::StatusCode}; +use std::sync::Arc; +use tokio::sync::Mutex; + +// 获取所有代理配置 +pub async fn handle_get_proxies( + State(state): State>>, +) -> Result, StatusCode> { + // 获取代理配置并立即释放锁 + let proxies = { + let state = state.lock().await; + state.proxies.clone() + }; + + let proxies_count = proxies.get_proxies().len(); + let general_proxy = proxies.get_general().to_string(); + + Ok(Json(ProxyInfoResponse { + status: ApiStatus::Success, + proxies: Some(proxies), + proxies_count, + general_proxy: Some(general_proxy), + message: None, + })) +} + +// 更新代理配置 +pub async fn handle_update_proxies( + State(state): State>>, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + // 获取新的代理配置 + let proxies = request.proxies; + + // 更新全局代理池并保存配置 + if let Err(e) = proxies.update_and_save().await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Error, + code: None, + error: Some(format!("Failed to save proxy configuration: {}", e)), + message: Some("无法保存代理配置".to_string()), + }), + )); + } + + // 获取通用代理信息(在更新应用状态前) + let proxies_count = proxies.get_proxies().len(); + + // 只在需要更新应用状态时持有锁 + { + let mut state_guard = state.lock().await; + // 更新应用状态(完全覆盖) + state_guard.proxies = proxies; + } + + Ok(Json(ProxyInfoResponse { + status: ApiStatus::Success, + proxies: None, + proxies_count, + general_proxy: None, + message: Some("代理配置已更新".to_string()), + })) +} + +// 添加新的代理 +pub async fn handle_add_proxy( + State(state): State>>, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + // 获取当前的代理配置 + let mut proxies = { + let state_guard = state.lock().await; + state_guard.proxies.clone() + }; + + // 创建现有代理名称的集合 + let existing_proxies: std::collections::HashSet = + proxies.get_proxies().keys().cloned().collect(); + + // 处理新的代理 + let mut added_count = 0; + + for (name, proxy) in &request.proxies { + // 跳过已存在的代理 + if existing_proxies.contains(name) { + continue; + } + + // 直接添加新的代理 + proxies.add_proxy(name.clone(), proxy.clone()); + added_count += 1; + } + + // 如果有新代理才进行后续操作 + if added_count > 0 { + // 更新全局代理池并保存配置 + if let Err(e) = proxies.update_and_save().await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Error, + code: None, + error: Some(format!("Failed to save proxy configuration: {}", e)), + message: Some("无法保存代理配置".to_string()), + }), + )); + } + + // 获取更新后的信息 + let proxies_count = proxies.get_proxies().len(); + + // 更新应用状态,只在需要时持有锁 + { + let mut state_guard = state.lock().await; + state_guard.proxies = proxies.clone(); + } + + Ok(Json(ProxyInfoResponse { + status: ApiStatus::Success, + proxies: None, + proxies_count, + general_proxy: None, + message: Some(format!("已添加 {} 个新代理", added_count)), + })) + } else { + // 如果没有新代理,返回当前状态 + let general_proxy = proxies.get_general().to_string(); + let proxies_count = proxies.get_proxies().len(); + + Ok(Json(ProxyInfoResponse { + status: ApiStatus::Success, + proxies: Some(proxies), + proxies_count, + general_proxy: Some(general_proxy), + message: Some("没有添加新代理".to_string()), + })) + } +} + +// 删除指定的代理 +pub async fn handle_delete_proxies( + State(state): State>>, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + // 获取当前的代理配置并计算失败的代理名称 + let mut proxies = { + let state_guard = state.lock().await; + state_guard.proxies.clone() + }; + + // 计算失败的代理名称 + let failed_names: Vec = request + .names + .iter() + .filter(|name| !proxies.get_proxies().contains_key(*name)) + .cloned() + .collect(); + + // 删除指定的代理 + for name in &request.names { + proxies.remove_proxy(name); + } + + // 更新全局代理池并保存配置 + if let Err(e) = proxies.update_and_save().await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Error, + code: None, + error: Some(format!("Failed to save proxy configuration: {}", e)), + message: Some("无法保存代理配置".to_string()), + }), + )); + } + + // 更新应用状态,只在需要时持有锁 + { + let mut state_guard = state.lock().await; + state_guard.proxies = proxies.clone(); + } + + // 根据expectation返回不同的结果 + let updated_proxies = if request.expectation.needs_updated_tokens() { + Some(proxies) + } else { + None + }; + + Ok(Json(ProxiesDeleteResponse { + status: ApiStatus::Success, + updated_proxies, + failed_names: if request.expectation.needs_failed_tokens() && !failed_names.is_empty() { + Some(failed_names) + } else { + None + }, + })) +} + +// 设置通用代理 +pub async fn handle_set_general_proxy( + State(state): State>>, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + // 获取当前的代理配置 + let mut proxies = { + let state_guard = state.lock().await; + state_guard.proxies.clone() + }; + + // 检查代理名称是否存在 + if !proxies.get_proxies().contains_key(&request.name) { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + status: ApiStatus::Error, + code: None, + error: Some("Proxy name not found".to_string()), + message: Some("代理名称不存在".to_string()), + }), + )); + } + + // 设置通用代理 + proxies.set_general(&request.name); + + // 更新全局代理池并保存配置 + if let Err(e) = proxies.update_and_save().await { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Error, + code: None, + error: Some(format!("Failed to save proxy configuration: {}", e)), + message: Some("无法保存代理配置".to_string()), + }), + )); + } + + // 更新应用状态,只在需要时持有锁 + { + let mut state_guard = state.lock().await; + state_guard.proxies = proxies; + } + + Ok(Json(CommonResponse { + status: ApiStatus::Success, + message: Some("通用代理已设置".to_string()), + })) +} diff --git a/src/chat/route/token.rs b/src/chat/route/token.rs index eabcc21..e61507e 100644 --- a/src/chat/route/token.rs +++ b/src/chat/route/token.rs @@ -1,40 +1,25 @@ use crate::{ app::{ - constant::{ - CONTENT_TYPE_TEXT_HTML_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_TOKENS_PATH, - }, - model::{AppConfig, PageContent}, + constant::AUTHORIZATION_BEARER_PREFIX, + lazy::{AUTH_TOKEN, KEY_PREFIX}, + model::{AppConfig, BuildKeyRequest, BuildKeyResponse, UsageCheckModelType}, }, + chat::config::{KeyConfig, key_config}, common::{ model::ApiStatus, - utils::{extract_time, extract_time_ks, extract_user_id, validate_token_and_checksum}, + utils::{ + extract_time, extract_time_ks, extract_user_id, to_base64, token_to_tokeninfo, + validate_token_and_checksum, + }, }, }; use axum::{ Json, - body::Body, - http::header::CONTENT_TYPE, - response::{IntoResponse, Response}, + http::{HeaderMap, StatusCode, header::AUTHORIZATION}, }; +use prost::Message as _; use serde::{Deserialize, Serialize}; -pub async fn handle_tokens_page() -> impl IntoResponse { - match AppConfig::get_page_content(ROUTE_TOKENS_PATH).unwrap_or_default() { - PageContent::Default => Response::builder() - .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(Body::from(include_str!("../../../static/tokens.min.html"))) - .unwrap(), - PageContent::Text(content) => Response::builder() - .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8) - .body(Body::from(content)) - .unwrap(), - PageContent::Html(content) => Response::builder() - .header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8) - .body(Body::from(content)) - .unwrap(), - } -} - #[derive(Deserialize)] pub struct TokenRequest { pub token: Option, @@ -97,3 +82,75 @@ pub async fn handle_basic_calibration( checksum_time, }) } + +pub async fn handle_build_key( + headers: HeaderMap, + Json(request): Json, +) -> (StatusCode, Json) { + // 验证认证令牌 + if AppConfig::is_share() { + let auth_header = headers + .get(AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)); + + if auth_header + .is_none_or(|h| h != AppConfig::get_share_token().as_str() && h != AUTH_TOKEN.as_str()) + { + return ( + StatusCode::UNAUTHORIZED, + Json(BuildKeyResponse::Error("Unauthorized".to_owned())), + ); + } + } + + // 验证并解析 auth_token + let token_info = match token_to_tokeninfo(&request.auth_token, request.proxy_name) { + Some(info) => info, + None => { + return ( + StatusCode::BAD_REQUEST, + Json(BuildKeyResponse::Error("Invalid auth token".to_owned())), + ); + } + }; + + // 构建 proto 消息 + let mut key_config = KeyConfig { + auth_token: Some(token_info), + disable_vision: request.disable_vision, + enable_slow_pool: request.enable_slow_pool, + usage_check_models: None, + include_web_references: request.include_web_references, + }; + + if let Some(usage_check_models) = request.usage_check_models { + let usage_check = key_config::UsageCheckModel { + r#type: match usage_check_models.model_type { + UsageCheckModelType::Default => key_config::usage_check_model::Type::Default as i32, + UsageCheckModelType::Disabled => { + key_config::usage_check_model::Type::Disabled as i32 + } + UsageCheckModelType::All => key_config::usage_check_model::Type::All as i32, + UsageCheckModelType::Custom => key_config::usage_check_model::Type::Custom as i32, + }, + model_ids: if matches!(usage_check_models.model_type, UsageCheckModelType::Custom) { + usage_check_models + .model_ids + .iter() + .map(|s| s.to_string()) + .collect() + } else { + Vec::new() + }, + }; + key_config.usage_check_models = Some(usage_check); + } + + // 序列化 + let encoded = key_config.encode_to_vec(); + + let key = format!("{}{}", *KEY_PREFIX, to_base64(&encoded)); + + (StatusCode::OK, Json(BuildKeyResponse::Key(key))) +} diff --git a/src/chat/route/tokens.rs b/src/chat/route/tokens.rs index b754a1e..87e0493 100644 --- a/src/chat/route/tokens.rs +++ b/src/chat/route/tokens.rs @@ -1,44 +1,23 @@ use crate::{ - app::{ - constant::AUTHORIZATION_BEARER_PREFIX, - lazy::AUTH_TOKEN, - model::{ - AppState, TokenAddRequest, TokenInfo, TokenInfoResponse, TokenManager, - TokenTagsResponse, TokenTagsUpdateRequest, TokenUpdateRequest, TokensDeleteRequest, - TokensDeleteResponse, - }, + app::model::{ + AppState, CommonResponse, TokenAddRequest, TokenInfo, TokenInfoResponse, TokenManager, + TokenTagsUpdateRequest, TokenUpdateRequest, TokensDeleteRequest, TokensDeleteResponse, }, common::{ - model::{ApiStatus, ErrorResponse, error::ChatError, userinfo::TokenProfile}, + model::{ApiStatus, ErrorResponse, userinfo::TokenProfile}, utils::{ generate_checksum_with_default, generate_checksum_with_repair, load_tokens_from_content, parse_token, validate_token, }, }, }; -use axum::{ - Json, - extract::State, - http::{HeaderMap, StatusCode, header::AUTHORIZATION}, -}; +use axum::{Json, extract::State, http::StatusCode}; use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; pub async fn handle_get_tokens( State(state): State>>, - headers: HeaderMap, ) -> Result, StatusCode> { - // 验证 AUTH_TOKEN - let auth_header = headers - .get(AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(StatusCode::UNAUTHORIZED)?; - - if auth_header != AUTH_TOKEN.as_str() { - return Err(StatusCode::UNAUTHORIZED); - } - let state = state.lock().await; let tokens = state.token_manager.tokens.clone(); let tokens_count = tokens.len(); @@ -53,20 +32,8 @@ pub async fn handle_get_tokens( pub async fn handle_update_tokens( State(state): State>>, - headers: HeaderMap, Json(request): Json, ) -> Result, StatusCode> { - // 验证 AUTH_TOKEN - let auth_header = headers - .get(AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(StatusCode::UNAUTHORIZED)?; - - if auth_header != AUTH_TOKEN.as_str() { - return Err(StatusCode::UNAUTHORIZED); - } - // 获取当前的 token_manager 以保留现有 token 的 profile 和 tags let current_token_manager = { let state = state.lock().await; @@ -123,26 +90,8 @@ pub async fn handle_update_tokens( pub async fn handle_add_tokens( State(state): State>>, - headers: HeaderMap, Json(request): Json, ) -> Result, (StatusCode, Json)> { - // 验证 AUTH_TOKEN - let auth_header = headers - .get(AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - ))?; - - if auth_header != AUTH_TOKEN.as_str() { - return Err(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - )); - } - // 获取当前的 token_manager let mut token_manager = { let state = state.lock().await; @@ -212,12 +161,11 @@ pub async fn handle_add_tokens( })) } else { // 如果没有新tokens,返回当前状态 - let tokens = token_manager.tokens.clone(); - let tokens_count = tokens.len(); + let tokens_count = token_manager.tokens.len(); Ok(Json(TokenInfoResponse { status: ApiStatus::Success, - tokens: Some(tokens), + tokens: Some(token_manager.tokens), tokens_count, message: Some("No new tokens were added".to_string()), })) @@ -226,26 +174,8 @@ pub async fn handle_add_tokens( pub async fn handle_delete_tokens( State(state): State>>, - headers: HeaderMap, Json(request): Json, ) -> Result, (StatusCode, Json)> { - // 验证 AUTH_TOKEN - let auth_header = headers - .get(AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - ))?; - - if auth_header != AUTH_TOKEN.as_str() { - return Err(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - )); - } - // 获取当前的 token_manager let mut token_manager = { let state = state.lock().await; @@ -344,26 +274,8 @@ pub async fn handle_delete_tokens( pub async fn handle_update_token_tags( State(state): State>>, - headers: HeaderMap, Json(request): Json, -) -> Result, (StatusCode, Json)> { - // 验证 AUTH_TOKEN - let auth_header = headers - .get(AUTHORIZATION) - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX)) - .ok_or(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - ))?; - - if auth_header != AUTH_TOKEN.as_str() { - return Err(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - )); - } - +) -> Result, (StatusCode, Json)> { // 获取并更新 token_manager { let mut state = state.lock().await; @@ -396,8 +308,84 @@ pub async fn handle_update_token_tags( } } - Ok(Json(TokenTagsResponse { + Ok(Json(CommonResponse { status: ApiStatus::Success, message: Some("标签更新成功".to_string()), })) } + +pub async fn handle_update_tokens_profile( + State(state): State>>, + Json(tokens): Json>, +) -> Result, (StatusCode, Json)> { + // 验证请求 + if tokens.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + status: ApiStatus::Error, + code: None, + error: Some("No tokens provided".to_string()), + message: Some("未提供任何令牌".to_string()), + }), + )); + } + + // 获取当前的 token_manager + let mut state_guard = state.lock().await; + let token_manager = &mut state_guard.token_manager; + + // 批量更新tokens的profile + let mut updated_count = 0; + let mut failed_count = 0; + + for token in &tokens { + // 验证token是否在token_manager中存在 + if let Some(token_idx) = token_manager + .tokens + .iter() + .position(|info| info.token == *token) + { + // 获取profile + if let Some(profile) = crate::common::utils::get_token_profile( + token_manager.tokens[token_idx].get_client(), + token, + ) + .await + { + // 更新profile + token_manager.tokens[token_idx].profile = Some(profile); + updated_count += 1; + } else { + failed_count += 1; + } + } else { + failed_count += 1; + } + } + + // 保存更改 + if updated_count > 0 { + if token_manager.save_tokens().await.is_err() { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Error, + code: None, + error: Some("Failed to save token profiles".to_string()), + message: Some("无法保存令牌配置数据".to_string()), + }), + )); + } + } + + let message = format!( + "已更新{}个令牌配置, {}个令牌更新失败", + updated_count, failed_count + ); + + Ok(Json(CommonResponse { + status: ApiStatus::Success, + message: Some(message), + })) +} diff --git a/src/chat/service.rs b/src/chat/service.rs index d04c15a..4763d51 100644 --- a/src/chat/service.rs +++ b/src/chat/service.rs @@ -9,8 +9,8 @@ use crate::{ KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT, }, model::{ - AppConfig, AppState, ChatRequest, LogStatus, RequestLog, TimingInfo, TokenInfo, - UsageCheck, + AppConfig, AppState, Chain, LogStatus, RequestLog, TimingInfo, TokenInfo, UsageCheck, + proxy_pool::ProxyPool, }, }, chat::{ @@ -23,10 +23,12 @@ use crate::{ stream::{StreamDecoder, StreamMessage}, }, common::{ - client::build_client, - model::{ApiStatus, ErrorResponse, error::ChatError, userinfo::MembershipType}, + client::build_request, + model::{ + ApiStatus, ErrorResponse, error::ChatError, tri::TriState, userinfo::MembershipType, + }, utils::{ - TrimNewlines as _, format_time_ms, from_base64, get_available_models, + InstantExt as _, TrimNewlines as _, format_time_ms, from_base64, get_available_models, get_token_profile, tokeninfo_to_token, validate_token_and_checksum, }, }, @@ -44,6 +46,7 @@ use axum::{ use bytes::Bytes; use futures::StreamExt; use prost::Message as _; +use reqwest::Client; use std::sync::atomic::{AtomicUsize, Ordering}; use std::{ convert::Infallible, @@ -52,7 +55,7 @@ use std::{ use tokio::sync::Mutex; use uuid::Uuid; -use super::{constant::LONG_CONTEXT_MODELS, model::Model}; +use super::model::{ChatRequest, Model}; // 辅助函数:提取认证token fn extract_auth_token(headers: &HeaderMap) -> Result<&str, (StatusCode, Json)> { @@ -70,7 +73,7 @@ fn extract_auth_token(headers: &HeaderMap) -> Result<&str, (StatusCode, Json>, -) -> Result<(String, String), (StatusCode, Json)> { +) -> Result<(String, String, Client), (StatusCode, Json)> { match auth_header { // 管理员Token处理 token if is_admin_token(token) => resolve_admin_token(state).await, @@ -79,10 +82,13 @@ async fn resolve_token_info( token if is_dynamic_key(token) => resolve_dynamic_key(token), // 普通用户Token处理 - token => validate_token_and_checksum(token).ok_or(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - )), + token => { + let (token, checksum) = validate_token_and_checksum(token).ok_or(( + StatusCode::UNAUTHORIZED, + Json(ChatError::Unauthorized.to_json()), + ))?; + Ok((token, checksum, ProxyPool::get_general_client())) + } } } @@ -100,7 +106,7 @@ fn is_dynamic_key(token: &str) -> bool { // 辅助函数:处理管理员token async fn resolve_admin_token( state: &Arc>, -) -> Result<(String, String), (StatusCode, Json)> { +) -> Result<(String, String, Client), (StatusCode, Json)> { static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0); let state_guard = state.lock().await; @@ -116,11 +122,17 @@ async fn resolve_admin_token( let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len(); let token_info = &token_infos[index]; - Ok((token_info.token.clone(), token_info.checksum.clone())) + Ok(( + token_info.token.clone(), + token_info.checksum.clone(), + token_info.get_client(), + )) } // 辅助函数:处理动态密钥 -fn resolve_dynamic_key(token: &str) -> Result<(String, String), (StatusCode, Json)> { +fn resolve_dynamic_key( + token: &str, +) -> Result<(String, String, Client), (StatusCode, Json)> { from_base64(&token[*KEY_PREFIX_LEN..]) .and_then(|decoded_bytes| KeyConfig::decode(&decoded_bytes[..]).ok()) .and_then(|key_config| key_config.auth_token) @@ -143,18 +155,20 @@ pub async fn handle_models( // 提取和验证认证token let auth_token = extract_auth_token(&headers)?; - let (token, checksum) = resolve_token_info(auth_token, &state).await?; + let (token, checksum, client) = resolve_token_info(auth_token, &state).await?; // 获取可用模型列表 - let models = get_available_models(&token, &checksum).await.ok_or(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - status: ApiStatus::Failure, - code: Some(StatusCode::INTERNAL_SERVER_ERROR.as_u16()), - error: Some("Failed to fetch available models".to_string()), - message: Some("Unable to get available models".to_string()), - }), - ))?; + let models = get_available_models(client, &token, &checksum) + .await + .ok_or(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + status: ApiStatus::Failure, + code: Some(StatusCode::INTERNAL_SERVER_ERROR.as_u16()), + error: Some("Failed to fetch available models".to_string()), + message: Some("Unable to get available models".to_string()), + }), + ))?; // 更新模型列表 if let Err(e) = Models::update(models) { @@ -221,7 +235,7 @@ pub async fn handle_chat( let mut current_config = KeyConfig::new_with_global(); // 验证认证token并获取token信息 - let (auth_token, checksum) = match auth_header { + let (auth_token, checksum, client) = match auth_header { // 管理员Token验证逻辑 token if token == AUTH_TOKEN.as_str() @@ -242,7 +256,11 @@ pub async fn handle_chat( // 轮询选择token let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len(); let token_info = &token_infos[index]; - (token_info.token.clone(), token_info.checksum.clone()) + ( + token_info.token.clone(), + token_info.checksum.clone(), + token_info.get_client(), + ) } token if AppConfig::get_dynamic_key() && token.starts_with(&*KEY_PREFIX) => { @@ -260,10 +278,13 @@ pub async fn handle_chat( } // 普通用户Token验证逻辑 - token => validate_token_and_checksum(token).ok_or(( - StatusCode::UNAUTHORIZED, - Json(ChatError::Unauthorized.to_json()), - ))?, + token => { + let (token, checksum) = validate_token_and_checksum(token).ok_or(( + StatusCode::UNAUTHORIZED, + Json(ChatError::Unauthorized.to_json()), + ))?; + (token, checksum, ProxyPool::get_general_client()) + } }; let current_config = current_config; @@ -277,58 +298,32 @@ pub async fn handle_chat( state.request_manager.total_requests += 1; state.request_manager.active_requests += 1; - let mut found_count: u32 = 0; - let mut no_prompt_count: u32 = 0; let mut need_profile_check = false; for log in state.request_manager.request_logs.iter().rev() { if log.token_info.token == auth_token { - if !LONG_CONTEXT_MODELS.contains(&log.model.as_str()) { - found_count += 1; - } - - if log.prompt.is_none() { - no_prompt_count += 1; - } - - if found_count == 1 && log.token_info.profile.is_some() { - if let Some(profile) = &log.token_info.profile { - if profile.stripe.membership_type == MembershipType::Free { - let is_premium = USAGE_CHECK_MODELS.contains(&model_name.as_str()); - need_profile_check = - if is_premium { - profile.usage.premium.max_requests.is_some_and(|max| { - profile.usage.premium.num_requests >= max - }) - } else { - profile.usage.standard.max_requests.is_some_and(|max| { - profile.usage.standard.num_requests >= max - }) - }; - } + if let Some(profile) = &log.token_info.profile { + if profile.stripe.membership_type == MembershipType::Free { + let is_premium = USAGE_CHECK_MODELS.contains(&model_name.as_str()); + need_profile_check = if is_premium { + profile + .usage + .premium + .max_requests + .is_some_and(|max| profile.usage.premium.num_requests >= max) + } else { + profile + .usage + .standard + .max_requests + .is_some_and(|max| profile.usage.standard.num_requests >= max) + }; } - } - - if found_count == 2 { break; } } } - if found_count == 2 && no_prompt_count == 2 { - state.request_manager.active_requests -= 1; - state.request_manager.error_requests += 1; - return Err(( - StatusCode::TOO_MANY_REQUESTS, - Json(ErrorResponse { - status: ApiStatus::Error, - code: Some(429), - error: Some("rate_limit_exceeded".to_string()), - message: Some("Too many requests without prompt".to_string()), - }), - )); - } - // 处理检查结果 if need_profile_check { state.request_manager.active_requests -= 1; @@ -359,9 +354,10 @@ pub async fn handle_chat( let auth_token_clone = auth_token.clone(); let state_clone = state_clone.clone(); let log_id = next_id; + let client = client.clone(); tokio::spawn(async move { - let profile = get_token_profile(&auth_token_clone).await; + let profile = get_token_profile(client, &auth_token_clone).await; let mut state = state_clone.lock().await; // 先找到所有需要更新的位置的索引 @@ -404,11 +400,8 @@ pub async fn handle_chat( profile: None, tags: None, }, - prompt: None, - timing: TimingInfo { - total: 0.0, - first: None, - }, + chain: None, + timing: TimingInfo { total: 0.0 }, stream: request.stream, status: LogStatus::Pending, error: None, @@ -441,7 +434,7 @@ pub async fn handle_chat( .rev() .find(|log| log.id == current_id) { - log.status = LogStatus::Failed; + log.status = LogStatus::Failure; log.error = Some(e.to_string()); } state.request_manager.active_requests -= 1; @@ -456,7 +449,8 @@ pub async fn handle_chat( }; // 构建请求客户端 - let client = build_client( + let client = build_request( + client, &auth_token, &checksum, if is_search { @@ -492,7 +486,8 @@ pub async fn handle_chat( } resp } - Err(e) => { + Err(mut e) => { + e = e.without_url(); // 更新请求日志为失败 { let mut state = state.lock().await; @@ -503,7 +498,7 @@ pub async fn handle_chat( .rev() .find(|log| log.id == current_id) { - log.status = LogStatus::Failed; + log.status = LogStatus::Failure; log.error = Some(e.to_string()); } state.request_manager.active_requests -= 1; @@ -526,7 +521,7 @@ pub async fn handle_chat( .rev() .find(|log| log.id == current_id) { - log.status = LogStatus::Failed; + log.status = LogStatus::Failure; log.error = Some("Request timeout".to_string()); } state.request_manager.active_requests -= 1; @@ -551,18 +546,19 @@ pub async fn handle_chat( let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple()); let is_start = Arc::new(AtomicBool::new(true)); let start_time = std::time::Instant::now(); - let first_chunk_time = Arc::new(Mutex::new(None::)); let decoder = Arc::new(Mutex::new(StreamDecoder::new())); + let content_time = Arc::new(Mutex::new(std::time::Instant::now())); // 定义消息处理器的上下文结构体 struct MessageProcessContext<'a> { response_id: &'a str, model: &'a str, is_start: &'a AtomicBool, - first_chunk_time: &'a Mutex>, start_time: std::time::Instant, state: &'a Mutex, current_id: u64, + need_usage: bool, + content_time: &'a Mutex, } // 处理消息并生成响应数据的辅助函数 @@ -576,9 +572,26 @@ pub async fn handle_chat( match message { StreamMessage::Content(text) => { let is_first = ctx.is_start.load(Ordering::SeqCst); - if is_first { - if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() { - *first_time = Some(ctx.start_time.elapsed().as_secs_f64()); + + if let Ok(mut time_tracker) = ctx.content_time.try_lock() { + let interval = time_tracker.duration_as_secs_f64(); + if let Ok(mut state) = ctx.state.try_lock() { + if let Some(log) = state + .request_manager + .request_logs + .iter_mut() + .rev() + .find(|log| log.id == ctx.current_id) + { + if let Some(chain) = &mut log.chain { + chain.delays.push((text.clone(), interval)); + } else { + log.chain = Some(Chain { + prompt: String::new(), + delays: vec![(text.clone(), interval)], + }); + } + } } } @@ -607,9 +620,14 @@ pub async fn handle_chat( Some(text) }, }), + logprobs: None, finish_reason: None, }], - usage: None, + usage: if ctx.need_usage { + TriState::Null + } else { + TriState::None + }, }; response_data.push_str(&format!( @@ -620,7 +638,6 @@ pub async fn handle_chat( StreamMessage::StreamEnd => { // 计算总时间和首次片段时间 let total_time = ctx.start_time.elapsed().as_secs_f64(); - let first_time = ctx.first_chunk_time.lock().await.unwrap_or(total_time); { let mut state = ctx.state.lock().await; @@ -632,7 +649,6 @@ pub async fn handle_chat( .find(|log| log.id == ctx.current_id) { log.timing.total = format_time_ms(total_time); - log.timing.first = Some(format_time_ms(first_time)); } } @@ -648,14 +664,39 @@ pub async fn handle_chat( role: None, content: None, }), + logprobs: None, finish_reason: Some(FINISH_REASON_STOP.to_string()), }], - usage: None, + usage: if ctx.need_usage { + TriState::Null + } else { + TriState::None + }, }; response_data.push_str(&format!( - "data: {}\n\ndata: [DONE]\n\n", + "data: {}\n\n", serde_json::to_string(&response).unwrap() )); + if ctx.need_usage { + let response = ChatResponse { + id: ctx.response_id.to_string(), + object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), + created: chrono::Utc::now().timestamp(), + model: None, + choices: vec![], + usage: TriState::Some(Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }), + }; + response_data.push_str(&format!( + "data: {}\n\ndata: [DONE]\n\n", + serde_json::to_string(&response).unwrap() + )); + } else { + response_data.push_str("data: [DONE]\n\n"); + }; } StreamMessage::Debug(debug_prompt) => { if let Ok(mut state) = ctx.state.try_lock() { @@ -666,7 +707,10 @@ pub async fn handle_chat( .rev() .find(|log| log.id == ctx.current_id) { - log.prompt = Some(debug_prompt); + log.chain = Some(Chain { + prompt: debug_prompt, + delays: vec![], + }); } } } @@ -696,7 +740,7 @@ pub async fn handle_chat( .rev() .find(|log| log.id == current_id) { - log.status = LogStatus::Failed; + log.status = LogStatus::Failure; log.error = Some(error_response.native_code()); log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64()); @@ -727,7 +771,7 @@ pub async fn handle_chat( .rev() .find(|log| log.id == current_id) { - log.status = LogStatus::Failed; + log.status = LogStatus::Failure; log.error = Some("Empty stream response".to_string()); state.request_manager.error_requests += 1; } @@ -748,16 +792,18 @@ pub async fn handle_chat( let response_id = response_id.clone(); let model = request.model.clone(); let is_start = is_start.clone(); - let first_chunk_time = first_chunk_time.clone(); let state = state.clone(); + let need_usage = request.stream_options.is_some_and(|opt| opt.include_usage); + let content_time = content_time.clone(); move |chunk| { let decoder = decoder.clone(); let response_id = response_id.clone(); let model = model.clone(); let is_start = is_start.clone(); - let first_chunk_time = first_chunk_time.clone(); let state = state.clone(); + let need_usage = need_usage; + let content_time = content_time.clone(); async move { let chunk = chunk.unwrap_or_default(); @@ -766,10 +812,11 @@ pub async fn handle_chat( response_id: &response_id, model: &model, is_start: &is_start, - first_chunk_time: &first_chunk_time, start_time, state: &state, current_id, + need_usage, + content_time: &content_time, }; // 使用decoder处理chunk @@ -807,10 +854,12 @@ pub async fn handle_chat( } else { // 非流式响应 let start_time = std::time::Instant::now(); - let mut first_chunk_time = None::; let mut decoder = StreamDecoder::new(); let mut full_text = String::with_capacity(1024); let mut stream = response.bytes_stream(); + let mut prompt = String::new(); + let mut content_time = std::time::Instant::now(); + let mut delays: Vec<(String, f64)> = Vec::new(); // 逐个处理chunks while let Some(chunk) = stream.next().await { @@ -828,23 +877,12 @@ pub async fn handle_chat( for message in messages { match message { StreamMessage::Content(text) => { - if first_chunk_time.is_none() { - first_chunk_time = Some(start_time.elapsed().as_secs_f64()); - } + let interval = content_time.duration_as_secs_f64(); + delays.push((text.clone(), interval)); full_text.push_str(&text); } StreamMessage::Debug(debug_prompt) => { - if let Ok(mut state) = state.try_lock() { - if let Some(log) = state - .request_manager - .request_logs - .iter_mut() - .rev() - .find(|log| log.id == current_id) - { - log.prompt = Some(debug_prompt); - } - } + prompt = debug_prompt; } _ => {} } @@ -881,7 +919,7 @@ pub async fn handle_chat( .rev() .find(|log| log.id == current_id) { - log.status = LogStatus::Failed; + log.status = LogStatus::Failure; log.error = Some("Empty response received".to_string()); state.request_manager.error_requests += 1; } @@ -904,9 +942,10 @@ pub async fn handle_chat( content: MessageContent::Text(full_text.trim_leading_newlines()), }), delta: None, + logprobs: None, finish_reason: Some(FINISH_REASON_STOP.to_string()), }], - usage: Some(Usage { + usage: TriState::Some(Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, @@ -925,11 +964,23 @@ pub async fn handle_chat( .find(|log| log.id == current_id) { log.timing.total = total_time; - log.timing.first = first_chunk_time; log.status = LogStatus::Success; } } + // 更新最终的延迟信息 + if let Ok(mut state) = state.try_lock() { + if let Some(log) = state + .request_manager + .request_logs + .iter_mut() + .rev() + .find(|log| log.id == current_id) + { + log.chain = Some(Chain { prompt, delays }); + } + } + Ok(Response::builder() .header(CONTENT_TYPE, "application/json") .body(Body::from(serde_json::to_string(&response_data).unwrap())) diff --git a/src/common/client.rs b/src/common/client.rs index b05f36a..d9df1ab 100644 --- a/src/common/client.rs +++ b/src/common/client.rs @@ -1,23 +1,21 @@ use super::utils::generate_hash; -use crate::{ - AppConfig, - app::{ - constant::{ - CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, CURSOR_API2_HOST, CURSOR_HOST, - CURSOR_SETTINGS_URL, HEADER_NAME_GHOST_MODE, TRUE, - }, - lazy::{ - CURSOR_API2_STRIPE_URL, CURSOR_USAGE_API_URL, CURSOR_USER_API_URL, REVERSE_PROXY_HOST, - USE_REVERSE_PROXY, - }, +use crate::app::{ + constant::{ + CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, CURSOR_API2_HOST, CURSOR_HOST, + CURSOR_SETTINGS_URL, HEADER_NAME_GHOST_MODE, TRUE, + }, + lazy::{ + CURSOR_API2_STRIPE_URL, CURSOR_TIMEZONE, CURSOR_USAGE_API_URL, CURSOR_USER_API_URL, + REVERSE_PROXY_HOST, USE_REVERSE_PROXY, }, }; -use reqwest::header::{ - ACCEPT, ACCEPT_ENCODING, ACCEPT_LANGUAGE, CACHE_CONTROL, CONNECTION, CONTENT_TYPE, COOKIE, DNT, - HOST, ORIGIN, PRAGMA, REFERER, TE, TRANSFER_ENCODING, USER_AGENT, +use reqwest::{ + Client, RequestBuilder, + header::{ + ACCEPT, ACCEPT_ENCODING, ACCEPT_LANGUAGE, CACHE_CONTROL, CONNECTION, CONTENT_TYPE, COOKIE, + DNT, HOST, ORIGIN, PRAGMA, REFERER, TE, TRANSFER_ENCODING, USER_AGENT, + }, }; -use reqwest::{Client, RequestBuilder}; -use std::sync::LazyLock; use uuid::Uuid; macro_rules! def_const { @@ -50,18 +48,6 @@ def_const!(U_EQ_4, "u=4"); def_const!(PROXY_HOST, "x-co"); -pub(crate) static HTTP_CLIENT: LazyLock> = - LazyLock::new(|| parking_lot::RwLock::new(AppConfig::get_proxies().get_client())); - -/// 重新构建 HTTP 客户端 -/// -/// 当需要更新代理设置时,可以调用此方法重新创建客户端 -pub fn rebuild_http_client() { - let new_client = AppConfig::get_proxies().get_client(); - let mut client = HTTP_CLIENT.write(); - *client = new_client; -} - /// 返回预构建的 Cursor API 客户端 /// /// # 参数 @@ -73,7 +59,8 @@ pub fn rebuild_http_client() { /// # 返回 /// /// * `reqwest::RequestBuilder` - 配置好的请求构建器 -pub fn build_client( +pub fn build_request( + client: Client, auth_token: &str, checksum: &str, url: &str, @@ -82,13 +69,12 @@ pub fn build_client( let trace_id = Uuid::new_v4().to_string(); let client = if *USE_REVERSE_PROXY { - HTTP_CLIENT - .read() + client .post(url) .header(HOST, &*REVERSE_PROXY_HOST) .header(PROXY_HOST, CURSOR_API2_HOST) } else { - HTTP_CLIENT.read().post(url).header(HOST, CURSOR_API2_HOST) + client.post(url).header(HOST, CURSOR_API2_HOST) }; client @@ -108,7 +94,7 @@ pub fn build_client( .header("x-client-key", generate_hash()) .header("x-cursor-checksum", checksum) .header("x-cursor-client-version", "0.42.5") - .header("x-cursor-timezone", "Asia/Shanghai") + .header("x-cursor-timezone", &*CURSOR_TIMEZONE) .header(HEADER_NAME_GHOST_MODE, TRUE) .header("x-request-id", trace_id) .header(CONNECTION, KEEP_ALIVE) @@ -124,16 +110,14 @@ pub fn build_client( /// # 返回 /// /// * `reqwest::RequestBuilder` - 配置好的请求构建器 -pub fn build_profile_client(auth_token: &str) -> RequestBuilder { +pub fn build_profile_request(client: &Client, auth_token: &str) -> RequestBuilder { let client = if *USE_REVERSE_PROXY { - HTTP_CLIENT - .read() + client .get(&*CURSOR_API2_STRIPE_URL) .header(HOST, &*REVERSE_PROXY_HOST) .header(PROXY_HOST, CURSOR_API2_HOST) } else { - HTTP_CLIENT - .read() + client .get(&*CURSOR_API2_STRIPE_URL) .header(HOST, CURSOR_API2_HOST) }; @@ -168,20 +152,16 @@ pub fn build_profile_client(auth_token: &str) -> RequestBuilder { /// # 返回 /// /// * `reqwest::RequestBuilder` - 配置好的请求构建器 -pub fn build_usage_client(user_id: &str, auth_token: &str) -> RequestBuilder { +pub fn build_usage_request(client: &Client, user_id: &str, auth_token: &str) -> RequestBuilder { let session_token = format!("{}%3A%3A{}", user_id, auth_token); let client = if *USE_REVERSE_PROXY { - HTTP_CLIENT - .read() + client .get(&*CURSOR_USAGE_API_URL) .header(HOST, &*REVERSE_PROXY_HOST) .header(PROXY_HOST, CURSOR_HOST) } else { - HTTP_CLIENT - .read() - .get(&*CURSOR_USAGE_API_URL) - .header(HOST, CURSOR_HOST) + client.get(&*CURSOR_USAGE_API_URL).header(HOST, CURSOR_HOST) }; client @@ -217,20 +197,16 @@ pub fn build_usage_client(user_id: &str, auth_token: &str) -> RequestBuilder { /// # 返回 /// /// * `reqwest::RequestBuilder` - 配置好的请求构建器 -pub fn build_userinfo_client(user_id: &str, auth_token: &str) -> RequestBuilder { +pub fn build_userinfo_request(client: &Client, user_id: &str, auth_token: &str) -> RequestBuilder { let session_token = format!("{}%3A%3A{}", user_id, auth_token); let client = if *USE_REVERSE_PROXY { - HTTP_CLIENT - .read() + client .get(&*CURSOR_USER_API_URL) .header(HOST, &*REVERSE_PROXY_HOST) .header(PROXY_HOST, CURSOR_HOST) } else { - HTTP_CLIENT - .read() - .get(&*CURSOR_USER_API_URL) - .header(HOST, CURSOR_HOST) + client.get(&*CURSOR_USER_API_URL).header(HOST, CURSOR_HOST) }; client diff --git a/src/common/model.rs b/src/common/model.rs index d4eab61..9a38c67 100644 --- a/src/common/model.rs +++ b/src/common/model.rs @@ -3,6 +3,7 @@ pub mod error; pub mod health; pub mod token; pub mod userinfo; +pub mod tri; use config::ConfigData; diff --git a/src/common/model/config.rs b/src/common/model/config.rs index e1b0f34..fdf5874 100644 --- a/src/common/model/config.rs +++ b/src/common/model/config.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::app::model::{PageContent, Proxies, UsageCheck, VisionAbility}; +use crate::app::model::{PageContent, UsageCheck, VisionAbility}; #[derive(Serialize)] pub struct ConfigData { @@ -12,7 +12,6 @@ pub struct ConfigData { pub enable_dynamic_key: bool, #[serde(skip_serializing_if = "String::is_empty")] pub share_token: String, - pub proxies: Proxies, pub include_web_references: bool, } @@ -28,6 +27,5 @@ pub struct ConfigUpdateRequest { pub usage_check_models: Option, pub enable_dynamic_key: Option, pub share_token: Option, - pub proxies: Option, pub include_web_references: Option, } diff --git a/src/common/model/tri.rs b/src/common/model/tri.rs new file mode 100644 index 0000000..551d3f1 --- /dev/null +++ b/src/common/model/tri.rs @@ -0,0 +1,152 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq)] +pub enum TriState { + None, + Null, + Some(T), +} + +impl TriState { + // pub fn is_some(&self) -> bool { + // matches!(self, TriState::Some(_)) + // } + + // pub fn is_null(&self) -> bool { + // matches!(self, TriState::Null) + // } + + pub fn is_none(&self) -> bool { + matches!(self, TriState::None) + } +} + +impl Default for TriState { + fn default() -> Self { + TriState::None + } +} + +impl Serialize for TriState +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + TriState::None => serializer.serialize_none(), + TriState::Null => serializer.serialize_unit(), + TriState::Some(value) => value.serialize(serializer), + } + } +} + +impl<'de, T> Deserialize<'de> for TriState +where + T: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let opt = Option::::deserialize(deserializer); + + match opt { + Ok(Some(value)) => Ok(TriState::Some(value)), + Ok(None) => Ok(TriState::Null), + Err(_) => Ok(TriState::None), + } + } +} + +impl From> for TriState { + fn from(option: Option) -> Self { + match option { + Some(value) => TriState::Some(value), + None => TriState::Null, + } + } +} + +#[derive(Serialize)] +#[serde(transparent)] +pub struct TriStateField { + #[serde(skip_serializing_if = "TriState::is_none")] + pub value: TriState, +} + +impl From> for TriStateField { + fn from(value: TriState) -> Self { + TriStateField { value } + } +} + +impl From> for TriState { + fn from(field: TriStateField) -> Self { + field.value + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct TestStruct { + required: String, + optional: Option, + #[serde(skip_serializing_if = "TriState::is_none")] + tristate: TriState, + } + + #[test] + fn test_tristate_serialization() { + // 创建三个测试结构体,分别包含不同状态的TriState + let test_none = TestStruct { + required: "必填字段".to_string(), + optional: Some("可选字段".to_string()), + tristate: TriState::None, + }; + + let test_null = TestStruct { + required: "必填字段".to_string(), + optional: None, + tristate: TriState::Null, + }; + + let test_some = TestStruct { + required: "必填字段".to_string(), + optional: Some("可选字段".to_string()), + tristate: TriState::Some("三态字段".to_string()), + }; + + // 序列化并打印结果 + println!("TriState::None 序列化结果:"); + println!("{}", serde_json::to_string_pretty(&test_none).unwrap()); + println!(); + + println!("TriState::Null 序列化结果:"); + println!("{}", serde_json::to_string_pretty(&test_null).unwrap()); + println!(); + + println!("TriState::Some 序列化结果:"); + println!("{}", serde_json::to_string_pretty(&test_some).unwrap()); + println!(); + + // 验证序列化行为 + let json_none = serde_json::to_string(&test_none).unwrap(); + let json_null = serde_json::to_string(&test_null).unwrap(); + let json_some = serde_json::to_string(&test_some).unwrap(); + + // TriState::None 不应该在JSON中出现 + assert!(!json_none.contains("tristate")); + + // TriState::Null 应该在JSON中出现为null + assert!(json_null.contains("\"tristate\":null")); + + // TriState::Some 应该在JSON中出现为具体值 + assert!(json_some.contains("\"tristate\":\"三态字段\"")); + } +} diff --git a/src/common/utils.rs b/src/common/utils.rs index 2e45c02..2111e74 100644 --- a/src/common/utils.rs +++ b/src/common/utils.rs @@ -1,8 +1,11 @@ mod checksum; +use std::time::Instant; + use ::base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; pub use checksum::*; mod token; use prost::Message as _; +use reqwest::Client; pub use token::*; mod base64; pub use base64::*; @@ -15,6 +18,7 @@ use crate::{ app::{ constant::{COMMA, FALSE, TRUE}, lazy::{CURSOR_API2_CHAT_MODELS_URL, TOKEN_DELIMITER, USE_COMMA_DELIMITER}, + model::proxy_pool::ProxyPool, }, chat::{ aiserver::v1::{AvailableModelsRequest, AvailableModelsResponse}, @@ -81,18 +85,32 @@ impl TrimNewlines for String { } } -pub async fn get_token_profile(auth_token: &str) -> Option { +pub trait InstantExt { + fn duration_as_secs_f64(&mut self) -> f64; +} + +impl InstantExt for Instant { + #[inline] + fn duration_as_secs_f64(&mut self) -> f64 { + let now = Instant::now(); + let duration = now.duration_since(*self); + *self = now; + duration.as_secs_f64() + } +} + +pub async fn get_token_profile(client: Client, auth_token: &str) -> Option { let user_id = extract_user_id(auth_token)?; // 构建请求客户端 - let client = super::client::build_usage_client(&user_id, auth_token); + let request = super::client::build_usage_request(&client, &user_id, auth_token); // 发送请求并获取响应 // let response = client.send().await.ok()?; // let bytes = response.bytes().await?; // println!("Raw response bytes: {:?}", bytes); // let usage = serde_json::from_str::(&text).ok()?; - let usage = client + let usage = request .send() .await .ok()? @@ -100,10 +118,10 @@ pub async fn get_token_profile(auth_token: &str) -> Option { .await .ok()?; - let user = get_user_profile(auth_token).await?; + let user = get_user_profile(&client, auth_token).await?; // 从 Stripe 获取用户资料 - let stripe = get_stripe_profile(auth_token).await?; + let stripe = get_stripe_profile(&client, auth_token).await?; // 映射响应数据到 TokenProfile Some(TokenProfile { @@ -113,8 +131,8 @@ pub async fn get_token_profile(auth_token: &str) -> Option { }) } -pub async fn get_stripe_profile(auth_token: &str) -> Option { - let client = super::client::build_profile_client(auth_token); +pub async fn get_stripe_profile(client: &Client, auth_token: &str) -> Option { + let client = super::client::build_profile_request(client, auth_token); let response = client .send() .await @@ -125,11 +143,11 @@ pub async fn get_stripe_profile(auth_token: &str) -> Option { Some(response) } -pub async fn get_user_profile(auth_token: &str) -> Option { +pub async fn get_user_profile(client: &Client, auth_token: &str) -> Option { let user_id = extract_user_id(auth_token)?; // 构建请求客户端 - let client = super::client::build_userinfo_client(&user_id, auth_token); + let client = super::client::build_userinfo_request(client, &user_id, auth_token); // 发送请求并获取响应 let user_profile = client.send().await.ok()?.json::().await.ok()?; @@ -137,9 +155,18 @@ pub async fn get_user_profile(auth_token: &str) -> Option { Some(user_profile) } -pub async fn get_available_models(auth_token: &str, checksum: &str) -> Option> { - let client = - super::client::build_client(auth_token, checksum, &CURSOR_API2_CHAT_MODELS_URL, false); +pub async fn get_available_models( + client: Client, + auth_token: &str, + checksum: &str, +) -> Option> { + let client = super::client::build_request( + client, + auth_token, + checksum, + &CURSOR_API2_CHAT_MODELS_URL, + false, + ); let request = AvailableModelsRequest { is_nightly: true, include_long_context_models: true, @@ -180,7 +207,10 @@ pub async fn get_available_models(auth_token: &str, checksum: &str) -> Option CURSOR, // c + u → "cu" (cursor) _ => UNKNOWN, }, - Some('d') if chars.next() == Some('e') => DEEPSEEK, // d + e → "de" (deepseek) + Some('d') => match chars.next() { + Some('e') if chars.next() == Some('e') => DEEPSEEK, // d + e + e → "dee" (deepseek) + _ => UNKNOWN, + }, // 其他情况 _ => UNKNOWN, } @@ -274,7 +304,10 @@ pub fn format_time_ms(seconds: f64) -> f64 { use crate::chat::config::key_config; /// 将 JWT token 转换为 TokenInfo -pub fn token_to_tokeninfo(auth_token: &str) -> Option { +pub fn token_to_tokeninfo( + auth_token: &str, + proxy_name: Option, +) -> Option { let (token, checksum) = validate_token_and_checksum(auth_token)?; // JWT token 由3部分组成,用 . 分隔 @@ -311,11 +344,12 @@ pub fn token_to_tokeninfo(auth_token: &str) -> Option { signature: parts[2].to_string(), machine_id: machine_id_hash, mac_id: mac_id_hash, + proxy_name, }) } /// 将 TokenInfo 转换为 JWT token -pub fn tokeninfo_to_token(info: &key_config::TokenInfo) -> Option<(String, String)> { +pub fn tokeninfo_to_token(info: &key_config::TokenInfo) -> Option<(String, String, Client)> { // 构建 payload let payload = TokenPayload { sub: info.sub.clone(), @@ -342,10 +376,13 @@ pub fn tokeninfo_to_token(info: &key_config::TokenInfo) -> Option<(String, Strin None }; + let client = ProxyPool::get_client_or_general(info.proxy_name.as_deref()); + // 组合 token Some(( format!("{}.{}.{}", HEADER_B64, payload_b64, info.signature), generate_checksum(&device_id, mac_addr.as_deref()), + client, )) } diff --git a/src/common/utils/checksum.rs b/src/common/utils/checksum.rs index 44e1657..df38e27 100644 --- a/src/common/utils/checksum.rs +++ b/src/common/utils/checksum.rs @@ -1,4 +1,4 @@ -use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as BASE64}; use rand::Rng as _; use sha2::{Digest, Sha256}; @@ -72,7 +72,7 @@ pub fn generate_checksum_with_repair(checksum: &str) -> String { for (i, &b) in bytes.iter().enumerate() { let valid = match (len, i) { // 通用字符校验(排除非法字符) - (_, _) if !b.is_ascii_alphanumeric() && b != b'/' && b != b'+' && b != b'=' => false, + (_, _) if !b.is_ascii_alphanumeric() && b != b'/' && b != b'-' && b != b'_' => false, // 72字节格式:时间戳(8) + 设备哈希(64) (72, 8..=71) => b.is_ascii_hexdigit(), @@ -157,7 +157,7 @@ pub fn validate_checksum(checksum: &str) -> bool { for (i, &b) in bytes.iter().enumerate() { let valid = match (len, i) { // 通用字符校验(排除非法字符) - (_, _) if !b.is_ascii_alphanumeric() && b != b'/' && b != b'+' && b != b'=' => false, + (_, _) if !b.is_ascii_alphanumeric() && b != b'/' && b != b'-' && b != b'_' => false, // 格式校验 (72, 0..=7) => true, // 时间戳部分(由extract_time_ks验证) diff --git a/src/main.rs b/src/main.rs index 6c655fe..a9affb5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,25 +8,31 @@ use app::{ PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BASIC_CALIBRATION_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, ROUTE_GET_HASH, ROUTE_GET_TIMESTAMP_HEADER, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, - ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKEN_TAGS_UPDATE_PATH, - ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_DELETE_PATH, ROUTE_TOKENS_GET_PATH, ROUTE_TOKENS_PATH, - ROUTE_TOKENS_UPDATE_PATH, ROUTE_USER_INFO_PATH, + ROUTE_PROXIES_ADD_PATH, ROUTE_PROXIES_DELETE_PATH, ROUTE_PROXIES_GET_PATH, + ROUTE_PROXIES_PATH, ROUTE_PROXIES_SET_GENERAL_PATH, ROUTE_PROXIES_UPDATE_PATH, + ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKENS_ADD_PATH, + ROUTE_TOKENS_DELETE_PATH, ROUTE_TOKENS_GET_PATH, ROUTE_TOKENS_PATH, + ROUTE_TOKENS_PROFILE_UPDATE_PATH, ROUTE_TOKENS_TAGS_UPDATE_PATH, ROUTE_TOKENS_UPDATE_PATH, + ROUTE_USER_INFO_PATH, }, lazy::{AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MODELS_PATH}, model::*, }; use axum::{ - Router, + Router, middleware, routing::{get, post}, }; use chat::{ + middleware::admin_auth_middleware, route::{ - handle_about, handle_add_tokens, handle_api_page, handle_basic_calibration, - handle_build_key, handle_build_key_page, handle_config_page, handle_delete_tokens, - handle_env_example, handle_get_checksum, handle_get_hash, handle_get_timestamp_header, - handle_get_tokens, handle_health, handle_logs, handle_logs_post, handle_readme, - handle_root, handle_static, handle_tokens_page, handle_update_token_tags, - handle_update_tokens, handle_user_info, + handle_about, handle_add_proxy, handle_add_tokens, handle_api_page, + handle_basic_calibration, handle_build_key, handle_build_key_page, handle_config_page, + handle_delete_proxies, handle_delete_tokens, handle_env_example, handle_get_checksum, + handle_get_hash, handle_get_proxies, handle_get_timestamp_header, handle_get_tokens, + handle_health, handle_logs, handle_logs_post, handle_proxies_page, handle_readme, + handle_root, handle_set_general_proxy, handle_static, handle_tokens_page, + handle_update_proxies, handle_update_token_tags, handle_update_tokens, + handle_update_tokens_profile, handle_user_info, }, service::{handle_chat, handle_models}, }; @@ -141,11 +147,32 @@ async fn main() { .route(ROUTE_ROOT_PATH, get(handle_root)) .route(ROUTE_HEALTH_PATH, get(handle_health)) .route(ROUTE_TOKENS_PATH, get(handle_tokens_page)) + .route(ROUTE_PROXIES_PATH, get(handle_proxies_page)) + .merge( + Router::new() + .route(ROUTE_TOKENS_GET_PATH, post(handle_get_tokens)) + .route(ROUTE_TOKENS_UPDATE_PATH, post(handle_update_tokens)) + .route(ROUTE_TOKENS_ADD_PATH, post(handle_add_tokens)) + .route(ROUTE_TOKENS_DELETE_PATH, post(handle_delete_tokens)) + .route( + ROUTE_TOKENS_TAGS_UPDATE_PATH, + post(handle_update_token_tags), + ) + .route( + ROUTE_TOKENS_PROFILE_UPDATE_PATH, + post(handle_update_tokens_profile), + ) + .route(ROUTE_PROXIES_GET_PATH, post(handle_get_proxies)) + .route(ROUTE_PROXIES_UPDATE_PATH, post(handle_update_proxies)) + .route(ROUTE_PROXIES_ADD_PATH, post(handle_add_proxy)) + .route(ROUTE_PROXIES_DELETE_PATH, post(handle_delete_proxies)) + .route( + ROUTE_PROXIES_SET_GENERAL_PATH, + post(handle_set_general_proxy), + ) + .layer(middleware::from_fn(admin_auth_middleware)), + ) .route(ROUTE_MODELS_PATH.as_str(), get(handle_models)) - .route(ROUTE_TOKENS_GET_PATH, post(handle_get_tokens)) - .route(ROUTE_TOKENS_UPDATE_PATH, post(handle_update_tokens)) - .route(ROUTE_TOKENS_ADD_PATH, post(handle_add_tokens)) - .route(ROUTE_TOKENS_DELETE_PATH, post(handle_delete_tokens)) .route(ROUTE_CHAT_PATH.as_str(), post(handle_chat)) .route(ROUTE_LOGS_PATH, get(handle_logs)) .route(ROUTE_LOGS_PATH, post(handle_logs_post)) @@ -163,7 +190,6 @@ async fn main() { .route(ROUTE_USER_INFO_PATH, post(handle_user_info)) .route(ROUTE_BUILD_KEY_PATH, get(handle_build_key_page)) .route(ROUTE_BUILD_KEY_PATH, post(handle_build_key)) - .route(ROUTE_TOKEN_TAGS_UPDATE_PATH, post(handle_update_token_tags)) .layer(RequestBodyLimitLayer::new( 1024 * 1024 * parse_usize_from_env("REQUEST_BODY_LIMIT_MB", 2), )) @@ -172,14 +198,24 @@ async fn main() { // 启动服务器 let port = parse_string_from_env("PORT", "3000"); - let addr = format!("0.0.0.0:{}", port); - println!("服务器运行在端口 {}", port); - println!("当前版本: v{}", PKG_VERSION); - // if PKG_VERSION.contains("pre") { - // println!("当前是测试版,有问题及时反馈哦~"); - // } + let addr = format!("0.0.0.0:{port}"); + println!("服务器运行在端口 {port}"); + #[cfg(not(feature = "__preview"))] + println!("当前版本: v{PKG_VERSION}"); + #[cfg(feature = "__preview")] + { + const BUILD_VERSION: &str = include_str!("../VERSION"); + println!("当前版本: v{PKG_VERSION}+build.{BUILD_VERSION}"); + } + #[cfg(feature = "__preview")] + println!("当前是测试版,有问题及时反馈哦~"); - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + let listener = tokio::net::TcpListener::bind(&addr) + .await + .unwrap_or_else(|e| { + eprintln!("无法绑定到地址 {}: {}", addr, e); + std::process::exit(1); + }); let server = axum::serve(listener, app); tokio::select! { result = server => { diff --git a/static/build_key.html b/static/build_key.html index f3dbbc1..4afabb0 100644 --- a/static/build_key.html +++ b/static/build_key.html @@ -107,6 +107,11 @@ +
+ + +
+
-
- - - -
-
+
+ + + +
+
+
+ +
+
+
+
+
+
+ +
+
+ +
+
+ +
+
+ +
+
+ +
+
+ +
+
+ +
+
+
+
+
+ + + + +
+
+
+ + +
+
+ +
+ +
+ + + + + + + + + + + +
代理名称代理类型代理地址
+ +
+ +
+
已选择: 0 个项目
+
共 0 个代理
+
+
+ + +
+
+ 设为通用代理 + Ctrl+G +
+
+ 复制代理地址 + Ctrl+C +
+
+
+ 删除 + Delete +
+
+ + + + + + + + + + + + + + +
+ + + + + \ No newline at end of file diff --git a/static/shared-styles.css b/static/shared-styles.css index e76f111..1cbbf21 100644 --- a/static/shared-styles.css +++ b/static/shared-styles.css @@ -256,13 +256,23 @@ button.secondary.active { .button-group { display: flex; gap: 10px; - margin: var(--spacing) 0; + flex-wrap: wrap; + justify-content: flex-end; + align-items: flex-end; + margin: 0; } /* 按钮组中的按钮间距调整 */ .button-group button { - flex: 1; - min-width: 120px; + height: 38px; + min-width: 100px; + white-space: nowrap; +} + +.button-group button .context-menu-shortcut { + margin-left: 5px; + opacity: 0.7; + font-size: 12px; } /* 消息容器 - 固定在顶部中间 */ @@ -413,6 +423,63 @@ tr:hover { margin-bottom: 0; } +/* 托盘消息容器 */ +.toast-container { + position: fixed; + bottom: 20px; + right: 20px; + display: flex; + flex-direction: column; + gap: 10px; + z-index: 1000; + max-width: 350px; + max-height: 80vh; + overflow-y: hidden; + padding-top: 10px; + padding-bottom: 10px; + padding-right: 5px; +} + +.toast { + background: var(--card-background); + color: var(--text-primary); + padding: 10px 16px; + border-radius: var(--border-radius); + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); + opacity: 0; + transform: translateY(20px); + transition: opacity 0.4s cubic-bezier(0.25, 0.8, 0.25, 1), transform 0.4s cubic-bezier(0.25, 0.8, 0.25, 1); + position: relative; + min-width: 200px; + margin-left: auto; + will-change: transform, opacity; + pointer-events: auto; +} + +.toast.info { + border-left: 4px solid #2196F3; +} + +.toast.error { + background: #f44336; + color: white; +} + +.toast.success { + background: #4caf50; + color: white; +} + +.toast.warning { + background: #ff9800; + color: white; +} + +.toast.show { + opacity: 1; + transform: translateY(0); +} + /* 响应式设计 */ @media (max-width: 768px) { :root { diff --git a/static/shared.js b/static/shared.js index 4af8e6d..f8ba321 100644 --- a/static/shared.js +++ b/static/shared.js @@ -184,120 +184,6 @@ function parseStringFromBoolean(value, defaultValue = null) { return value ? 'true' : 'false'; } -/** - * 解析对话内容 - * @param {string} promptStr - 原始prompt字符串 - * @returns {Array<{role: string, content: string}>} 解析后的对话数组 - */ -function parsePrompt(promptStr) { - if (!promptStr) return []; - - const messages = []; - const lines = promptStr.split('\n'); - let currentRole = ''; - let currentContent = ''; - - const roleMap = { - 'BEGIN_SYSTEM': 'system', - 'BEGIN_USER': 'user', - 'BEGIN_ASSISTANT': 'assistant' - }; - - for (let i = 0; i < lines.length; i++) { - const line = lines[i]; - - // 检查是否是角色标记行 - let foundRole = false; - for (const [marker, role] of Object.entries(roleMap)) { - if (line.includes(marker)) { - // 保存之前的消息(如果有) - if (currentRole && currentContent.trim()) { - messages.push({ - role: currentRole, - content: currentContent.trim() - }); - } - // 设置新角色 - currentRole = role; - currentContent = ''; - foundRole = true; - break; - } - } - - // 如果不是角色标记行且不是END标记行,则添加到当前内容 - if (!foundRole && !line.includes('END_')) { - currentContent += line + '\n'; - } - } - - // 添加最后一条消息 - if (currentRole && currentContent.trim()) { - messages.push({ - role: currentRole, - content: currentContent.trim() - }); - } - - return messages; -} - -/** - * 格式化对话内容为HTML表格 - * @param {Array<{role: string, content: string}>} messages - 对话消息数组 - * @returns {string} HTML表格字符串 - */ -function formatPromptToTable(messages) { - if (!messages || messages.length === 0) { - return '

无对话内容

'; - } - - const roleLabels = { - 'system': '系统', - 'user': '用户', - 'assistant': '助手' - }; - - function escapeHtml(content) { - // 先转义HTML特殊字符 - const escaped = content - .replace(/&/g, '&') - .replace(//g, '>') - .replace(/"/g, '"') - .replace(/'/g, '''); - - // 将HTML标签文本用引号包裹,使其更易读 - // return escaped.replace(/<(\/?[^>]+)>/g, '"<$1>"'); - return escaped; - } - - return `${messages.map(msg => ``).join('')}
角色内容
${roleLabels[msg.role] || msg.role}${escapeHtml(msg.content).replace(/\n/g, '
')}
`; -} - -/** - * 安全地显示prompt对话框 - * @param {string} promptStr - 原始prompt字符串 - */ -function showPromptModal(promptStr) { - try { - const modal = document.getElementById('promptModal'); - const content = document.getElementById('promptContent'); - - if (!modal || !content) { - console.error('Modal elements not found'); - return; - } - - const messages = parsePrompt(promptStr); - content.innerHTML = formatPromptToTable(messages); - modal.style.display = 'block'; - } catch (e) { - console.error('显示prompt对话框失败:', e); - console.error('原始prompt:', promptStr); - } -} - /** * 将会员类型代码转换为显示名称 * @param {string|null} type - 会员类型代码,如 'free_trial', 'pro', 'free', 'enterprise' 等 diff --git a/static/tokens.html b/static/tokens.html index 3f4dc7d..bbe04b7 100644 --- a/static/tokens.html +++ b/static/tokens.html @@ -10,101 +10,247 @@ @@ -287,170 +656,1376 @@ -
-
-

Token 管理

-
- - - - + +
+
+
+ +
- -
- - -
添加模式: 输入要添加的token,每行一个 - 删除模式: 输入要删除的token,每行一个
-
- -
-
- - +
+ +
+ + +
- - - - - - - - - - - - - - -
TokenChecksum邮箱会员类型Premium用量试用剩余操作
+
+ +
+ + + +
+
+
-
- 快捷键: Ctrl + Enter 执行当前操作 +
+
+
+ + +
+
+ + +
+
+
+ + + +
-
+ +
+
+ +
- - +
+ + + + + + + + + + + + + +
账户/Token会员类型用量试用剩余代理
+ + +
+
+ +
+
已选择: 0 个项目
+
共 0 个Token
+
+
+ + +
+
+ 查看详情 + Enter +
+
+ 刷新Profile + F5 +
+
+ 生成Key + Ctrl+G +
+
+ 复制Token + Ctrl+C +
+
+ 设置代理 +
+
+ 未指定 + 0 + +
+
+ +
+
+
+
+ 删除 + Delete +
+
+ + +
+
+

Token详情

+ +
+
+ +
+
+ + + +
+
+ + +