mirror of
https://github.com/wisdgod/cursor-api.git
synced 2025-09-27 11:02:09 +08:00
0.1.3-rc.5.2.5
This commit is contained in:
@@ -131,3 +131,9 @@ GENERAL_TIMEZONE=Asia/Shanghai
|
||||
|
||||
# 使用内嵌的Claude.ai官方提示词作为默认提示词,如果是claude-开头的模型优先级大于DEFAULT_INSTRUCTIONS
|
||||
USE_OFFICIAL_CLAUDE_PROMPTS=false
|
||||
|
||||
# 真实额度(由于Cursor服务本身的问题,需要等待约5秒;由于架构原因,流式可能有bug),否则全零
|
||||
REAL_USAGE=false
|
||||
|
||||
# 安全哈希,checksum生成更慢
|
||||
SAFE_HASH=true
|
||||
|
@@ -2,7 +2,7 @@ cargo-features = ["profile-rustflags", "trim-paths"]
|
||||
|
||||
[package]
|
||||
name = "cursor-api"
|
||||
version = "0.1.3-rc.5.2.4"
|
||||
version = "0.1.3-rc.5.2.5"
|
||||
edition = "2024"
|
||||
authors = ["wisdgod <nav@wisdgod.com>"]
|
||||
description = "OpenAI format compatibility layer for the Cursor API"
|
||||
@@ -25,6 +25,7 @@ flate2 = { version = "1", default-features = false, features = ["rust_backend"]
|
||||
futures = { version = "^0.3", default-features = false, features = ["std"] }
|
||||
gif = { version = "^0.13", default-features = false, features = ["std"] }
|
||||
hex = { version = "^0.4", default-features = false, features = ["std"] }
|
||||
http = "1"
|
||||
image = { version = "^0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
|
||||
lasso = { version = "^0.7", features = ["inline-more", "multi-threaded"] }
|
||||
memmap2 = "^0.9"
|
||||
@@ -34,11 +35,10 @@ paste = "^1.0"
|
||||
prost = "^0.13"
|
||||
prost-types = "^0.13"
|
||||
rand = { version = "^0.9", default-features = false, features = ["thread_rng"] }
|
||||
regex = { version = "^1.11", default-features = false, features = ["std", "perf"] }
|
||||
reqwest = { version = "^0.12", default-features = false, features = ["gzip", "brotli", "json", "stream", "socks", "__tls", "charset", "rustls-tls-native-roots", "macos-system-configuration"] }
|
||||
reqwest = { version = "^0.12", default-features = false, features = ["gzip", "brotli", "json", "stream", "socks", "__tls", "charset", "rustls-tls-webpki-roots", "macos-system-configuration"] }
|
||||
rkyv = { version = "^0.7", default-features = false, features = ["alloc", "std", "bytecheck", "size_64", "validation", "std"] }
|
||||
serde = { version = "^1.0", default-features = false, features = ["std", "derive", "rc"] }
|
||||
serde_json = { package = "sonic-rs", version = "^0.4" }
|
||||
serde_json = { package = "sonic-rs", version = "0.5" }
|
||||
# serde_json = "^1.0"
|
||||
sha2 = { version = "^0.10", default-features = false }
|
||||
sysinfo = { version = "^0.34", default-features = false, features = ["system"] }
|
||||
|
47
README.md
47
README.md
@@ -80,6 +80,46 @@ deepseek-v3
|
||||
deepseek-r1
|
||||
o3-mini
|
||||
grok-2
|
||||
deepseek-v3.1
|
||||
grok-3-beta
|
||||
grok-3-mini-beta
|
||||
gpt-4.1
|
||||
```
|
||||
|
||||
支持思考:
|
||||
```
|
||||
claude-3.7-sonnet-thinking
|
||||
claude-3.7-sonnet-thinking-max
|
||||
o1-mini
|
||||
o1-preview
|
||||
o1
|
||||
gemini-2.5-pro-exp-03-25
|
||||
gemini-2.5-pro-max
|
||||
gemini-2.0-flash-thinking-exp
|
||||
deepseek-r1
|
||||
o3-mini
|
||||
```
|
||||
|
||||
支持图像:
|
||||
```
|
||||
claude-3.5-sonnet
|
||||
claude-3.7-sonnet
|
||||
claude-3.7-sonnet-thinking
|
||||
claude-3.7-sonnet-max
|
||||
claude-3.7-sonnet-thinking-max
|
||||
gpt-4
|
||||
gpt-4o
|
||||
gpt-4.5-preview
|
||||
claude-3-opus
|
||||
gpt-4-turbo-2024-04-09
|
||||
gpt-4o-128k
|
||||
claude-3-haiku-200k
|
||||
claude-3-5-sonnet-200k
|
||||
gpt-4o-mini
|
||||
claude-3.5-haiku
|
||||
gemini-2.5-pro-exp-03-25
|
||||
gemini-2.5-pro-max
|
||||
gpt-4.1
|
||||
```
|
||||
|
||||
## 接口说明
|
||||
@@ -961,9 +1001,12 @@ string
|
||||
}
|
||||
],
|
||||
"delays": [
|
||||
"string",
|
||||
[
|
||||
"string",
|
||||
number
|
||||
[
|
||||
number, // chars count
|
||||
number // time
|
||||
]
|
||||
]
|
||||
],
|
||||
"usage": { // optional
|
||||
|
@@ -1,3 +1,6 @@
|
||||
mod header;
|
||||
pub use header::*;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! def_pub_const {
|
||||
// 单个常量定义
|
||||
@@ -78,27 +81,12 @@ def_pub_const!(
|
||||
STATUS_FAILURE => "failure"
|
||||
);
|
||||
|
||||
// Header constants
|
||||
def_pub_const!(
|
||||
HEADER_NAME_GHOST_MODE => "x-ghost-mode"
|
||||
);
|
||||
|
||||
// Boolean constants
|
||||
def_pub_const!(
|
||||
TRUE => "true",
|
||||
FALSE => "false"
|
||||
);
|
||||
|
||||
// Content type constants
|
||||
def_pub_const!(
|
||||
CONTENT_TYPE_PROTO => "application/proto",
|
||||
CONTENT_TYPE_CONNECT_PROTO => "application/connect+proto",
|
||||
CONTENT_TYPE_TEXT_HTML_WITH_UTF8 => "text/html;charset=utf-8",
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8 => "text/plain;charset=utf-8",
|
||||
CONTENT_TYPE_TEXT_CSS_WITH_UTF8 => "text/css;charset=utf-8",
|
||||
CONTENT_TYPE_TEXT_JS_WITH_UTF8 => "text/javascript;charset=utf-8"
|
||||
);
|
||||
|
||||
// Authorization constants
|
||||
def_pub_const!(
|
||||
AUTHORIZATION_BEARER_PREFIX => "Bearer "
|
||||
|
80
src/app/constant/header.rs
Normal file
80
src/app/constant/header.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
macro_rules! def_header_name {
|
||||
($($name:ident => $value:expr),+ $(,)?) => {
|
||||
$(paste::paste! {
|
||||
#[inline]
|
||||
pub(crate) fn [<header_name_ $name>]() -> &'static http::header::HeaderName {
|
||||
static HEADER_NAME: std::sync::OnceLock<http::header::HeaderName> = std::sync::OnceLock::new();
|
||||
HEADER_NAME.get_or_init(|| http::header::HeaderName::from_static($value))
|
||||
}
|
||||
})+
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! def_header_value {
|
||||
($($name:ident => $value:expr),+ $(,)?) => {
|
||||
$(paste::paste! {
|
||||
#[inline]
|
||||
pub fn [<header_value_ $name>]() -> &'static http::header::HeaderValue {
|
||||
static HEADER_NAME: std::sync::OnceLock<http::header::HeaderValue> = std::sync::OnceLock::new();
|
||||
HEADER_NAME.get_or_init(|| http::header::HeaderValue::from_static($value))
|
||||
}
|
||||
})+
|
||||
};
|
||||
}
|
||||
|
||||
def_header_value!(
|
||||
one => "1",
|
||||
encoding => "gzip",
|
||||
encodings => "gzip,br",
|
||||
accept => "*/*",
|
||||
language => "en-US",
|
||||
empty => "empty",
|
||||
cors => "cors",
|
||||
no_cache => "no-cache",
|
||||
no_cache_revalidate => "no-cache, must-revalidate",
|
||||
ua_win => "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||
same_origin => "same-origin",
|
||||
keep_alive => "keep-alive",
|
||||
trailers => "trailers",
|
||||
u_eq_0 => "u=0",
|
||||
connect_es => "connect-es/1.6.1",
|
||||
not_a_brand => "\"Not-A.Brand\";v=\"99\", \"Chromium\";v=\"124\"",
|
||||
mobile_no => "?0",
|
||||
windows => "\"Windows\"",
|
||||
ua_cursor => "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.42.5 Chrome/124.0.6367.243 Electron/30.4.0 Safari/537.36",
|
||||
vscode_origin => "vscode-file://vscode-app",
|
||||
cross_site => "cross-site",
|
||||
gzip_deflate => "gzip, deflate",
|
||||
event_stream => "text/event-stream",
|
||||
chunked => "chunked",
|
||||
json => "application/json",
|
||||
proto => "application/proto",
|
||||
connect_proto => "application/connect+proto",
|
||||
|
||||
// Content type constants
|
||||
text_html_utf8 => "text/html;charset=utf-8",
|
||||
text_plain_utf8 => "text/plain;charset=utf-8",
|
||||
text_css_utf8 => "text/css;charset=utf-8",
|
||||
text_js_utf8 => "text/javascript;charset=utf-8"
|
||||
);
|
||||
|
||||
def_header_name!(
|
||||
proxy_host => "x-co",
|
||||
connect_accept_encoding => "connect-accept-encoding",
|
||||
connect_protocol_version => "connect-protocol-version",
|
||||
ghost_mode => "x-ghost-mode",
|
||||
amzn_trace_id => "x-amzn-trace-id",
|
||||
client_key => "x-client-key",
|
||||
cursor_checksum => "x-cursor-checksum",
|
||||
cursor_client_version => "x-cursor-client-version",
|
||||
cursor_timezone => "x-cursor-timezone",
|
||||
request_id => "x-request-id",
|
||||
sec_ch_ua => "sec-ch-ua",
|
||||
sec_ch_ua_mobile => "sec-ch-ua-mobile",
|
||||
sec_ch_ua_platform => "sec-ch-ua-platform",
|
||||
sec_fetch_dest => "sec-fetch-dest",
|
||||
sec_fetch_mode => "sec-fetch-mode",
|
||||
sec_fetch_site => "sec-fetch-site",
|
||||
sec_gpc => "sec-gpc",
|
||||
priority => "priority",
|
||||
);
|
@@ -11,20 +11,11 @@ use tokio::sync::{Mutex, OnceCell};
|
||||
macro_rules! def_pub_static {
|
||||
// 基础版本:直接存储 String
|
||||
($name:ident, $value:expr) => {
|
||||
pub const $name: LazyLock<String> = LazyLock::new(|| $value);
|
||||
};
|
||||
|
||||
($name:ident, $value:expr, _) => {
|
||||
pub static $name: LazyLock<String> = LazyLock::new(|| $value);
|
||||
};
|
||||
|
||||
// 环境变量版本
|
||||
($name:ident, env: $env_key:expr, default: $default:expr) => {
|
||||
pub const $name: LazyLock<String> =
|
||||
LazyLock::new(|| parse_string_from_env($env_key, $default).trim().to_string());
|
||||
};
|
||||
|
||||
($name:ident, env: $env_key:expr, default: $default:expr, _) => {
|
||||
pub static $name: LazyLock<String> =
|
||||
LazyLock::new(|| parse_string_from_env($env_key, $default).trim().to_string());
|
||||
};
|
||||
@@ -41,14 +32,13 @@ macro_rules! def_pub_static {
|
||||
// }
|
||||
|
||||
def_pub_static!(ROUTE_PREFIX, env: "ROUTE_PREFIX", default: EMPTY_STRING);
|
||||
def_pub_static!(AUTH_TOKEN, env: "AUTH_TOKEN", default: EMPTY_STRING, _);
|
||||
def_pub_static!(ROUTE_MODELS_PATH, format!("{}/v1/models", *ROUTE_PREFIX), _);
|
||||
def_pub_static!(AUTH_TOKEN, env: "AUTH_TOKEN", default: EMPTY_STRING);
|
||||
def_pub_static!(ROUTE_MODELS_PATH, format!("{}/v1/models", *ROUTE_PREFIX));
|
||||
def_pub_static!(
|
||||
ROUTE_CHAT_PATH,
|
||||
format!("{}/v1/chat/completions", *ROUTE_PREFIX),
|
||||
_
|
||||
format!("{}/v1/chat/completions", *ROUTE_PREFIX)
|
||||
);
|
||||
def_pub_static!(ROUTE_MESSAGES_PATH, format!("{}/v1/messages", *ROUTE_PREFIX), _);
|
||||
// def_pub_static!(ROUTE_MESSAGES_PATH, format!("{}/v1/messages", *ROUTE_PREFIX));
|
||||
|
||||
static START_TIME: OnceLock<chrono::DateTime<chrono::Local>> = OnceLock::new();
|
||||
|
||||
@@ -56,7 +46,7 @@ pub fn get_start_time() -> &'static chrono::DateTime<chrono::Local> {
|
||||
START_TIME.get_or_init(chrono::Local::now)
|
||||
}
|
||||
|
||||
pub const GENERAL_TIMEZONE: LazyLock<chrono_tz::Tz> = LazyLock::new(|| {
|
||||
pub static GENERAL_TIMEZONE: LazyLock<chrono_tz::Tz> = LazyLock::new(|| {
|
||||
use std::str::FromStr as _;
|
||||
let tz = parse_string_from_env("GENERAL_TIMEZONE", EMPTY_STRING);
|
||||
let tz = tz.trim();
|
||||
@@ -80,9 +70,9 @@ pub fn now_in_general_timezone() -> chrono::DateTime<chrono_tz::Tz> {
|
||||
GENERAL_TIMEZONE.from_utc_datetime(&chrono::Utc::now().naive_utc())
|
||||
}
|
||||
|
||||
def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default\n<|END_USER|>\n\n<|BEGIN_ASSISTANT|>\n\n\nYour will\n<|END_ASSISTANT|>\n\n<|BEGIN_USER|>\n\n\nThe current date is {{currentDateTime}}", _);
|
||||
def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default\n<|END_USER|>\n\n<|BEGIN_ASSISTANT|>\n\n\nYour will\n<|END_ASSISTANT|>\n\n<|BEGIN_USER|>\n\n\nThe current date is {{currentDateTime}}");
|
||||
|
||||
const USE_OFFICIAL_CLAUDE_PROMPTS: LazyLock<bool> =
|
||||
static USE_OFFICIAL_CLAUDE_PROMPTS: LazyLock<bool> =
|
||||
LazyLock::new(|| parse_bool_from_env("USE_OFFICIAL_CLAUDE_PROMPTS", false));
|
||||
|
||||
pub fn get_default_instructions(
|
||||
@@ -125,13 +115,13 @@ pub fn get_default_instructions(
|
||||
)
|
||||
}
|
||||
|
||||
def_pub_static!(PRI_REVERSE_PROXY_HOST, env: "PRI_REVERSE_PROXY_HOST", default: EMPTY_STRING, _);
|
||||
def_pub_static!(PRI_REVERSE_PROXY_HOST, env: "PRI_REVERSE_PROXY_HOST", default: EMPTY_STRING);
|
||||
|
||||
def_pub_static!(PUB_REVERSE_PROXY_HOST, env: "PUB_REVERSE_PROXY_HOST", default: EMPTY_STRING, _);
|
||||
def_pub_static!(PUB_REVERSE_PROXY_HOST, env: "PUB_REVERSE_PROXY_HOST", default: EMPTY_STRING);
|
||||
|
||||
const DEFAULT_KEY_PREFIX: &str = "sk-";
|
||||
|
||||
pub const KEY_PREFIX: LazyLock<String> = LazyLock::new(|| {
|
||||
pub static KEY_PREFIX: LazyLock<String> = LazyLock::new(|| {
|
||||
let value = parse_string_from_env("KEY_PREFIX", DEFAULT_KEY_PREFIX)
|
||||
.trim()
|
||||
.to_string();
|
||||
@@ -142,9 +132,9 @@ pub const KEY_PREFIX: LazyLock<String> = LazyLock::new(|| {
|
||||
}
|
||||
});
|
||||
|
||||
pub const KEY_PREFIX_LEN: LazyLock<usize> = LazyLock::new(|| KEY_PREFIX.len());
|
||||
pub static KEY_PREFIX_LEN: LazyLock<usize> = LazyLock::new(|| KEY_PREFIX.len());
|
||||
|
||||
pub const TOKEN_DELIMITER: LazyLock<char> = LazyLock::new(|| {
|
||||
pub static TOKEN_DELIMITER: LazyLock<char> = LazyLock::new(|| {
|
||||
let delimiter = parse_ascii_char_from_env("TOKEN_DELIMITER", COMMA);
|
||||
if delimiter.is_ascii_alphabetic()
|
||||
|| delimiter.is_ascii_digit()
|
||||
@@ -158,7 +148,7 @@ pub const TOKEN_DELIMITER: LazyLock<char> = LazyLock::new(|| {
|
||||
}
|
||||
});
|
||||
|
||||
pub const USE_COMMA_DELIMITER: LazyLock<bool> = LazyLock::new(|| {
|
||||
pub static USE_COMMA_DELIMITER: LazyLock<bool> = LazyLock::new(|| {
|
||||
let enable = parse_bool_from_env("USE_COMMA_DELIMITER", true);
|
||||
if enable && *TOKEN_DELIMITER == COMMA {
|
||||
false
|
||||
@@ -167,10 +157,10 @@ pub const USE_COMMA_DELIMITER: LazyLock<bool> = LazyLock::new(|| {
|
||||
}
|
||||
});
|
||||
|
||||
pub const USE_PRI_REVERSE_PROXY: LazyLock<bool> =
|
||||
pub static USE_PRI_REVERSE_PROXY: LazyLock<bool> =
|
||||
LazyLock::new(|| !PRI_REVERSE_PROXY_HOST.is_empty());
|
||||
|
||||
pub const USE_PUB_REVERSE_PROXY: LazyLock<bool> =
|
||||
pub static USE_PUB_REVERSE_PROXY: LazyLock<bool> =
|
||||
LazyLock::new(|| !PUB_REVERSE_PROXY_HOST.is_empty());
|
||||
|
||||
macro_rules! def_cursor_api_url {
|
||||
@@ -257,18 +247,18 @@ static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
path
|
||||
});
|
||||
|
||||
pub(super) const CONFIG_FILE_PATH: LazyLock<PathBuf> =
|
||||
pub(super) static CONFIG_FILE_PATH: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| DATA_DIR.join("config.bin"));
|
||||
|
||||
pub(super) const LOGS_FILE_PATH: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("logs.bin"));
|
||||
pub(super) static LOGS_FILE_PATH: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("logs.bin"));
|
||||
|
||||
pub(super) const TOKENS_FILE_PATH: LazyLock<PathBuf> =
|
||||
pub(super) static TOKENS_FILE_PATH: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| DATA_DIR.join("tokens.bin"));
|
||||
|
||||
pub(super) const PROXIES_FILE_PATH: LazyLock<PathBuf> =
|
||||
pub(super) static PROXIES_FILE_PATH: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| DATA_DIR.join("proxies.bin"));
|
||||
|
||||
pub const DEBUG: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("DEBUG", false));
|
||||
pub static DEBUG: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("DEBUG", false));
|
||||
|
||||
// 使用环境变量 "DEBUG_LOG_FILE" 来指定日志文件路径,默认值为 "debug.log"
|
||||
static DEBUG_LOG_FILE: LazyLock<String> =
|
||||
@@ -318,19 +308,42 @@ macro_rules! debug_println {
|
||||
};
|
||||
}
|
||||
|
||||
pub const REQUEST_LOGS_LIMIT: LazyLock<usize> =
|
||||
LazyLock::new(|| std::cmp::min(parse_usize_from_env("REQUEST_LOGS_LIMIT", 100), 100000));
|
||||
// 请求日志相关常量
|
||||
const DEFAULT_REQUEST_LOGS_LIMIT: usize = 100;
|
||||
const MAX_REQUEST_LOGS_LIMIT: usize = 100000;
|
||||
|
||||
pub const IS_NO_REQUEST_LOGS: LazyLock<bool> = LazyLock::new(|| *REQUEST_LOGS_LIMIT == 0);
|
||||
pub const IS_UNLIMITED_REQUEST_LOGS: LazyLock<bool> =
|
||||
LazyLock::new(|| *REQUEST_LOGS_LIMIT == 100000);
|
||||
|
||||
pub const TCP_KEEPALIVE: LazyLock<u64> = LazyLock::new(|| {
|
||||
let keepalive = parse_usize_from_env("TCP_KEEPALIVE", 90);
|
||||
u64::try_from(keepalive).map(|t| t.min(600)).unwrap_or(90)
|
||||
pub static REQUEST_LOGS_LIMIT: LazyLock<usize> = LazyLock::new(|| {
|
||||
std::cmp::min(
|
||||
parse_usize_from_env("REQUEST_LOGS_LIMIT", DEFAULT_REQUEST_LOGS_LIMIT),
|
||||
MAX_REQUEST_LOGS_LIMIT,
|
||||
)
|
||||
});
|
||||
|
||||
pub const SERVICE_TIMEOUT: LazyLock<u64> = LazyLock::new(|| {
|
||||
let timeout = parse_usize_from_env("SERVICE_TIMEOUT", 30);
|
||||
u64::try_from(timeout).map(|t| t.min(600)).unwrap_or(30)
|
||||
pub static IS_NO_REQUEST_LOGS: LazyLock<bool> = LazyLock::new(|| *REQUEST_LOGS_LIMIT == 0);
|
||||
pub static IS_UNLIMITED_REQUEST_LOGS: LazyLock<bool> =
|
||||
LazyLock::new(|| *REQUEST_LOGS_LIMIT == MAX_REQUEST_LOGS_LIMIT);
|
||||
|
||||
// TCP 和超时相关常量
|
||||
const DEFAULT_TCP_KEEPALIVE: usize = 90;
|
||||
const MAX_TCP_KEEPALIVE: u64 = 600;
|
||||
|
||||
pub static TCP_KEEPALIVE: LazyLock<u64> = LazyLock::new(|| {
|
||||
let keepalive = parse_usize_from_env("TCP_KEEPALIVE", DEFAULT_TCP_KEEPALIVE);
|
||||
u64::try_from(keepalive)
|
||||
.map(|t| t.min(MAX_TCP_KEEPALIVE))
|
||||
.unwrap_or(DEFAULT_TCP_KEEPALIVE as u64)
|
||||
});
|
||||
|
||||
const DEFAULT_SERVICE_TIMEOUT: usize = 30;
|
||||
const MAX_SERVICE_TIMEOUT: u64 = 600;
|
||||
|
||||
pub static SERVICE_TIMEOUT: LazyLock<u64> = LazyLock::new(|| {
|
||||
let timeout = parse_usize_from_env("SERVICE_TIMEOUT", DEFAULT_SERVICE_TIMEOUT);
|
||||
u64::try_from(timeout)
|
||||
.map(|t| t.min(MAX_SERVICE_TIMEOUT))
|
||||
.unwrap_or(DEFAULT_SERVICE_TIMEOUT as u64)
|
||||
});
|
||||
|
||||
pub static REAL_USAGE: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("REAL_USAGE", false));
|
||||
|
||||
pub static SAFE_HASH: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("SAFE_HASH", true));
|
||||
|
@@ -84,7 +84,8 @@ pub struct RequestLog {
|
||||
pub struct Chain {
|
||||
#[serde(skip_serializing_if = "Prompt::is_none")]
|
||||
pub prompt: Prompt,
|
||||
pub delays: Vec<(String, f64)>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub delays: Option<(String, Vec<(u32, f32)>)>,
|
||||
#[serde(skip_serializing_if = "OptionUsage::is_none")]
|
||||
pub usage: OptionUsage,
|
||||
}
|
||||
|
@@ -45,7 +45,7 @@ impl<'de> Deserialize<'de> for UsageCheckModelConfig {
|
||||
.split(COMMA)
|
||||
.filter_map(|model| {
|
||||
let model = model.trim();
|
||||
Models::find_id(model)
|
||||
Models::find_id(model).map(|m| m.id)
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
@@ -135,7 +135,7 @@ impl From<super::Prompt> for PromptHelper {
|
||||
#[derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
|
||||
pub struct ChainHelper {
|
||||
pub prompt: PromptHelper,
|
||||
pub delays: Vec<(String, f64)>,
|
||||
pub delays: Option<(String, Vec<(u32, f32)>)>,
|
||||
pub usage: super::OptionUsage,
|
||||
}
|
||||
impl From<ChainHelper> for super::Chain {
|
||||
|
@@ -1,17 +1,19 @@
|
||||
mod proxy_url;
|
||||
|
||||
use crate::app::lazy::{PROXIES_FILE_PATH, SERVICE_TIMEOUT, TCP_KEEPALIVE};
|
||||
use memmap2::{MmapMut, MmapOptions};
|
||||
use parking_lot::RwLock;
|
||||
use reqwest::{Client, Proxy};
|
||||
use proxy_url::StringUrl;
|
||||
use reqwest::Client;
|
||||
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fs::OpenOptions;
|
||||
use std::str::FromStr;
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
mod proxy_url;
|
||||
use proxy_url::StringUrl;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs::OpenOptions,
|
||||
str::FromStr,
|
||||
sync::LazyLock,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
// 新的代理值常量
|
||||
pub const NON_PROXY: &str = "non";
|
||||
@@ -21,10 +23,8 @@ pub const SYS_PROXY: &str = "sys";
|
||||
pub static PROXY_POOL: LazyLock<RwLock<ProxyPool>> = LazyLock::new(|| {
|
||||
let system_client = Client::builder()
|
||||
.https_only(true)
|
||||
.http1_only()
|
||||
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
|
||||
.timeout(Duration::from_secs(*SERVICE_TIMEOUT))
|
||||
.http1_title_case_headers()
|
||||
.connect_timeout(Duration::from_secs(*SERVICE_TIMEOUT))
|
||||
.build()
|
||||
.expect("创建默认系统客户端失败");
|
||||
|
||||
@@ -109,10 +109,8 @@ impl Proxies {
|
||||
SingleProxy::Non,
|
||||
Client::builder()
|
||||
.https_only(true)
|
||||
.http1_only()
|
||||
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
|
||||
.timeout(Duration::from_secs(*SERVICE_TIMEOUT))
|
||||
.http1_title_case_headers()
|
||||
.connect_timeout(Duration::from_secs(*SERVICE_TIMEOUT))
|
||||
.no_proxy()
|
||||
.build()
|
||||
.expect("创建无代理客户端失败"),
|
||||
@@ -123,29 +121,23 @@ impl Proxies {
|
||||
SingleProxy::Sys,
|
||||
Client::builder()
|
||||
.https_only(true)
|
||||
.http1_only()
|
||||
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
|
||||
.timeout(Duration::from_secs(*SERVICE_TIMEOUT))
|
||||
.http1_title_case_headers()
|
||||
.connect_timeout(Duration::from_secs(*SERVICE_TIMEOUT))
|
||||
.build()
|
||||
.expect("创建默认客户端失败"),
|
||||
);
|
||||
}
|
||||
SingleProxy::Url(url) => {
|
||||
if let Ok(proxy_obj) = Proxy::all(url.to_string()) {
|
||||
pool.clients.insert(
|
||||
(*proxy).clone(),
|
||||
Client::builder()
|
||||
.https_only(true)
|
||||
.http1_only()
|
||||
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
|
||||
.timeout(Duration::from_secs(*SERVICE_TIMEOUT))
|
||||
.http1_title_case_headers()
|
||||
.proxy(proxy_obj)
|
||||
.build()
|
||||
.expect("创建代理客户端失败"),
|
||||
);
|
||||
}
|
||||
pool.clients.insert(
|
||||
(*proxy).clone(),
|
||||
Client::builder()
|
||||
.https_only(true)
|
||||
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
|
||||
.connect_timeout(Duration::from_secs(*SERVICE_TIMEOUT))
|
||||
.proxy(url.as_proxy().expect("创建代理对象失败"))
|
||||
.build()
|
||||
.expect("创建代理客户端失败"),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -3,7 +3,7 @@ use memmap2::{MmapMut, MmapOptions};
|
||||
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
collections::{HashMap, HashSet, VecDeque},
|
||||
fs::OpenOptions,
|
||||
};
|
||||
|
||||
@@ -59,7 +59,7 @@ pub struct RequestStatsManager {
|
||||
pub total_requests: u64,
|
||||
pub active_requests: u64,
|
||||
pub error_requests: u64,
|
||||
pub request_logs: Vec<RequestLog>,
|
||||
pub request_logs: VecDeque<RequestLog>,
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
@@ -180,7 +180,7 @@ impl TokenManager {
|
||||
}
|
||||
|
||||
impl RequestStatsManager {
|
||||
pub fn new(request_logs: Vec<RequestLog>) -> Self {
|
||||
pub fn new(request_logs: VecDeque<RequestLog>) -> Self {
|
||||
Self {
|
||||
total_requests: request_logs.len() as u64,
|
||||
active_requests: 0,
|
||||
@@ -220,11 +220,11 @@ impl RequestStatsManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn load_logs() -> Result<Vec<RequestLog>, Box<dyn std::error::Error>> {
|
||||
pub async fn load_logs() -> Result<VecDeque<RequestLog>, Box<dyn std::error::Error>> {
|
||||
let file = match OpenOptions::new().read(true).open(&*LOGS_FILE_PATH) {
|
||||
Ok(file) => file,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
return Ok(Vec::new());
|
||||
return Ok(VecDeque::new());
|
||||
}
|
||||
Err(e) => return Err(Box::new(e)),
|
||||
};
|
||||
|
@@ -25,6 +25,7 @@ impl UsageCheck {
|
||||
.model_ids
|
||||
.iter()
|
||||
.filter_map(|id| Models::find_id(id))
|
||||
.map(|m| m.id)
|
||||
.collect();
|
||||
if models.is_empty() {
|
||||
Self::None
|
||||
@@ -125,6 +126,7 @@ impl<'de> Deserialize<'de> for UsageCheck {
|
||||
let model = model.trim();
|
||||
Models::find_id(model)
|
||||
})
|
||||
.map(|m| m.id)
|
||||
.collect();
|
||||
|
||||
if models.is_empty() {
|
||||
@@ -153,6 +155,7 @@ impl UsageCheck {
|
||||
let model = model.trim();
|
||||
Models::find_id(model)
|
||||
})
|
||||
.map(|m| m.id)
|
||||
.collect();
|
||||
|
||||
if models.is_empty() {
|
||||
|
@@ -1,7 +1,20 @@
|
||||
use crate::app::{
|
||||
constant::{
|
||||
CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, CURSOR_API2_HOST, CURSOR_HOST,
|
||||
CURSOR_SETTINGS_URL, HEADER_NAME_GHOST_MODE, TRUE,
|
||||
CURSOR_API2_HOST, CURSOR_HOST, CURSOR_SETTINGS_URL, TRUE, header_name_amzn_trace_id,
|
||||
header_name_client_key, header_name_connect_accept_encoding,
|
||||
header_name_connect_protocol_version, header_name_cursor_checksum,
|
||||
header_name_cursor_client_version, header_name_cursor_timezone, header_name_ghost_mode,
|
||||
header_name_priority, header_name_proxy_host, header_name_request_id,
|
||||
header_name_sec_ch_ua, header_name_sec_ch_ua_mobile, header_name_sec_ch_ua_platform,
|
||||
header_name_sec_fetch_dest, header_name_sec_fetch_mode, header_name_sec_fetch_site,
|
||||
header_name_sec_gpc, header_value_accept, header_value_chunked, header_value_connect_es,
|
||||
header_value_connect_proto, header_value_cors, header_value_cross_site, header_value_empty,
|
||||
header_value_encoding, header_value_encodings, header_value_gzip_deflate,
|
||||
header_value_keep_alive, header_value_language, header_value_mobile_no,
|
||||
header_value_no_cache, header_value_not_a_brand, header_value_one, header_value_proto,
|
||||
header_value_same_origin, header_value_trailers, header_value_u_eq_0,
|
||||
header_value_ua_cursor, header_value_ua_win, header_value_vscode_origin,
|
||||
header_value_windows,
|
||||
},
|
||||
lazy::{
|
||||
PRI_REVERSE_PROXY_HOST, PUB_REVERSE_PROXY_HOST, USE_PRI_REVERSE_PROXY,
|
||||
@@ -10,7 +23,7 @@ use crate::app::{
|
||||
},
|
||||
};
|
||||
use reqwest::{
|
||||
Client, RequestBuilder,
|
||||
Client, Method, RequestBuilder,
|
||||
header::{
|
||||
ACCEPT, ACCEPT_ENCODING, ACCEPT_LANGUAGE, CACHE_CONTROL, CONNECTION, CONTENT_LENGTH,
|
||||
CONTENT_TYPE, COOKIE, DNT, HOST, ORIGIN, PRAGMA, REFERER, TE, TRANSFER_ENCODING,
|
||||
@@ -18,56 +31,30 @@ use reqwest::{
|
||||
},
|
||||
};
|
||||
|
||||
macro_rules! def_const {
|
||||
($name:ident, $value:expr) => {
|
||||
const $name: &'static str = $value;
|
||||
};
|
||||
}
|
||||
|
||||
def_const!(SEC_FETCH_DEST, "sec-fetch-dest");
|
||||
def_const!(SEC_FETCH_MODE, "sec-fetch-mode");
|
||||
def_const!(SEC_FETCH_SITE, "sec-fetch-site");
|
||||
def_const!(SEC_GPC, "sec-gpc");
|
||||
def_const!(PRIORITY, "priority");
|
||||
|
||||
def_const!(ONE, "1");
|
||||
def_const!(ENCODINGS, "gzip,br");
|
||||
def_const!(VALUE_ACCEPT, "*/*");
|
||||
def_const!(VALUE_LANGUAGE, "en-US");
|
||||
def_const!(EMPTY, "empty");
|
||||
def_const!(CORS, "cors");
|
||||
def_const!(NO_CACHE, "no-cache");
|
||||
def_const!(
|
||||
UA_WIN,
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
);
|
||||
def_const!(SAME_ORIGIN, "same-origin");
|
||||
def_const!(KEEP_ALIVE, "keep-alive");
|
||||
def_const!(TRAILERS, "trailers");
|
||||
def_const!(U_EQ_4, "u=4");
|
||||
def_const!(U_EQ_0, "u=0");
|
||||
|
||||
def_const!(PROXY_HOST, "x-co");
|
||||
|
||||
#[inline]
|
||||
fn get_client_and_host_post<'a>(
|
||||
fn get_client_and_host<'a>(
|
||||
client: &Client,
|
||||
method: Method,
|
||||
url: &'a str,
|
||||
is_pri: bool,
|
||||
real_host: &'a str,
|
||||
) -> (RequestBuilder, &'a str) {
|
||||
if is_pri && *USE_PRI_REVERSE_PROXY {
|
||||
(
|
||||
client.post(url).header(PROXY_HOST, real_host),
|
||||
client
|
||||
.request(method, url)
|
||||
.header(header_name_proxy_host(), real_host),
|
||||
PRI_REVERSE_PROXY_HOST.as_str(),
|
||||
)
|
||||
} else if !is_pri && *USE_PUB_REVERSE_PROXY {
|
||||
(
|
||||
client.post(url).header(PROXY_HOST, real_host),
|
||||
client
|
||||
.request(method, url)
|
||||
.header(header_name_proxy_host(), real_host),
|
||||
PUB_REVERSE_PROXY_HOST.as_str(),
|
||||
)
|
||||
} else {
|
||||
(client.post(url), real_host)
|
||||
(client.request(method, url), real_host)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,54 +82,43 @@ pub(crate) struct AiServiceRequest<'a> {
|
||||
///
|
||||
/// * `reqwest::RequestBuilder` - 配置好的请求构建器
|
||||
pub fn build_request(req: AiServiceRequest) -> RequestBuilder {
|
||||
let (client, host) =
|
||||
get_client_and_host_post(&req.client, req.url, req.is_pri, CURSOR_API2_HOST);
|
||||
let (builder, host) = get_client_and_host(
|
||||
&req.client,
|
||||
Method::POST,
|
||||
req.url,
|
||||
req.is_pri,
|
||||
CURSOR_API2_HOST,
|
||||
);
|
||||
|
||||
client
|
||||
builder
|
||||
.header(
|
||||
CONTENT_TYPE,
|
||||
if req.is_stream {
|
||||
CONTENT_TYPE_CONNECT_PROTO
|
||||
header_value_connect_proto()
|
||||
} else {
|
||||
CONTENT_TYPE_PROTO
|
||||
header_value_proto()
|
||||
},
|
||||
)
|
||||
.bearer_auth(req.auth_token)
|
||||
.header("connect-accept-encoding", ENCODINGS)
|
||||
.header("connect-protocol-version", ONE)
|
||||
.header(USER_AGENT, "connect-es/1.6.1")
|
||||
.header("x-amzn-trace-id", format!("Root={}", req.trace_id))
|
||||
.header("x-client-key", req.client_key)
|
||||
.header("x-cursor-checksum", req.checksum)
|
||||
.header("x-cursor-client-version", "0.42.5")
|
||||
.header("x-cursor-timezone", req.timezone)
|
||||
.header(HEADER_NAME_GHOST_MODE, TRUE)
|
||||
.header("x-request-id", req.trace_id)
|
||||
.header(
|
||||
header_name_connect_accept_encoding(),
|
||||
header_value_encoding(),
|
||||
)
|
||||
.header(header_name_connect_protocol_version(), header_value_one())
|
||||
.header(USER_AGENT, header_value_connect_es())
|
||||
.header(
|
||||
header_name_amzn_trace_id(),
|
||||
format!("Root={}", req.trace_id),
|
||||
)
|
||||
.header(header_name_client_key(), req.client_key)
|
||||
.header(header_name_cursor_checksum(), req.checksum)
|
||||
.header(header_name_cursor_client_version(), "0.42.5")
|
||||
.header(header_name_cursor_timezone(), req.timezone)
|
||||
.header(header_name_ghost_mode(), TRUE)
|
||||
.header(header_name_request_id(), req.trace_id)
|
||||
.header(HOST, host)
|
||||
.header(CONNECTION, KEEP_ALIVE)
|
||||
.header(TRANSFER_ENCODING, "chunked")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_client_and_host<'a>(
|
||||
client: &Client,
|
||||
url: &'a str,
|
||||
is_pri: bool,
|
||||
real_host: &'a str,
|
||||
) -> (RequestBuilder, &'a str) {
|
||||
if is_pri && *USE_PRI_REVERSE_PROXY {
|
||||
(
|
||||
client.get(url).header(PROXY_HOST, real_host),
|
||||
PRI_REVERSE_PROXY_HOST.as_str(),
|
||||
)
|
||||
} else if !is_pri && *USE_PUB_REVERSE_PROXY {
|
||||
(
|
||||
client.get(url).header(PROXY_HOST, real_host),
|
||||
PUB_REVERSE_PROXY_HOST.as_str(),
|
||||
)
|
||||
} else {
|
||||
(client.get(url), real_host)
|
||||
}
|
||||
.header(CONNECTION, header_value_keep_alive())
|
||||
.header(TRANSFER_ENCODING, header_value_chunked())
|
||||
}
|
||||
|
||||
/// 返回预构建的获取 Stripe 账户信息的 Cursor API 客户端
|
||||
@@ -155,32 +131,30 @@ fn get_client_and_host<'a>(
|
||||
///
|
||||
/// * `reqwest::RequestBuilder` - 配置好的请求构建器
|
||||
pub fn build_profile_request(client: &Client, auth_token: &str, is_pri: bool) -> RequestBuilder {
|
||||
let (client, host) = get_client_and_host(
|
||||
let (builder, host) = get_client_and_host(
|
||||
client,
|
||||
Method::GET,
|
||||
cursor_api2_stripe_url(is_pri),
|
||||
is_pri,
|
||||
CURSOR_API2_HOST,
|
||||
);
|
||||
|
||||
client
|
||||
builder
|
||||
.header(HOST, host)
|
||||
.header("sec-ch-ua", "\"Not-A.Brand\";v=\"99\", \"Chromium\";v=\"124\"")
|
||||
.header(HEADER_NAME_GHOST_MODE, TRUE)
|
||||
.header("sec-ch-ua-mobile", "?0")
|
||||
.header(header_name_sec_ch_ua(), header_value_not_a_brand())
|
||||
.header(header_name_ghost_mode(), TRUE)
|
||||
.header(header_name_sec_ch_ua_mobile(), header_value_mobile_no())
|
||||
.bearer_auth(auth_token)
|
||||
.header(
|
||||
USER_AGENT,
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.42.5 Chrome/124.0.6367.243 Electron/30.4.0 Safari/537.36",
|
||||
)
|
||||
.header("sec-ch-ua-platform", "\"Windows\"")
|
||||
.header(ACCEPT, VALUE_ACCEPT)
|
||||
.header(ORIGIN, "vscode-file://vscode-app")
|
||||
.header(SEC_FETCH_SITE, "cross-site")
|
||||
.header(SEC_FETCH_MODE, CORS)
|
||||
.header(SEC_FETCH_DEST, EMPTY)
|
||||
.header(ACCEPT_ENCODING, ENCODINGS)
|
||||
.header(ACCEPT_LANGUAGE, VALUE_LANGUAGE)
|
||||
.header(PRIORITY, "u=1, i")
|
||||
.header(USER_AGENT, header_value_ua_cursor())
|
||||
.header(header_name_sec_ch_ua_platform(), header_value_windows())
|
||||
.header(ACCEPT, header_value_accept())
|
||||
.header(ORIGIN, header_value_vscode_origin())
|
||||
.header(header_name_sec_fetch_site(), header_value_cross_site())
|
||||
.header(header_name_sec_fetch_mode(), header_value_cors())
|
||||
.header(header_name_sec_fetch_dest(), header_value_empty())
|
||||
.header(ACCEPT_ENCODING, header_value_encodings())
|
||||
.header(ACCEPT_LANGUAGE, header_value_language())
|
||||
.header(header_name_priority(), header_value_u_eq_0())
|
||||
}
|
||||
|
||||
/// 返回预构建的获取使用情况的 Cursor API 客户端
|
||||
@@ -199,26 +173,31 @@ pub fn build_usage_request(
|
||||
auth_token: &str,
|
||||
is_pri: bool,
|
||||
) -> RequestBuilder {
|
||||
let (client, host) =
|
||||
get_client_and_host(client, cursor_usage_api_url(is_pri), is_pri, CURSOR_HOST);
|
||||
let (client, host) = get_client_and_host(
|
||||
client,
|
||||
Method::GET,
|
||||
cursor_usage_api_url(is_pri),
|
||||
is_pri,
|
||||
CURSOR_HOST,
|
||||
);
|
||||
|
||||
client
|
||||
.header(HOST, host)
|
||||
.header(USER_AGENT, UA_WIN)
|
||||
.header(ACCEPT, VALUE_ACCEPT)
|
||||
.header(ACCEPT_LANGUAGE, VALUE_LANGUAGE)
|
||||
.header(ACCEPT_ENCODING, ENCODINGS)
|
||||
.header(USER_AGENT, header_value_ua_win())
|
||||
.header(ACCEPT, header_value_accept())
|
||||
.header(ACCEPT_LANGUAGE, header_value_language())
|
||||
.header(ACCEPT_ENCODING, header_value_encodings())
|
||||
.header(REFERER, CURSOR_SETTINGS_URL)
|
||||
.header(DNT, ONE)
|
||||
.header(SEC_GPC, ONE)
|
||||
.header(SEC_FETCH_DEST, EMPTY)
|
||||
.header(SEC_FETCH_MODE, CORS)
|
||||
.header(SEC_FETCH_SITE, SAME_ORIGIN)
|
||||
.header(CONNECTION, KEEP_ALIVE)
|
||||
.header(PRAGMA, NO_CACHE)
|
||||
.header(CACHE_CONTROL, NO_CACHE)
|
||||
.header(TE, TRAILERS)
|
||||
.header(PRIORITY, U_EQ_4)
|
||||
.header(DNT, header_value_one())
|
||||
.header(header_name_sec_gpc(), header_value_one())
|
||||
.header(header_name_sec_fetch_dest(), header_value_empty())
|
||||
.header(header_name_sec_fetch_mode(), header_value_cors())
|
||||
.header(header_name_sec_fetch_site(), header_value_same_origin())
|
||||
.header(CONNECTION, header_value_keep_alive())
|
||||
.header(PRAGMA, header_value_no_cache())
|
||||
.header(CACHE_CONTROL, header_value_no_cache())
|
||||
.header(TE, header_value_trailers())
|
||||
.header(header_name_priority(), header_value_u_eq_0())
|
||||
.header(
|
||||
COOKIE,
|
||||
format!("WorkosCursorSessionToken={user_id}%3A%3A{auth_token}"),
|
||||
@@ -242,26 +221,31 @@ pub fn build_userinfo_request(
|
||||
auth_token: &str,
|
||||
is_pri: bool,
|
||||
) -> RequestBuilder {
|
||||
let (client, host) =
|
||||
get_client_and_host(client, cursor_user_api_url(is_pri), is_pri, CURSOR_HOST);
|
||||
let (client, host) = get_client_and_host(
|
||||
client,
|
||||
Method::GET,
|
||||
cursor_user_api_url(is_pri),
|
||||
is_pri,
|
||||
CURSOR_HOST,
|
||||
);
|
||||
|
||||
client
|
||||
.header(HOST, host)
|
||||
.header(USER_AGENT, UA_WIN)
|
||||
.header(ACCEPT, VALUE_ACCEPT)
|
||||
.header(ACCEPT_LANGUAGE, VALUE_LANGUAGE)
|
||||
.header(ACCEPT_ENCODING, ENCODINGS)
|
||||
.header(USER_AGENT, header_value_ua_win())
|
||||
.header(ACCEPT, header_value_accept())
|
||||
.header(ACCEPT_LANGUAGE, header_value_language())
|
||||
.header(ACCEPT_ENCODING, header_value_encodings())
|
||||
.header(REFERER, CURSOR_SETTINGS_URL)
|
||||
.header(DNT, ONE)
|
||||
.header(SEC_GPC, ONE)
|
||||
.header(SEC_FETCH_DEST, EMPTY)
|
||||
.header(SEC_FETCH_MODE, CORS)
|
||||
.header(SEC_FETCH_SITE, SAME_ORIGIN)
|
||||
.header(CONNECTION, KEEP_ALIVE)
|
||||
.header(PRAGMA, NO_CACHE)
|
||||
.header(CACHE_CONTROL, NO_CACHE)
|
||||
.header(TE, TRAILERS)
|
||||
.header(PRIORITY, U_EQ_4)
|
||||
.header(DNT, header_value_one())
|
||||
.header(header_name_sec_gpc(), header_value_one())
|
||||
.header(header_name_sec_fetch_dest(), header_value_empty())
|
||||
.header(header_name_sec_fetch_mode(), header_value_cors())
|
||||
.header(header_name_sec_fetch_site(), header_value_same_origin())
|
||||
.header(CONNECTION, header_value_keep_alive())
|
||||
.header(PRAGMA, header_value_no_cache())
|
||||
.header(CACHE_CONTROL, header_value_no_cache())
|
||||
.header(TE, header_value_trailers())
|
||||
.header(header_name_priority(), header_value_u_eq_0())
|
||||
.header(
|
||||
COOKIE,
|
||||
format!("WorkosCursorSessionToken={user_id}%3A%3A{auth_token}"),
|
||||
@@ -276,8 +260,9 @@ pub fn build_token_upgrade_request(
|
||||
auth_token: &str,
|
||||
is_pri: bool,
|
||||
) -> RequestBuilder {
|
||||
let (client, host) = get_client_and_host_post(
|
||||
let (client, host) = get_client_and_host(
|
||||
client,
|
||||
Method::POST,
|
||||
cursor_token_upgrade_url(is_pri),
|
||||
is_pri,
|
||||
CURSOR_HOST,
|
||||
@@ -287,10 +272,10 @@ pub fn build_token_upgrade_request(
|
||||
|
||||
client
|
||||
.header(HOST, host)
|
||||
.header(USER_AGENT, UA_WIN)
|
||||
.header(ACCEPT, VALUE_ACCEPT)
|
||||
.header(ACCEPT_LANGUAGE, VALUE_LANGUAGE)
|
||||
.header(ACCEPT_ENCODING, ENCODINGS)
|
||||
.header(USER_AGENT, header_value_ua_win())
|
||||
.header(ACCEPT, header_value_accept())
|
||||
.header(ACCEPT_LANGUAGE, header_value_language())
|
||||
.header(ACCEPT_ENCODING, header_value_encodings())
|
||||
.header(
|
||||
REFERER,
|
||||
format!(
|
||||
@@ -299,16 +284,16 @@ pub fn build_token_upgrade_request(
|
||||
)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.header(CONTENT_LENGTH, body.len())
|
||||
.header(DNT, ONE)
|
||||
.header(SEC_GPC, ONE)
|
||||
.header(SEC_FETCH_DEST, EMPTY)
|
||||
.header(SEC_FETCH_MODE, CORS)
|
||||
.header(SEC_FETCH_SITE, SAME_ORIGIN)
|
||||
.header(CONNECTION, KEEP_ALIVE)
|
||||
.header(PRAGMA, NO_CACHE)
|
||||
.header(CACHE_CONTROL, NO_CACHE)
|
||||
.header(TE, TRAILERS)
|
||||
.header(PRIORITY, U_EQ_0)
|
||||
.header(DNT, header_value_one())
|
||||
.header(header_name_sec_gpc(), header_value_one())
|
||||
.header(header_name_sec_fetch_dest(), header_value_empty())
|
||||
.header(header_name_sec_fetch_mode(), header_value_cors())
|
||||
.header(header_name_sec_fetch_site(), header_value_same_origin())
|
||||
.header(CONNECTION, header_value_keep_alive())
|
||||
.header(PRAGMA, header_value_no_cache())
|
||||
.header(CACHE_CONTROL, header_value_no_cache())
|
||||
.header(TE, header_value_trailers())
|
||||
.header(header_name_priority(), header_value_u_eq_0())
|
||||
.header(
|
||||
COOKIE,
|
||||
format!("WorkosCursorSessionToken={user_id}%3A%3A{auth_token}"),
|
||||
@@ -324,17 +309,18 @@ pub fn build_token_poll_request(
|
||||
) -> RequestBuilder {
|
||||
let (client, host) = get_client_and_host(
|
||||
client,
|
||||
Method::GET,
|
||||
cursor_token_poll_url(is_pri),
|
||||
is_pri,
|
||||
CURSOR_API2_HOST,
|
||||
);
|
||||
client
|
||||
.header(HOST, host)
|
||||
.header(ACCEPT_ENCODING, "gzip, deflate")
|
||||
.header(ACCEPT_LANGUAGE, "en-US")
|
||||
.header(USER_AGENT, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.48.2 Chrome/132.0.6834.210 Electron/34.3.4 Safari/537.36")
|
||||
.header(ORIGIN, "vscode-file://vscode-app")
|
||||
.header(HEADER_NAME_GHOST_MODE, TRUE)
|
||||
.header(ACCEPT, "*/*")
|
||||
.header(ACCEPT_ENCODING, header_value_gzip_deflate())
|
||||
.header(ACCEPT_LANGUAGE, header_value_language())
|
||||
.header(USER_AGENT, header_value_ua_cursor())
|
||||
.header(ORIGIN, header_value_vscode_origin())
|
||||
.header(header_name_ghost_mode(), TRUE)
|
||||
.header(ACCEPT, header_value_accept())
|
||||
.query(&[("uuid", uuid), ("verifier", verifier)])
|
||||
}
|
||||
|
@@ -7,8 +7,7 @@ pub struct HealthCheckResponse {
|
||||
pub status: ApiStatus,
|
||||
pub version: &'static str,
|
||||
pub uptime: i64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stats: Option<SystemStats>,
|
||||
pub stats: SystemStats,
|
||||
pub models: Vec<&'static str>,
|
||||
pub endpoints: &'static [&'static str],
|
||||
}
|
||||
|
@@ -1,5 +1,4 @@
|
||||
mod checksum;
|
||||
use std::time::Instant;
|
||||
|
||||
use ::base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
pub use checksum::*;
|
||||
@@ -30,8 +29,8 @@ use crate::{
|
||||
},
|
||||
config::key_config,
|
||||
constant::{
|
||||
ANTHROPIC, CREATED, CURSOR, DEEPSEEK, GOOGLE, MODEL_OBJECT, OPENAI, UNKNOWN, XAI,
|
||||
calculate_display_name_v3,
|
||||
ANTHROPIC, CREATED, CURSOR, DEEPSEEK, DEFAULT, GOOGLE, MODEL_OBJECT, OPENAI, UNKNOWN,
|
||||
XAI, calculate_display_name_v3,
|
||||
},
|
||||
model::{Model, Usage},
|
||||
},
|
||||
@@ -111,20 +110,6 @@ impl TrimNewlines for String {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait InstantExt {
|
||||
fn duration_as_secs_f64(&mut self) -> f64;
|
||||
}
|
||||
|
||||
impl InstantExt for Instant {
|
||||
#[inline]
|
||||
fn duration_as_secs_f64(&mut self) -> f64 {
|
||||
let now = Instant::now();
|
||||
let duration = now.duration_since(*self);
|
||||
*self = now;
|
||||
duration.as_secs_f64()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_token_profile(
|
||||
client: Client,
|
||||
auth_token: &str,
|
||||
@@ -261,6 +246,12 @@ pub async fn get_available_models(
|
||||
}
|
||||
};
|
||||
let display_name = calculate_display_name_v3(&model.name);
|
||||
let is_thinking = model.supports_thinking();
|
||||
let is_image = if model.name.as_str() == DEFAULT {
|
||||
true
|
||||
} else {
|
||||
model.supports_images()
|
||||
};
|
||||
|
||||
Model {
|
||||
id: crate::leak::intern_string(model.name),
|
||||
@@ -268,6 +259,8 @@ pub async fn get_available_models(
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by,
|
||||
is_thinking,
|
||||
is_image,
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
@@ -276,9 +269,9 @@ pub async fn get_available_models(
|
||||
|
||||
pub async fn get_token_usage(
|
||||
client: Client,
|
||||
auth_token: &str,
|
||||
checksum: &str,
|
||||
client_key: &str,
|
||||
auth_token: String,
|
||||
checksum: String,
|
||||
client_key: String,
|
||||
timezone: &'static str,
|
||||
is_pri: bool,
|
||||
usage_uuid: String,
|
||||
@@ -287,9 +280,9 @@ pub async fn get_token_usage(
|
||||
let trace_id = uuid::Uuid::new_v4().to_string();
|
||||
let client = super::client::build_request(super::client::AiServiceRequest {
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
auth_token: &auth_token,
|
||||
checksum: &checksum,
|
||||
client_key: &client_key,
|
||||
url: cursor_api2_token_usage_url(is_pri),
|
||||
is_stream: false,
|
||||
timezone,
|
||||
@@ -449,13 +442,13 @@ pub fn token_to_tokeninfo(
|
||||
}
|
||||
|
||||
/// 将 TokenInfo 转换为 JWT token
|
||||
pub fn tokeninfo_to_token(mut info: key_config::TokenInfo) -> Option<(String, String, Client)> {
|
||||
pub fn tokeninfo_to_token(info: key_config::TokenInfo) -> Option<(String, String, Client)> {
|
||||
// 构建 payload
|
||||
let payload = TokenPayload {
|
||||
sub: std::mem::take(&mut info.sub),
|
||||
sub: info.sub,
|
||||
exp: info.end,
|
||||
randomness: std::mem::take(&mut info.randomness),
|
||||
time: info.start.to_string(), // exp - 30000天
|
||||
randomness: info.randomness,
|
||||
time: info.start.to_string(),
|
||||
iss: ISSUER.to_string(),
|
||||
scope: SCOPE.to_string(),
|
||||
aud: AUDIENCE.to_string(),
|
||||
|
@@ -4,11 +4,11 @@ use sha2::{Digest, Sha256};
|
||||
#[inline]
|
||||
pub fn generate_hash() -> String {
|
||||
use rand::Rng as _;
|
||||
hex::encode(
|
||||
Sha256::new()
|
||||
.chain_update(rand::rng().random::<[u8; 32]>())
|
||||
.finalize(),
|
||||
)
|
||||
let mut v = rand::rng().random::<[u8; 32]>();
|
||||
if *crate::app::lazy::SAFE_HASH {
|
||||
v = Sha256::new().chain_update(v).finalize().into();
|
||||
}
|
||||
hex::encode(v)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -31,10 +31,19 @@ fn deobfuscate_bytes(bytes: &mut [u8]) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_timestamp_header() -> String {
|
||||
pub fn generate_timestamp_header() -> [u8; 8] {
|
||||
static CACHE: std::sync::LazyLock<parking_lot::Mutex<(u64, [u8; 8])>> =
|
||||
std::sync::LazyLock::new(|| parking_lot::Mutex::new((0, [0u8; 8])));
|
||||
let timestamp = super::now_secs() / 1_000;
|
||||
|
||||
let mut timestamp_bytes = vec![
|
||||
let mut guard = CACHE.lock();
|
||||
if guard.0 == timestamp {
|
||||
return guard.1;
|
||||
} else {
|
||||
guard.0 = timestamp;
|
||||
}
|
||||
|
||||
let mut timestamp_bytes = [
|
||||
((timestamp >> 8) & 0xFF) as u8,
|
||||
(0xFF & timestamp) as u8,
|
||||
((timestamp >> 24) & 0xFF) as u8,
|
||||
@@ -44,12 +53,16 @@ pub fn generate_timestamp_header() -> String {
|
||||
];
|
||||
|
||||
obfuscate_bytes(&mut timestamp_bytes);
|
||||
BASE64.encode(×tamp_bytes)
|
||||
let mut result = [0u8; 8];
|
||||
let _ = BASE64.encode_slice(×tamp_bytes, &mut result);
|
||||
guard.1 = result;
|
||||
result
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String {
|
||||
let encoded = generate_timestamp_header();
|
||||
let timestamp_header = generate_timestamp_header();
|
||||
let encoded = unsafe { str::from_utf8_unchecked(×tamp_header) };
|
||||
match mac_addr {
|
||||
Some(mac) => format!("{encoded}{device_id}/{mac}"),
|
||||
None => format!("{encoded}{device_id}"),
|
||||
@@ -100,23 +113,23 @@ pub fn generate_checksum_with_repair(checksum: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
let timestamp_header = generate_timestamp_header();
|
||||
let encoded = unsafe { str::from_utf8_unchecked(×tamp_header) };
|
||||
|
||||
// 校验通过后构造结果
|
||||
match len {
|
||||
72 => format!(
|
||||
"{}{}/{}",
|
||||
generate_timestamp_header(),
|
||||
"{encoded}{}/{}",
|
||||
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(8..)) },
|
||||
generate_hash()
|
||||
),
|
||||
129 => format!(
|
||||
"{}{}/{}",
|
||||
generate_timestamp_header(),
|
||||
"{encoded}{}/{}",
|
||||
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(..64)) },
|
||||
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(65..)) }
|
||||
),
|
||||
137 => format!(
|
||||
"{}{}/{}",
|
||||
generate_timestamp_header(),
|
||||
"{encoded}{}/{}",
|
||||
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(8..72)) },
|
||||
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(73..)) }
|
||||
),
|
||||
|
@@ -18,11 +18,8 @@ use super::{
|
||||
AzureState, ChatExternalLink, ConversationMessage, ExplicitContext, GetChatRequest,
|
||||
ImageProto, ModelDetails, WebReference, conversation_message, image_proto,
|
||||
},
|
||||
constant::{
|
||||
ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS,
|
||||
SUPPORTED_IMAGE_MODELS,
|
||||
},
|
||||
model::{Message, MessageContent, Role},
|
||||
constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS},
|
||||
model::{Message, MessageContent, Model, Role},
|
||||
};
|
||||
|
||||
fn parse_web_references(text: &str) -> Vec<WebReference> {
|
||||
@@ -90,7 +87,7 @@ async fn process_chat_inputs(
|
||||
inputs: Vec<Message>,
|
||||
now_with_tz: Option<chrono::DateTime<chrono_tz::Tz>>,
|
||||
disable_vision: bool,
|
||||
model: &str,
|
||||
model: Model,
|
||||
) -> (String, Vec<ConversationMessage>, Vec<String>) {
|
||||
// 收集 system 指令
|
||||
let instructions = inputs
|
||||
@@ -101,7 +98,7 @@ async fn process_chat_inputs(
|
||||
MessageContent::Vision(contents) => contents
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if content.rtype == "text" {
|
||||
if content.r#type == "text" {
|
||||
content.text.clone()
|
||||
} else {
|
||||
None
|
||||
@@ -114,9 +111,9 @@ async fn process_chat_inputs(
|
||||
.join("\n\n");
|
||||
|
||||
// 使用默认指令或收集到的指令
|
||||
let image_support = !disable_vision && SUPPORTED_IMAGE_MODELS.contains(&model);
|
||||
let image_support = !disable_vision && model.is_image;
|
||||
let instructions = if instructions.is_empty() {
|
||||
get_default_instructions(now_with_tz, model, image_support)
|
||||
get_default_instructions(now_with_tz, model.id, image_support)
|
||||
} else {
|
||||
instructions
|
||||
};
|
||||
@@ -239,7 +236,7 @@ async fn process_chat_inputs(
|
||||
let mut images = Vec::new();
|
||||
|
||||
for content in contents {
|
||||
match content.rtype.as_str() {
|
||||
match content.r#type.as_str() {
|
||||
"text" => {
|
||||
if let Some(text) = content.text {
|
||||
text_parts.push(text);
|
||||
@@ -493,7 +490,7 @@ async fn process_http_image(
|
||||
pub async fn encode_chat_message(
|
||||
inputs: Vec<Message>,
|
||||
now_with_tz: Option<chrono::DateTime<chrono_tz::Tz>>,
|
||||
model: &str,
|
||||
model: Model,
|
||||
disable_vision: bool,
|
||||
enable_slow_pool: bool,
|
||||
is_search: bool,
|
||||
@@ -525,7 +522,7 @@ pub async fn encode_chat_message(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let long_context = AppConfig::get_long_context() || LONG_CONTEXT_MODELS.contains(&model);
|
||||
let long_context = AppConfig::get_long_context() || LONG_CONTEXT_MODELS.contains(&model.id);
|
||||
|
||||
let chat = GetChatRequest {
|
||||
current_file: None,
|
||||
@@ -535,7 +532,7 @@ pub async fn encode_chat_message(
|
||||
workspace_root_path: None,
|
||||
code_blocks: vec![],
|
||||
model_details: Some(ModelDetails {
|
||||
model_name: Some(model.to_string()),
|
||||
model_name: Some(model.id.to_string()),
|
||||
api_key: None,
|
||||
enable_ghost_mode: Some(true),
|
||||
azure_state: Some(AzureState {
|
||||
|
@@ -3,19 +3,19 @@
|
||||
pub struct KeyConfig {
|
||||
/// 认证令牌(必需)
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub auth_token: ::core::option::Option<key_config::TokenInfo>,
|
||||
pub auth_token: Option<key_config::TokenInfo>,
|
||||
/// 是否禁用图片处理能力
|
||||
#[prost(bool, optional, tag = "4")]
|
||||
pub disable_vision: ::core::option::Option<bool>,
|
||||
pub disable_vision: Option<bool>,
|
||||
/// 是否启用慢速池
|
||||
#[prost(bool, optional, tag = "5")]
|
||||
pub enable_slow_pool: ::core::option::Option<bool>,
|
||||
pub enable_slow_pool: Option<bool>,
|
||||
/// 使用量检查模型规则
|
||||
#[prost(message, optional, tag = "6")]
|
||||
pub usage_check_models: ::core::option::Option<key_config::UsageCheckModel>,
|
||||
pub usage_check_models: Option<key_config::UsageCheckModel>,
|
||||
/// 包含网络引用
|
||||
#[prost(bool, optional, tag = "7")]
|
||||
pub include_web_references: ::core::option::Option<bool>,
|
||||
pub include_web_references: Option<bool>,
|
||||
}
|
||||
/// Nested message and enum types in `KeyConfig`.
|
||||
pub mod key_config {
|
||||
@@ -24,7 +24,7 @@ pub mod key_config {
|
||||
pub struct TokenInfo {
|
||||
/// 用户标识符
|
||||
#[prost(string, tag = "1")]
|
||||
pub sub: ::prost::alloc::string::String,
|
||||
pub sub: String,
|
||||
/// 生成时间(Unix 时间戳)
|
||||
#[prost(int64, tag = "2")]
|
||||
pub start: i64,
|
||||
@@ -33,19 +33,19 @@ pub mod key_config {
|
||||
pub end: i64,
|
||||
/// 随机字符串
|
||||
#[prost(string, tag = "4")]
|
||||
pub randomness: ::prost::alloc::string::String,
|
||||
pub randomness: String,
|
||||
/// 签名
|
||||
#[prost(string, tag = "5")]
|
||||
pub signature: ::prost::alloc::string::String,
|
||||
pub signature: String,
|
||||
/// 机器ID的SHA256哈希值
|
||||
#[prost(bytes = "vec", tag = "6")]
|
||||
pub machine_id: ::prost::alloc::vec::Vec<u8>,
|
||||
pub machine_id: Vec<u8>,
|
||||
/// MAC地址的SHA256哈希值
|
||||
#[prost(bytes = "vec", tag = "7")]
|
||||
pub mac_id: ::prost::alloc::vec::Vec<u8>,
|
||||
pub mac_id: Vec<u8>,
|
||||
/// 代理名称
|
||||
#[prost(string, optional, tag = "8")]
|
||||
pub proxy_name: ::core::option::Option<::prost::alloc::string::String>,
|
||||
pub proxy_name: Option<String>,
|
||||
}
|
||||
/// 使用量检查模型规则
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
@@ -55,7 +55,7 @@ pub mod key_config {
|
||||
pub r#type: i32,
|
||||
/// 模型 ID 列表,当 type 为 TYPE_CUSTOM 时生效
|
||||
#[prost(string, repeated, tag = "2")]
|
||||
pub model_ids: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
pub model_ids: Vec<String>,
|
||||
}
|
||||
/// Nested message and enum types in `UsageCheckModel`.
|
||||
pub mod usage_check_model {
|
||||
@@ -88,7 +88,7 @@ pub mod key_config {
|
||||
}
|
||||
}
|
||||
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||
pub fn from_str_name(value: &str) -> Option<Self> {
|
||||
match value {
|
||||
"TYPE_DEFAULT" => Some(Self::Default),
|
||||
"TYPE_DISABLED" => Some(Self::Disabled),
|
||||
|
@@ -66,6 +66,7 @@ def_pub_const!(
|
||||
O1 => "o1",
|
||||
O3_MINI => "o3-mini",
|
||||
GPT_4_5_PREVIEW => "gpt-4.5-preview",
|
||||
GPT_4_1 => "gpt-4.1",
|
||||
|
||||
// Cursor 模型
|
||||
CURSOR_FAST => "cursor-fast",
|
||||
@@ -73,7 +74,6 @@ def_pub_const!(
|
||||
|
||||
// Google 模型
|
||||
GEMINI_1_5_FLASH_500K => "gemini-1.5-flash-500k",
|
||||
GEMINI_EXP_1206 => "gemini-exp-1206",
|
||||
GEMINI_2_0_PRO_EXP => "gemini-2.0-pro-exp",
|
||||
GEMINI_2_5_PRO_EXP_03_25 => "gemini-2.5-pro-exp-03-25",
|
||||
GEMINI_2_5_PRO_MAX => "gemini-2.5-pro-max",
|
||||
@@ -83,9 +83,12 @@ def_pub_const!(
|
||||
// Deepseek 模型
|
||||
DEEPSEEK_V3 => "deepseek-v3",
|
||||
DEEPSEEK_R1 => "deepseek-r1",
|
||||
DEEPSEEK_V3_1 => "deepseek-v3.1",
|
||||
|
||||
// XAI 模型
|
||||
GROK_2 => "grok-2",
|
||||
GROK_3_BETA => "grok-3-beta",
|
||||
GROK_3_MINI_BETA => "grok-3-mini-beta",
|
||||
|
||||
// 未知模型
|
||||
DEFAULT => "default",
|
||||
@@ -103,6 +106,8 @@ macro_rules! create_models {
|
||||
created: CREATED,
|
||||
object: MODEL_OBJECT,
|
||||
owned_by: $owner,
|
||||
is_thinking: SUPPORTED_THINKING_MODELS.contains(&$model),
|
||||
is_image: SUPPORTED_IMAGE_MODELS.contains(&$model),
|
||||
},
|
||||
)*
|
||||
]),
|
||||
@@ -139,12 +144,8 @@ impl Models {
|
||||
// }
|
||||
|
||||
// 查找模型并返回其 ID
|
||||
pub fn find_id(model: &str) -> Option<&'static str> {
|
||||
Self::read()
|
||||
.models
|
||||
.iter()
|
||||
.find(|m| m.id == model)
|
||||
.map(|m| m.id)
|
||||
pub fn find_id(model: &str) -> Option<Model> {
|
||||
Self::read().models.iter().find(|m| m.id == model).copied()
|
||||
}
|
||||
|
||||
// 返回所有模型 ID 的列表
|
||||
@@ -228,22 +229,21 @@ create_models!(
|
||||
DEEPSEEK_R1 => DEEPSEEK,
|
||||
O3_MINI => OPENAI,
|
||||
GROK_2 => XAI,
|
||||
DEEPSEEK_V3_1 => DEEPSEEK,
|
||||
GROK_3_BETA => XAI,
|
||||
GROK_3_MINI_BETA => XAI,
|
||||
GPT_4_1 => OPENAI,
|
||||
);
|
||||
|
||||
pub const USAGE_CHECK_MODELS: [&str; 13] = [
|
||||
CLAUDE_3_5_SONNET,
|
||||
CLAUDE_3_7_SONNET,
|
||||
CLAUDE_3_7_SONNET_THINKING,
|
||||
GEMINI_EXP_1206,
|
||||
GPT_4,
|
||||
GPT_4_TURBO_2024_04_09,
|
||||
GPT_4O,
|
||||
CLAUDE_3_5_HAIKU,
|
||||
GPT_4O_128K,
|
||||
GEMINI_1_5_FLASH_500K,
|
||||
CLAUDE_3_HAIKU_200K,
|
||||
CLAUDE_3_5_SONNET_200K,
|
||||
DEEPSEEK_R1,
|
||||
pub const FREE_MODELS: [&str; 8] = [
|
||||
CURSOR_FAST,
|
||||
CURSOR_SMALL,
|
||||
GPT_4O_MINI,
|
||||
GPT_3_5_TURBO,
|
||||
DEEPSEEK_V3,
|
||||
DEEPSEEK_V3_1,
|
||||
GROK_3_MINI_BETA,
|
||||
GPT_4_1,
|
||||
];
|
||||
|
||||
pub const LONG_CONTEXT_MODELS: [&str; 4] = [
|
||||
@@ -253,14 +253,37 @@ pub const LONG_CONTEXT_MODELS: [&str; 4] = [
|
||||
CLAUDE_3_5_SONNET_200K,
|
||||
];
|
||||
|
||||
pub const SUPPORTED_IMAGE_MODELS: [&str; 9] = [
|
||||
const SUPPORTED_THINKING_MODELS: [&str; 10] = [
|
||||
CLAUDE_3_7_SONNET_THINKING,
|
||||
CLAUDE_3_7_SONNET_THINKING_MAX,
|
||||
O1_MINI,
|
||||
O1_PREVIEW,
|
||||
O1,
|
||||
GEMINI_2_5_PRO_EXP_03_25,
|
||||
GEMINI_2_5_PRO_MAX,
|
||||
GEMINI_2_0_FLASH_THINKING_EXP,
|
||||
DEEPSEEK_R1,
|
||||
O3_MINI,
|
||||
];
|
||||
|
||||
const SUPPORTED_IMAGE_MODELS: [&str; 19] = [
|
||||
DEFAULT,
|
||||
CLAUDE_3_5_SONNET,
|
||||
CLAUDE_3_7_SONNET,
|
||||
CLAUDE_3_7_SONNET_THINKING,
|
||||
GPT_4O,
|
||||
GPT_4O_MINI,
|
||||
DEFAULT,
|
||||
CLAUDE_3_OPUS,
|
||||
CLAUDE_3_5_HAIKU,
|
||||
CLAUDE_3_7_SONNET_MAX,
|
||||
CLAUDE_3_7_SONNET_THINKING_MAX,
|
||||
GPT_4,
|
||||
GPT_4O,
|
||||
GPT_4_5_PREVIEW,
|
||||
CLAUDE_3_OPUS,
|
||||
GPT_4_TURBO_2024_04_09,
|
||||
GPT_4O_128K,
|
||||
CLAUDE_3_HAIKU_200K,
|
||||
CLAUDE_3_5_SONNET_200K,
|
||||
GPT_4O_MINI,
|
||||
CLAUDE_3_5_HAIKU,
|
||||
GEMINI_2_5_PRO_EXP_03_25,
|
||||
GEMINI_2_5_PRO_MAX,
|
||||
GPT_4_1,
|
||||
];
|
||||
|
@@ -1,6 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{ser::SerializeStruct as _, Deserialize, Serialize};
|
||||
use serde::{Deserialize, Serialize, ser::SerializeStruct as _};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
@@ -12,7 +12,7 @@ pub enum MessageContent {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct VisionMessageContent {
|
||||
#[serde(rename = "type")]
|
||||
pub rtype: String,
|
||||
pub r#type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -53,8 +53,8 @@ pub enum Role {
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ChatResponse {
|
||||
pub id: String,
|
||||
pub struct ChatResponse<'a> {
|
||||
pub id: &'a str,
|
||||
pub object: &'static str,
|
||||
pub created: i64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -116,13 +116,16 @@ pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
// 模型定义
|
||||
/// 模型定义
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Model {
|
||||
pub id: &'static str,
|
||||
pub display_name: &'static str,
|
||||
pub created: &'static i64,
|
||||
pub object: &'static str,
|
||||
pub owned_by: &'static str,
|
||||
pub is_thinking: bool,
|
||||
pub is_image: bool,
|
||||
}
|
||||
|
||||
impl Serialize for Model {
|
||||
@@ -130,7 +133,7 @@ impl Serialize for Model {
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut state = serializer.serialize_struct("Model", 7)?;
|
||||
let mut state = serializer.serialize_struct("Model", 9)?;
|
||||
|
||||
state.serialize_field("id", &self.id)?;
|
||||
state.serialize_field("display_name", &self.display_name)?;
|
||||
@@ -139,6 +142,8 @@ impl Serialize for Model {
|
||||
state.serialize_field("object", &self.object)?;
|
||||
state.serialize_field("type", &self.object)?;
|
||||
state.serialize_field("owned_by", &self.owned_by)?;
|
||||
state.serialize_field("supports_thinking", &self.is_thinking)?;
|
||||
state.serialize_field("supports_images", &self.is_image)?;
|
||||
|
||||
state.end()
|
||||
}
|
||||
@@ -150,19 +155,19 @@ impl PartialEq for Model {
|
||||
}
|
||||
}
|
||||
|
||||
use super::constant::{Models, USAGE_CHECK_MODELS};
|
||||
use super::constant::{FREE_MODELS, Models};
|
||||
use crate::{
|
||||
app::model::{AppConfig, UsageCheck},
|
||||
common::model::tri::TriState,
|
||||
};
|
||||
|
||||
impl Model {
|
||||
pub fn is_usage_check(model_id: &str, usage_check: Option<UsageCheck>) -> bool {
|
||||
pub fn is_usage_check(&self, usage_check: Option<UsageCheck>) -> bool {
|
||||
match usage_check.unwrap_or(AppConfig::get_usage_check()) {
|
||||
UsageCheck::None => false,
|
||||
UsageCheck::Default => USAGE_CHECK_MODELS.contains(&model_id),
|
||||
UsageCheck::Default => !FREE_MODELS.contains(&self.id),
|
||||
UsageCheck::All => true,
|
||||
UsageCheck::Custom(models) => models.contains(&model_id),
|
||||
UsageCheck::Custom(models) => models.contains(&self.id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -6,7 +6,7 @@ use axum::{
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app::constant::CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8,
|
||||
app::constant::header_value_text_plain_utf8,
|
||||
common::utils::{
|
||||
generate_checksum_with_default, generate_checksum_with_repair, generate_hash,
|
||||
generate_timestamp_header,
|
||||
@@ -16,11 +16,7 @@ use crate::{
|
||||
pub async fn handle_get_hash() -> Response {
|
||||
let hash = generate_hash();
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
CONTENT_TYPE,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8.parse().unwrap(),
|
||||
);
|
||||
let headers = HeaderMap::from_iter([(CONTENT_TYPE, header_value_text_plain_utf8().clone())]);
|
||||
|
||||
(headers, hash).into_response()
|
||||
}
|
||||
@@ -37,11 +33,7 @@ pub async fn handle_get_checksum(Query(query): Query<ChecksumQuery>) -> Response
|
||||
Some(checksum) => generate_checksum_with_repair(&checksum),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
CONTENT_TYPE,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8.parse().unwrap(),
|
||||
);
|
||||
let headers = HeaderMap::from_iter([(CONTENT_TYPE, header_value_text_plain_utf8().clone())]);
|
||||
|
||||
(headers, checksum).into_response()
|
||||
}
|
||||
@@ -49,11 +41,7 @@ pub async fn handle_get_checksum(Query(query): Query<ChecksumQuery>) -> Response
|
||||
pub async fn handle_get_timestamp_header() -> Response {
|
||||
let timestamp_header = generate_timestamp_header();
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
CONTENT_TYPE,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8.parse().unwrap(),
|
||||
);
|
||||
let headers = HeaderMap::from_iter([(CONTENT_TYPE, header_value_text_plain_utf8().clone())]);
|
||||
|
||||
(headers, timestamp_header).into_response()
|
||||
}
|
||||
|
@@ -1,22 +1,19 @@
|
||||
use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_API_PATH,
|
||||
ROUTE_BASIC_CALIBRATION_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH,
|
||||
ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, ROUTE_GET_HASH, ROUTE_GET_TIMESTAMP_HEADER,
|
||||
ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, ROUTE_PROXIES_ADD_PATH, ROUTE_PROXIES_DELETE_PATH,
|
||||
ROUTE_PROXIES_GET_PATH, ROUTE_PROXIES_PATH, ROUTE_PROXIES_SET_GENERAL_PATH,
|
||||
ROUTE_PROXIES_SET_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH,
|
||||
ROUTE_TOKEN_UPGRADE_PATH, ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_BY_TAG_GET_PATH,
|
||||
ROUTE_TOKENS_DELETE_PATH, ROUTE_TOKENS_GET_PATH, ROUTE_TOKENS_PATH,
|
||||
ROUTE_TOKENS_PROFILE_UPDATE_PATH, ROUTE_TOKENS_SET_PATH, ROUTE_TOKENS_STATUS_SET_PATH,
|
||||
ROUTE_TOKENS_TAGS_GET_PATH, ROUTE_TOKENS_TAGS_SET_PATH, ROUTE_TOKENS_UPGRADE_PATH,
|
||||
ROUTE_USER_INFO_PATH,
|
||||
},
|
||||
lazy::{
|
||||
AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MESSAGES_PATH, ROUTE_MODELS_PATH, get_start_time,
|
||||
PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BASIC_CALIBRATION_PATH,
|
||||
ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM,
|
||||
ROUTE_GET_HASH, ROUTE_GET_TIMESTAMP_HEADER, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH,
|
||||
ROUTE_PROXIES_ADD_PATH, ROUTE_PROXIES_DELETE_PATH, ROUTE_PROXIES_GET_PATH,
|
||||
ROUTE_PROXIES_PATH, ROUTE_PROXIES_SET_GENERAL_PATH, ROUTE_PROXIES_SET_PATH,
|
||||
ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKEN_UPGRADE_PATH,
|
||||
ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_BY_TAG_GET_PATH, ROUTE_TOKENS_DELETE_PATH,
|
||||
ROUTE_TOKENS_GET_PATH, ROUTE_TOKENS_PATH, ROUTE_TOKENS_PROFILE_UPDATE_PATH,
|
||||
ROUTE_TOKENS_SET_PATH, ROUTE_TOKENS_STATUS_SET_PATH, ROUTE_TOKENS_TAGS_GET_PATH,
|
||||
ROUTE_TOKENS_TAGS_SET_PATH, ROUTE_TOKENS_UPGRADE_PATH, ROUTE_USER_INFO_PATH,
|
||||
header_value_text_html_utf8, header_value_text_plain_utf8,
|
||||
},
|
||||
lazy::{ROUTE_CHAT_PATH, ROUTE_MODELS_PATH, get_start_time},
|
||||
model::{AppConfig, AppState, PageContent},
|
||||
},
|
||||
common::model::{
|
||||
@@ -30,12 +27,11 @@ use axum::{
|
||||
body::Body,
|
||||
extract::State,
|
||||
http::{
|
||||
HeaderMap, StatusCode,
|
||||
StatusCode,
|
||||
header::{CONTENT_TYPE, LOCATION},
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use std::sync::Arc;
|
||||
use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System};
|
||||
use tokio::sync::Mutex;
|
||||
@@ -48,20 +44,19 @@ pub async fn handle_root() -> impl IntoResponse {
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
static ENDPOINTS: std::sync::LazyLock<[&'static str; 34]> = std::sync::LazyLock::new(|| {
|
||||
static ENDPOINTS: std::sync::LazyLock<[&'static str; 33]> = std::sync::LazyLock::new(|| {
|
||||
[
|
||||
&*ROUTE_CHAT_PATH,
|
||||
&*ROUTE_MESSAGES_PATH,
|
||||
&*ROUTE_MODELS_PATH,
|
||||
ROUTE_TOKENS_PATH,
|
||||
ROUTE_TOKENS_GET_PATH,
|
||||
@@ -97,20 +92,12 @@ static ENDPOINTS: std::sync::LazyLock<[&'static str; 34]> = std::sync::LazyLock:
|
||||
]
|
||||
});
|
||||
|
||||
pub async fn handle_health(
|
||||
State(state): State<Arc<Mutex<AppState>>>,
|
||||
headers: HeaderMap,
|
||||
) -> Json<HealthCheckResponse> {
|
||||
pub async fn handle_health(State(state): State<Arc<Mutex<AppState>>>) -> Json<HealthCheckResponse> {
|
||||
let start_time = get_start_time();
|
||||
let uptime = (chrono::Local::now() - start_time).num_seconds();
|
||||
|
||||
// 先检查 headers 是否包含有效的认证信息
|
||||
let stats = if headers
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
|
||||
.is_some_and(|token| token == AUTH_TOKEN.as_str())
|
||||
{
|
||||
let stats = {
|
||||
// 只有在需要系统信息时才创建实例
|
||||
let mut sys = System::new_with_specifics(
|
||||
RefreshKind::nothing()
|
||||
@@ -135,7 +122,7 @@ pub async fn handle_health(
|
||||
|
||||
let state = state.lock().await;
|
||||
|
||||
Some(SystemStats {
|
||||
SystemStats {
|
||||
started: start_time.to_string(),
|
||||
total_requests: state.request_manager.total_requests,
|
||||
active_requests: state.request_manager.active_requests,
|
||||
@@ -147,9 +134,7 @@ pub async fn handle_health(
|
||||
usage: cpu_usage, // CPU 使用率(百分比)
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
Json(HealthCheckResponse {
|
||||
|
@@ -1,8 +1,8 @@
|
||||
use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
|
||||
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_LOGS_PATH,
|
||||
AUTHORIZATION_BEARER_PREFIX, ROUTE_LOGS_PATH, header_value_text_html_utf8,
|
||||
header_value_text_plain_utf8,
|
||||
},
|
||||
lazy::AUTH_TOKEN,
|
||||
model::{AppConfig, AppState, LogStatus, PageContent, RequestLog},
|
||||
@@ -30,15 +30,15 @@ use tokio::sync::Mutex;
|
||||
pub async fn handle_logs() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_LOGS_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(include_str!("../../../static/logs.min.html")))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -92,7 +92,7 @@ pub async fn handle_logs_post(
|
||||
active: None,
|
||||
error: None,
|
||||
logs: Vec::new(),
|
||||
timestamp: Local::now().to_string(),
|
||||
timestamp: Local::now(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -108,7 +108,7 @@ pub async fn handle_logs_post(
|
||||
active: None,
|
||||
error: None,
|
||||
logs: Vec::new(),
|
||||
timestamp: Local::now().to_string(),
|
||||
timestamp: Local::now(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -117,123 +117,112 @@ pub async fn handle_logs_post(
|
||||
};
|
||||
|
||||
// 准备日志数据(管理员或特定用户的)
|
||||
let logs = if auth_header == auth_token {
|
||||
state.request_manager.request_logs.clone()
|
||||
} else {
|
||||
let mut iterator = Box::new(state.request_manager.request_logs.iter())
|
||||
as Box<dyn Iterator<Item = &RequestLog>>;
|
||||
if auth_header != auth_token {
|
||||
// 解析 token
|
||||
let token_part = extract_token(auth_header).ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
// 筛选符合条件的日志
|
||||
let filtered_logs: Vec<RequestLog> = state
|
||||
.request_manager
|
||||
.request_logs
|
||||
.iter()
|
||||
.filter(|log| log.token_info.token == token_part)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if filtered_logs.is_empty() {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
filtered_logs
|
||||
iterator = Box::new(iterator.filter(move |log| log.token_info.token == token_part));
|
||||
};
|
||||
|
||||
// 应用查询参数过滤
|
||||
let mut result_logs = logs;
|
||||
|
||||
// 按状态过滤
|
||||
if let Some(status) = &request.query.status {
|
||||
result_logs.retain(|log| log.status.as_str_name() == *status);
|
||||
iterator = Box::new(iterator.filter(move |log| log.status.as_str_name() == status));
|
||||
}
|
||||
|
||||
// 按模型过滤
|
||||
if let Some(model) = &request.query.model {
|
||||
result_logs.retain(|log| log.model.contains(model));
|
||||
iterator = Box::new(iterator.filter(move |log| log.model.contains(model)));
|
||||
}
|
||||
|
||||
// 按用户邮箱过滤
|
||||
if let Some(email) = &request.query.email {
|
||||
result_logs.retain(|log| {
|
||||
iterator = Box::new(iterator.filter(move |log| {
|
||||
log.token_info
|
||||
.profile
|
||||
.as_ref()
|
||||
.map(|p| p.user.email.contains(email))
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
// 按会员类型过滤
|
||||
if let Some(membership_type) = membership_enum {
|
||||
result_logs.retain(|log| {
|
||||
iterator = Box::new(iterator.filter(move |log| {
|
||||
log.token_info
|
||||
.profile
|
||||
.as_ref()
|
||||
.map(|p| p.stripe.membership_type == membership_type)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
// 按总耗时范围过滤
|
||||
if let Some(min_time) = request.query.min_total_time {
|
||||
result_logs.retain(|log| log.timing.total >= min_time);
|
||||
iterator = Box::new(iterator.filter(move |log| log.timing.total >= min_time));
|
||||
}
|
||||
|
||||
if let Some(max_time) = request.query.max_total_time {
|
||||
result_logs.retain(|log| log.timing.total <= max_time);
|
||||
iterator = Box::new(iterator.filter(move |log| log.timing.total <= max_time));
|
||||
}
|
||||
|
||||
// 按是否为流式请求过滤
|
||||
if let Some(stream) = request.query.stream {
|
||||
result_logs.retain(|log| log.stream == stream);
|
||||
iterator = Box::new(iterator.filter(move |log| log.stream == stream));
|
||||
}
|
||||
|
||||
// 按是否有错误过滤
|
||||
if let Some(has_error) = request.query.has_error {
|
||||
result_logs.retain(|log| log.error.is_some() == has_error);
|
||||
iterator = Box::new(iterator.filter(move |log| log.error.is_some() == has_error));
|
||||
}
|
||||
|
||||
// 按是否有chain过滤
|
||||
if let Some(has_chain) = request.query.has_chain {
|
||||
result_logs.retain(|log| log.chain.is_some() == has_chain);
|
||||
iterator = Box::new(iterator.filter(move |log| log.chain.is_some() == has_chain));
|
||||
}
|
||||
|
||||
// 按日期范围过滤
|
||||
if let Some(from_date) = request.query.from_date {
|
||||
result_logs.retain(|log| log.timestamp >= from_date);
|
||||
iterator = Box::new(iterator.filter(move |log| log.timestamp >= from_date));
|
||||
}
|
||||
|
||||
if let Some(to_date) = request.query.to_date {
|
||||
result_logs.retain(|log| log.timestamp <= to_date);
|
||||
iterator = Box::new(iterator.filter(move |log| log.timestamp <= to_date));
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
let total = result_logs.len() as u64;
|
||||
let filtered_log_refs: Vec<_> = iterator.collect();
|
||||
let total = filtered_log_refs.len() as u64;
|
||||
|
||||
// 应用分页
|
||||
if let Some(offset) = request.query.offset {
|
||||
result_logs = result_logs.into_iter().skip(offset).collect();
|
||||
}
|
||||
let paginated_log_refs = filtered_log_refs
|
||||
.into_iter()
|
||||
.skip(request.query.offset.unwrap_or(0))
|
||||
.take(request.query.limit.unwrap_or(usize::MAX));
|
||||
|
||||
if let Some(limit) = request.query.limit {
|
||||
result_logs = result_logs.into_iter().take(limit).collect();
|
||||
}
|
||||
let result_logs: Vec<RequestLog> = paginated_log_refs.cloned().collect();
|
||||
let active = if auth_header == auth_token {
|
||||
Some(state.request_manager.active_requests)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let error = if auth_header == auth_token {
|
||||
Some(state.request_manager.error_requests)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
drop(state);
|
||||
|
||||
Ok(Json(LogsResponse {
|
||||
status: ApiStatus::Success,
|
||||
total,
|
||||
active: if auth_header == auth_token {
|
||||
Some(state.request_manager.active_requests)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
error: if auth_header == auth_token {
|
||||
Some(state.request_manager.error_requests)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
active,
|
||||
error,
|
||||
logs: result_logs,
|
||||
timestamp: Local::now().to_string(),
|
||||
timestamp: Local::now(),
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -246,5 +235,5 @@ pub struct LogsResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<u64>,
|
||||
pub logs: Vec<RequestLog>,
|
||||
pub timestamp: String,
|
||||
pub timestamp: DateTime<Local>,
|
||||
}
|
||||
|
@@ -1,9 +1,9 @@
|
||||
use crate::app::{
|
||||
constant::{
|
||||
CONTENT_TYPE_TEXT_CSS_WITH_UTF8, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
|
||||
CONTENT_TYPE_TEXT_JS_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_ABOUT_PATH,
|
||||
ROUTE_API_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH, ROUTE_PROXIES_PATH,
|
||||
ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENS_PATH,
|
||||
ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH,
|
||||
ROUTE_PROXIES_PATH, ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH,
|
||||
ROUTE_TOKENS_PATH, header_value_text_css_utf8, header_value_text_html_utf8,
|
||||
header_value_text_js_utf8, header_value_text_plain_utf8,
|
||||
},
|
||||
model::{AppConfig, PageContent},
|
||||
};
|
||||
@@ -19,7 +19,7 @@ use axum::{
|
||||
|
||||
pub async fn handle_env_example() -> impl IntoResponse {
|
||||
Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(include_str!("../../../.env.example")))
|
||||
.unwrap()
|
||||
}
|
||||
@@ -28,15 +28,15 @@ pub async fn handle_env_example() -> impl IntoResponse {
|
||||
pub async fn handle_config_page() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_CONFIG_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(include_str!("../../../static/config.min.html")))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -47,13 +47,13 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
|
||||
"shared-styles.css" => {
|
||||
match AppConfig::get_page_content(ROUTE_SHARED_STYLES_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_css_utf8())
|
||||
.body(Body::from(include_str!(
|
||||
"../../../static/shared-styles.min.css"
|
||||
)))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) | PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_css_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -61,13 +61,13 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
|
||||
"shared.js" => {
|
||||
match AppConfig::get_page_content(ROUTE_SHARED_JS_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_js_utf8())
|
||||
.body(Body::from(
|
||||
include_str!("../../../static/shared.min.js").to_string(),
|
||||
))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) | PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_js_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -82,15 +82,15 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
|
||||
pub async fn handle_readme() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_README_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(include_str!("../../../static/readme.min.html")))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -104,11 +104,11 @@ pub async fn handle_about() -> impl IntoResponse {
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -117,17 +117,17 @@ pub async fn handle_about() -> impl IntoResponse {
|
||||
pub async fn handle_build_key_page() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_BUILD_KEY_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(include_str!(
|
||||
"../../../static/build_key.min.html"
|
||||
)))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -136,15 +136,15 @@ pub async fn handle_build_key_page() -> impl IntoResponse {
|
||||
pub async fn handle_tokens_page() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_TOKENS_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(include_str!("../../../static/tokens.min.html")))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -153,15 +153,15 @@ pub async fn handle_tokens_page() -> impl IntoResponse {
|
||||
pub async fn handle_proxies_page() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_PROXIES_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(include_str!("../../../static/proxies.min.html")))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
@@ -170,15 +170,15 @@ pub async fn handle_proxies_page() -> impl IntoResponse {
|
||||
pub async fn handle_api_page() -> impl IntoResponse {
|
||||
match AppConfig::get_page_content(ROUTE_API_PATH).unwrap_or_default() {
|
||||
PageContent::Default => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(include_str!("../../../static/api.min.html")))
|
||||
.unwrap(),
|
||||
PageContent::Text(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_plain_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
PageContent::Html(content) => Response::builder()
|
||||
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
|
||||
.header(CONTENT_TYPE, header_value_text_html_utf8())
|
||||
.body(Body::from(content))
|
||||
.unwrap(),
|
||||
}
|
||||
|
@@ -2,7 +2,8 @@ use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP, OBJECT_CHAT_COMPLETION,
|
||||
OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
OBJECT_CHAT_COMPLETION_CHUNK, header_value_chunked, header_value_event_stream,
|
||||
header_value_json, header_value_keep_alive, header_value_no_cache_revalidate,
|
||||
},
|
||||
lazy::{
|
||||
AUTH_TOKEN, GENERAL_TIMEZONE, IS_NO_REQUEST_LOGS, IS_UNLIMITED_REQUEST_LOGS,
|
||||
@@ -26,12 +27,10 @@ use crate::{
|
||||
},
|
||||
core::{
|
||||
config::KeyConfig,
|
||||
constant::{Models, USAGE_CHECK_MODELS},
|
||||
constant::Models,
|
||||
error::StreamError,
|
||||
model::{
|
||||
ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage,
|
||||
},
|
||||
stream::{StreamDecoder, StreamMessage},
|
||||
model::{ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role},
|
||||
stream::decoder::{StreamDecoder, StreamMessage},
|
||||
},
|
||||
leak::intern_string,
|
||||
};
|
||||
@@ -51,17 +50,17 @@ use axum::{
|
||||
use bytes::Bytes;
|
||||
use futures::StreamExt;
|
||||
use prost::Message as _;
|
||||
use std::{borrow::Cow, sync::atomic::{AtomicUsize, Ordering}};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{Arc, atomic::AtomicBool},
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::model::{ChatRequest, Model};
|
||||
|
||||
const NO_CACHE: &str = "no-cache, must-revalidate";
|
||||
const KEEP_ALIVE: &str = "keep-alive";
|
||||
use super::{constant::FREE_MODELS, model::ChatRequest};
|
||||
|
||||
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
@@ -106,7 +105,7 @@ pub async fn handle_models(
|
||||
));
|
||||
}
|
||||
|
||||
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
|
||||
let index = CURRENT_KEY_INDEX.load(Ordering::Acquire) % token_infos.len();
|
||||
let token_info = &token_infos[index];
|
||||
is_pri = true;
|
||||
(
|
||||
@@ -317,19 +316,18 @@ pub async fn handle_chat(
|
||||
if log.token_info.token == auth_token {
|
||||
if let Some(profile) = &log.token_info.profile {
|
||||
if profile.stripe.membership_type == MembershipType::Free {
|
||||
let is_premium = USAGE_CHECK_MODELS.contains(&model);
|
||||
need_profile_check = if is_premium {
|
||||
profile
|
||||
.usage
|
||||
.premium
|
||||
.max_requests
|
||||
.is_some_and(|max| profile.usage.premium.num_requests >= max)
|
||||
} else {
|
||||
need_profile_check = if FREE_MODELS.contains(&model.id) {
|
||||
profile
|
||||
.usage
|
||||
.standard
|
||||
.max_requests
|
||||
.is_some_and(|max| profile.usage.standard.num_requests >= max)
|
||||
} else {
|
||||
profile
|
||||
.usage
|
||||
.premium
|
||||
.max_requests
|
||||
.is_some_and(|max| profile.usage.premium.num_requests >= max)
|
||||
};
|
||||
}
|
||||
break;
|
||||
@@ -350,15 +348,14 @@ pub async fn handle_chat(
|
||||
let next_id = state
|
||||
.request_manager
|
||||
.request_logs
|
||||
.last()
|
||||
.back()
|
||||
.map_or(1, |log| log.id + 1);
|
||||
current_id = next_id;
|
||||
|
||||
// 如果需要获取用户使用情况,创建后台任务获取profile
|
||||
if Model::is_usage_check(
|
||||
model,
|
||||
UsageCheck::from_proto(current_config.usage_check_models.as_ref()),
|
||||
) {
|
||||
if model.is_usage_check(UsageCheck::from_proto(
|
||||
current_config.usage_check_models.as_ref(),
|
||||
)) {
|
||||
let auth_token_clone = auth_token.clone();
|
||||
let state_clone = state_clone.clone();
|
||||
let log_id = next_id;
|
||||
@@ -398,7 +395,7 @@ pub async fn handle_chat(
|
||||
});
|
||||
}
|
||||
|
||||
state.request_manager.request_logs.push(RequestLog {
|
||||
state.request_manager.request_logs.push_back(RequestLog {
|
||||
id: next_id,
|
||||
timestamp: request_time,
|
||||
model: intern_string(request.model),
|
||||
@@ -546,19 +543,16 @@ pub async fn handle_chat(
|
||||
let is_start = Arc::new(AtomicBool::new(true));
|
||||
let start_time = std::time::Instant::now();
|
||||
let decoder = Arc::new(Mutex::new(StreamDecoder::new()));
|
||||
let is_usage_sent = Arc::new(AtomicBool::new(false));
|
||||
let need_usage = if request.stream_options.is_some_and(|opt| opt.include_usage) {
|
||||
Arc::new(Mutex::new(NeedUsage::Need {
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
}))
|
||||
} else {
|
||||
Arc::new(Mutex::new(NeedUsage::None))
|
||||
};
|
||||
let need_usage = Arc::new(Mutex::new(NeedUsage::Need {
|
||||
is_need: request.stream_options.is_some_and(|opt| opt.include_usage),
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
}));
|
||||
let usage_uuid = Arc::new(Mutex::new(None));
|
||||
|
||||
// 定义消息处理器的上下文结构体
|
||||
struct MessageProcessContext<'a> {
|
||||
@@ -568,15 +562,17 @@ pub async fn handle_chat(
|
||||
start_time: std::time::Instant,
|
||||
state: &'a Mutex<AppState>,
|
||||
current_id: u64,
|
||||
usage_uuid: &'a Mutex<Option<String>>,
|
||||
need_usage: &'a Mutex<NeedUsage>,
|
||||
is_usage_sent: &'a AtomicBool,
|
||||
created: i64,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
enum NeedUsage {
|
||||
#[default]
|
||||
None,
|
||||
Taked,
|
||||
Need {
|
||||
is_need: bool,
|
||||
client: reqwest::Client,
|
||||
auth_token: String,
|
||||
checksum: String,
|
||||
@@ -589,7 +585,10 @@ pub async fn handle_chat(
|
||||
impl NeedUsage {
|
||||
#[inline(always)]
|
||||
const fn is_need(&self) -> bool {
|
||||
matches!(*self, Self::Need { .. })
|
||||
match self {
|
||||
Self::Taked => false,
|
||||
Self::Need { is_need, .. } => *is_need,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -602,8 +601,8 @@ pub async fn handle_chat(
|
||||
async fn process_messages(
|
||||
messages: Vec<StreamMessage>,
|
||||
ctx: &MessageProcessContext<'_>,
|
||||
) -> String {
|
||||
let mut response_data = String::new();
|
||||
) -> Vec<u8> {
|
||||
let mut response_data = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
match message {
|
||||
@@ -611,7 +610,7 @@ pub async fn handle_chat(
|
||||
let is_first = ctx.is_start.load(Ordering::Acquire);
|
||||
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
id: ctx.response_id,
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: if is_first { Some(ctx.model) } else { None },
|
||||
@@ -641,65 +640,13 @@ pub async fn handle_chat(
|
||||
},
|
||||
};
|
||||
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
response_data.extend_from_slice(b"data: ");
|
||||
response_data.extend_from_slice(&serde_json::to_vec(&response).unwrap());
|
||||
response_data.extend_from_slice(b"\n\n");
|
||||
}
|
||||
StreamMessage::Usage(usage_uuid) => {
|
||||
if !ctx.is_usage_sent.load(Ordering::Acquire) {
|
||||
if let NeedUsage::Need {
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
} = ctx.need_usage.lock().await.take()
|
||||
{
|
||||
let usage = get_token_usage(
|
||||
client,
|
||||
&auth_token,
|
||||
&checksum,
|
||||
&client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
usage_uuid,
|
||||
)
|
||||
.await;
|
||||
if let Some(ref usage) = usage {
|
||||
let mut state = ctx.state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_manager
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == ctx.current_id)
|
||||
{
|
||||
if let Some(chain) = &mut log.chain {
|
||||
chain.usage = OptionUsage::Uasge {
|
||||
input: usage.prompt_tokens,
|
||||
output: usage.completion_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: None,
|
||||
choices: vec![],
|
||||
usage: TriState::Some(usage.unwrap_or_default()),
|
||||
};
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
ctx.is_usage_sent.store(true, Ordering::Release);
|
||||
}
|
||||
} else {
|
||||
crate::debug_println!("usage is sent, but find {usage_uuid}");
|
||||
if !usage_uuid.is_empty() && ctx.need_usage.lock().await.is_need() {
|
||||
*ctx.usage_uuid.lock().await = Some(usage_uuid);
|
||||
}
|
||||
}
|
||||
StreamMessage::StreamEnd => {
|
||||
@@ -720,7 +667,7 @@ pub async fn handle_chat(
|
||||
}
|
||||
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
id: ctx.response_id,
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: None,
|
||||
@@ -740,31 +687,65 @@ pub async fn handle_chat(
|
||||
TriState::None
|
||||
},
|
||||
};
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
if !ctx.is_usage_sent.load(Ordering::Acquire)
|
||||
&& ctx.need_usage.lock().await.is_need()
|
||||
{
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: None,
|
||||
choices: vec![],
|
||||
usage: TriState::Some(Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}),
|
||||
response_data.extend_from_slice(b"data: ");
|
||||
response_data.extend_from_slice(&serde_json::to_vec(&response).unwrap());
|
||||
response_data.extend_from_slice(b"\n\n");
|
||||
if let Some(usage_uuid) = ctx.usage_uuid.lock().await.take() {
|
||||
if let NeedUsage::Need {
|
||||
is_need,
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
} = ctx.need_usage.lock().await.take()
|
||||
{
|
||||
let usage = if *crate::app::lazy::REAL_USAGE {
|
||||
let usage = tokio::spawn(get_token_usage(
|
||||
client, auth_token, checksum, client_key, timezone, is_pri,
|
||||
usage_uuid,
|
||||
))
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if let Some(ref usage) = usage {
|
||||
let mut state = ctx.state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_manager
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == ctx.current_id)
|
||||
{
|
||||
if let Some(chain) = &mut log.chain {
|
||||
chain.usage = OptionUsage::Uasge {
|
||||
input: usage.prompt_tokens,
|
||||
output: usage.completion_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
usage
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if is_need {
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id,
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: ctx.created,
|
||||
model: None,
|
||||
choices: vec![],
|
||||
usage: TriState::Some(usage.unwrap_or_default()),
|
||||
};
|
||||
response_data.extend_from_slice(b"data: ");
|
||||
response_data
|
||||
.extend_from_slice(&serde_json::to_vec(&response).unwrap());
|
||||
response_data.extend_from_slice(b"\n\n");
|
||||
}
|
||||
};
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
ctx.is_usage_sent.store(true, Ordering::Release);
|
||||
};
|
||||
}
|
||||
}
|
||||
StreamMessage::Debug(debug_prompt) => {
|
||||
if let Ok(mut state) = ctx.state.try_lock() {
|
||||
@@ -781,7 +762,7 @@ pub async fn handle_chat(
|
||||
} else {
|
||||
log.chain = Some(Chain {
|
||||
prompt: Prompt::new(debug_prompt),
|
||||
delays: vec![],
|
||||
delays: None,
|
||||
usage: OptionUsage::None,
|
||||
});
|
||||
}
|
||||
@@ -866,6 +847,8 @@ pub async fn handle_chat(
|
||||
}
|
||||
}
|
||||
|
||||
let created = Arc::new(std::sync::OnceLock::new());
|
||||
|
||||
// 处理后续的stream
|
||||
let stream = stream
|
||||
.then({
|
||||
@@ -877,27 +860,29 @@ pub async fn handle_chat(
|
||||
let response_id = response_id.clone();
|
||||
let is_start = is_start.clone();
|
||||
let state = state.clone();
|
||||
let is_usage_sent = is_usage_sent.clone();
|
||||
let need_usage = need_usage.clone();
|
||||
let usage_uuid = usage_uuid.clone();
|
||||
let created = created.clone();
|
||||
|
||||
async move {
|
||||
let chunk = match chunk {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
crate::debug_println!("Find chunk error: {e}");
|
||||
Err(_) => {
|
||||
// crate::debug_println!("Find chunk error: {e}");
|
||||
return Ok::<_, Infallible>(Bytes::new());
|
||||
}
|
||||
};
|
||||
|
||||
let ctx = MessageProcessContext {
|
||||
response_id: &response_id,
|
||||
model,
|
||||
model: model.id,
|
||||
is_start: &is_start,
|
||||
start_time,
|
||||
state: &state,
|
||||
current_id,
|
||||
usage_uuid: &usage_uuid,
|
||||
need_usage: &need_usage,
|
||||
is_usage_sent: &is_usage_sent,
|
||||
created: *created.get_or_init(|| chrono::Utc::now().timestamp()),
|
||||
};
|
||||
|
||||
// 使用decoder处理chunk
|
||||
@@ -930,16 +915,16 @@ pub async fn handle_chat(
|
||||
}
|
||||
};
|
||||
|
||||
let mut response_data = String::new();
|
||||
let mut response_data = Vec::new();
|
||||
|
||||
if let Some(first_msg) = decoder.lock().await.take_first_result() {
|
||||
let first_response = process_messages(first_msg, &ctx).await;
|
||||
response_data.push_str(&first_response);
|
||||
response_data.extend_from_slice(&first_response);
|
||||
}
|
||||
|
||||
let current_response = process_messages(messages, &ctx).await;
|
||||
if !current_response.is_empty() {
|
||||
response_data.push_str(¤t_response);
|
||||
response_data.extend_from_slice(¤t_response);
|
||||
}
|
||||
|
||||
Ok(Bytes::from(response_data))
|
||||
@@ -970,10 +955,10 @@ pub async fn handle_chat(
|
||||
}));
|
||||
|
||||
Ok(Response::builder()
|
||||
.header(CACHE_CONTROL, NO_CACHE)
|
||||
.header(CONNECTION, KEEP_ALIVE)
|
||||
.header(CONTENT_TYPE, "text/event-stream")
|
||||
.header(TRANSFER_ENCODING, "chunked")
|
||||
.header(CACHE_CONTROL, header_value_no_cache_revalidate())
|
||||
.header(CONNECTION, header_value_keep_alive())
|
||||
.header(CONTENT_TYPE, header_value_event_stream())
|
||||
.header(TRANSFER_ENCODING, header_value_chunked())
|
||||
.body(Body::from_stream(stream))
|
||||
.unwrap())
|
||||
} else {
|
||||
@@ -1095,13 +1080,7 @@ pub async fn handle_chat(
|
||||
|
||||
let (usage1, usage2) = if !usage_uuid.is_empty() {
|
||||
let result = get_token_usage(
|
||||
client,
|
||||
&auth_token,
|
||||
&checksum,
|
||||
&client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
usage_uuid,
|
||||
client, auth_token, checksum, client_key, timezone, is_pri, usage_uuid,
|
||||
)
|
||||
.await;
|
||||
let result2 = match result {
|
||||
@@ -1117,10 +1096,10 @@ pub async fn handle_chat(
|
||||
};
|
||||
|
||||
let response_data = ChatResponse {
|
||||
id: format!("chatcmpl-{trace_id}"),
|
||||
id: &format!("chatcmpl-{trace_id}"),
|
||||
object: OBJECT_CHAT_COMPLETION,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: Some(model),
|
||||
model: Some(model.id),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: Some(Message {
|
||||
@@ -1155,11 +1134,11 @@ pub async fn handle_chat(
|
||||
}
|
||||
}
|
||||
|
||||
let data = serde_json::to_string(&response_data).unwrap();
|
||||
let data = serde_json::to_vec(&response_data).unwrap();
|
||||
Ok(Response::builder()
|
||||
.header(CACHE_CONTROL, NO_CACHE)
|
||||
.header(CONNECTION, KEEP_ALIVE)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.header(CACHE_CONTROL, header_value_no_cache_revalidate())
|
||||
.header(CONNECTION, header_value_keep_alive())
|
||||
.header(CONTENT_TYPE, header_value_json())
|
||||
.header(CONTENT_LENGTH, data.len())
|
||||
.body(Body::from(data))
|
||||
.unwrap())
|
||||
|
@@ -1,2 +1 @@
|
||||
mod decoder;
|
||||
pub use decoder::*;
|
||||
pub mod decoder;
|
||||
|
@@ -1,14 +1,28 @@
|
||||
use crate::common::utils::InstantExt as _;
|
||||
use crate::core::{
|
||||
aiserver::v1::{StreamChatResponse, WebReference},
|
||||
error::{ChatError, StreamError},
|
||||
};
|
||||
use bytes::{Buf, BytesMut};
|
||||
use flate2::read::GzDecoder;
|
||||
use prost::Message;
|
||||
use std::io::Read;
|
||||
use std::time::Instant;
|
||||
|
||||
// 解压gzip数据
|
||||
pub trait InstantExt {
|
||||
fn duration_as_secs_f32(&mut self) -> f32;
|
||||
}
|
||||
|
||||
impl InstantExt for Instant {
|
||||
#[inline]
|
||||
fn duration_as_secs_f32(&mut self) -> f32 {
|
||||
let now = Instant::now();
|
||||
let duration = now.duration_since(*self);
|
||||
*self = now;
|
||||
duration.as_secs_f32()
|
||||
}
|
||||
}
|
||||
|
||||
/// 解压gzip数据
|
||||
#[inline]
|
||||
fn decompress_gzip(data: &[u8]) -> Option<Vec<u8>> {
|
||||
if data.len() < 3
|
||||
@@ -31,32 +45,6 @@ fn decompress_gzip(data: &[u8]) -> Option<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ToMarkdown {
|
||||
fn to_markdown(&self) -> String;
|
||||
}
|
||||
|
||||
impl ToMarkdown for Vec<WebReference> {
|
||||
#[inline]
|
||||
fn to_markdown(&self) -> String {
|
||||
if self.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut result = String::from("WebReferences:\n");
|
||||
for (i, web_ref) in self.iter().enumerate() {
|
||||
result.push_str(&format!(
|
||||
"{}. [{}]({})<{}>\n",
|
||||
i + 1,
|
||||
web_ref.title,
|
||||
web_ref.url,
|
||||
web_ref.chunk
|
||||
));
|
||||
}
|
||||
result.push('\n');
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Clone)]
|
||||
pub enum StreamMessage {
|
||||
// 调试
|
||||
@@ -64,6 +52,7 @@ pub enum StreamMessage {
|
||||
// 网络引用
|
||||
WebReference(Vec<WebReference>),
|
||||
// 内容开始标志
|
||||
#[cfg(test)]
|
||||
ContentStart,
|
||||
// 消息内容
|
||||
Content(String),
|
||||
@@ -77,18 +66,35 @@ impl StreamMessage {
|
||||
#[inline]
|
||||
fn convert_web_ref_to_content(self) -> Self {
|
||||
match self {
|
||||
StreamMessage::WebReference(refs) => StreamMessage::Content(refs.to_markdown()),
|
||||
StreamMessage::WebReference(refs) => {
|
||||
if refs.is_empty() {
|
||||
return StreamMessage::Content(String::new());
|
||||
}
|
||||
|
||||
let mut result = String::from("WebReferences:\n");
|
||||
for (i, web_ref) in refs.iter().enumerate() {
|
||||
result.push_str(&format!(
|
||||
"{}. [{}]({})<{}>\n",
|
||||
i + 1,
|
||||
web_ref.title,
|
||||
web_ref.url,
|
||||
web_ref.chunk
|
||||
));
|
||||
}
|
||||
result.push('\n');
|
||||
StreamMessage::Content(result)
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StreamDecoder {
|
||||
// 主要数据缓冲区 (24字节)
|
||||
buffer: Vec<u8>,
|
||||
// 结果相关 (24字节 + 24字节)
|
||||
// 主要数据缓冲区
|
||||
buffer: BytesMut,
|
||||
// 结果相关 (24字节 + 48字节)
|
||||
first_result: Option<Vec<StreamMessage>>,
|
||||
content_delays: Vec<(String, f64)>,
|
||||
content_delays: Option<(String, Vec<(u32, f32)>)>,
|
||||
// 计数器和时间 (8字节 + 8字节)
|
||||
empty_stream_count: usize,
|
||||
last_content_time: Instant,
|
||||
@@ -101,9 +107,9 @@ pub struct StreamDecoder {
|
||||
impl StreamDecoder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffer: Vec::new(),
|
||||
buffer: BytesMut::new(),
|
||||
first_result: None,
|
||||
content_delays: Vec::new(),
|
||||
content_delays: None,
|
||||
empty_stream_count: 0,
|
||||
last_content_time: Instant::now(),
|
||||
first_result_ready: false,
|
||||
@@ -112,10 +118,12 @@ impl StreamDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_empty_stream_count(&self) -> usize {
|
||||
self.empty_stream_count
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reset_empty_stream_count(&mut self) {
|
||||
if self.empty_stream_count > 0 {
|
||||
crate::debug_println!(
|
||||
@@ -126,10 +134,7 @@ impl StreamDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_empty_stream_count(&mut self) {
|
||||
self.empty_stream_count += 1;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn take_first_result(&mut self) -> Option<Vec<StreamMessage>> {
|
||||
if !self.buffer.is_empty() {
|
||||
return None;
|
||||
@@ -145,14 +150,17 @@ impl StreamDecoder {
|
||||
!self.buffer.is_empty()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_first_result_ready(&self) -> bool {
|
||||
self.first_result_ready
|
||||
}
|
||||
|
||||
pub fn take_content_delays(&mut self) -> Vec<(String, f64)> {
|
||||
#[inline]
|
||||
pub fn take_content_delays(&mut self) -> Option<(String, Vec<(u32, f32)>)> {
|
||||
std::mem::take(&mut self.content_delays)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn no_first_cache(mut self) -> Self {
|
||||
self.first_result_ready = true;
|
||||
self.first_result_taken = true;
|
||||
@@ -172,7 +180,7 @@ impl StreamDecoder {
|
||||
|
||||
if self.buffer.len() < 5 {
|
||||
if self.buffer.is_empty() {
|
||||
self.increment_empty_stream_count();
|
||||
self.empty_stream_count += 1;
|
||||
|
||||
return Err(StreamError::EmptyStream);
|
||||
}
|
||||
@@ -182,20 +190,68 @@ impl StreamDecoder {
|
||||
|
||||
self.reset_empty_stream_count();
|
||||
|
||||
let mut messages = Vec::new();
|
||||
let reserve = {
|
||||
let mut offset = 0;
|
||||
let mut count = 0;
|
||||
while offset + 5 <= self.buffer.len() {
|
||||
let msg_len: usize;
|
||||
|
||||
// SAFETY: The loop condition `offset + 5 <= self.buffer.len()` guarantees
|
||||
// that indices `offset` through `offset + 4` are within bounds.
|
||||
unsafe {
|
||||
msg_len = u32::from_be_bytes([
|
||||
*self.buffer.get_unchecked(offset + 1),
|
||||
*self.buffer.get_unchecked(offset + 2),
|
||||
*self.buffer.get_unchecked(offset + 3),
|
||||
*self.buffer.get_unchecked(offset + 4),
|
||||
]) as usize;
|
||||
}
|
||||
|
||||
if msg_len == 0 {
|
||||
offset += 5;
|
||||
continue;
|
||||
}
|
||||
|
||||
if offset + 5 + msg_len > self.buffer.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
offset += 5 + msg_len;
|
||||
count += 1;
|
||||
}
|
||||
count
|
||||
};
|
||||
|
||||
if let Some(content_delays) = self.content_delays.as_mut() {
|
||||
content_delays.0.reserve(reserve);
|
||||
content_delays.1.reserve(reserve);
|
||||
} else {
|
||||
self.content_delays =
|
||||
Some((String::with_capacity(reserve), Vec::with_capacity(reserve)));
|
||||
}
|
||||
|
||||
let mut messages = Vec::with_capacity(reserve);
|
||||
let mut offset = 0;
|
||||
|
||||
while offset + 5 <= self.buffer.len() {
|
||||
let msg_type = self.buffer[offset];
|
||||
let msg_len = u32::from_be_bytes([
|
||||
self.buffer[offset + 1],
|
||||
self.buffer[offset + 2],
|
||||
self.buffer[offset + 3],
|
||||
self.buffer[offset + 4],
|
||||
]) as usize;
|
||||
let msg_type: u8;
|
||||
let msg_len: usize;
|
||||
|
||||
// SAFETY: The loop condition `offset + 5 <= self.buffer.len()` guarantees
|
||||
// that indices `offset` through `offset + 4` are within bounds.
|
||||
unsafe {
|
||||
msg_type = *self.buffer.get_unchecked(offset);
|
||||
msg_len = u32::from_be_bytes([
|
||||
*self.buffer.get_unchecked(offset + 1),
|
||||
*self.buffer.get_unchecked(offset + 2),
|
||||
*self.buffer.get_unchecked(offset + 3),
|
||||
*self.buffer.get_unchecked(offset + 4),
|
||||
]) as usize;
|
||||
}
|
||||
|
||||
if msg_len == 0 {
|
||||
offset += 5;
|
||||
#[cfg(test)]
|
||||
messages.push(StreamMessage::ContentStart);
|
||||
continue;
|
||||
}
|
||||
@@ -209,8 +265,13 @@ impl StreamDecoder {
|
||||
if let Some(msg) = self.process_message(msg_type, msg_data)? {
|
||||
if let StreamMessage::Content(content) = &msg {
|
||||
self.has_seen_content = true;
|
||||
let delay = self.last_content_time.duration_as_secs_f64();
|
||||
self.content_delays.push((content.clone(), delay));
|
||||
let delay = self.last_content_time.duration_as_secs_f32();
|
||||
if let Some(content_delays) = self.content_delays.as_mut() {
|
||||
content_delays.0.push_str(content);
|
||||
content_delays
|
||||
.1
|
||||
.push((content.chars().count() as u32, delay));
|
||||
}
|
||||
}
|
||||
if convert_web_ref {
|
||||
messages.push(msg.convert_web_ref_to_content());
|
||||
@@ -222,7 +283,7 @@ impl StreamDecoder {
|
||||
offset += 5 + msg_len;
|
||||
}
|
||||
|
||||
self.buffer.drain(..offset);
|
||||
self.buffer.advance(offset);
|
||||
|
||||
if !self.first_result_taken && !messages.is_empty() {
|
||||
if self.first_result.is_none() {
|
||||
|
54
src/leak.rs
54
src/leak.rs
@@ -7,12 +7,11 @@ struct StringPool {
|
||||
}
|
||||
|
||||
impl StringPool {
|
||||
// 驻留字符串
|
||||
/// 驻留字符串
|
||||
fn intern(&mut self, s: &str) -> &'static str {
|
||||
if let Some(&interned) = self.pool.get(s) {
|
||||
interned
|
||||
} else {
|
||||
// 如果字符串不存在,使用 Box::leak 将其泄漏,并添加到 pool 中
|
||||
let leaked: &'static str = Box::leak(Box::from(s));
|
||||
self.pool.insert(leaked);
|
||||
leaked
|
||||
@@ -20,60 +19,9 @@ impl StringPool {
|
||||
}
|
||||
}
|
||||
|
||||
// 全局 StringPool 实例
|
||||
static STRING_POOL: LazyLock<Mutex<StringPool>> =
|
||||
LazyLock::new(|| Mutex::new(StringPool::default()));
|
||||
|
||||
pub fn intern_string<S: AsRef<str>>(s: S) -> &'static str {
|
||||
STRING_POOL.lock().intern(s.as_ref())
|
||||
}
|
||||
|
||||
// #[derive(Clone, Copy, PartialEq, Eq, Hash)]
|
||||
// pub struct InternedString(&'static str);
|
||||
|
||||
// impl InternedString {
|
||||
// #[inline(always)]
|
||||
// pub fn as_str(&self) -> &'static str {
|
||||
// self.0
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl Deref for InternedString {
|
||||
// type Target = str;
|
||||
|
||||
// #[inline(always)]
|
||||
// fn deref(&self) -> &'static Self::Target {
|
||||
// self.0
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl Borrow<str> for InternedString {
|
||||
// #[inline(always)]
|
||||
// fn borrow(&self) -> &'static str {
|
||||
// self.0
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl core::fmt::Debug for InternedString {
|
||||
// #[inline(always)]
|
||||
// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
// self.0.fmt(f)
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl core::fmt::Display for InternedString {
|
||||
// #[inline(always)]
|
||||
// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
// self.0.fmt(f)
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl serde::Serialize for InternedString {
|
||||
// #[inline]
|
||||
// fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
// where
|
||||
// S: serde::Serializer,
|
||||
// {
|
||||
// serializer.serialize_str(self.0)
|
||||
// }
|
||||
// }
|
||||
|
297
static/logs.html
297
static/logs.html
@@ -649,6 +649,7 @@
|
||||
<th>Token信息</th>
|
||||
<th>对话</th>
|
||||
<th>用时</th>
|
||||
<th>输入/输出</th>
|
||||
<th>流式响应</th>
|
||||
<th>状态</th>
|
||||
<th>错误信息</th>
|
||||
@@ -1043,20 +1044,16 @@
|
||||
updatePaginationControls();
|
||||
|
||||
tbody.innerHTML = data.logs.map(log => {
|
||||
// 预处理延迟数据以避免HTML解析问题
|
||||
let delaysData = '';
|
||||
if (log.chain && log.chain.delays) {
|
||||
// 为每个delay项创建安全的文本和数值对
|
||||
const safeDelays = log.chain.delays.map(item => {
|
||||
if (!Array.isArray(item) || item.length < 2) return ["", 0];
|
||||
|
||||
// 将延迟数据中的文本部分编码为Base64,避免HTML解析问题
|
||||
const textPart = typeof item[0] === 'string' ? btoa(encodeURIComponent(item[0])) : "";
|
||||
const delayPart = typeof item[1] === 'number' ? item[1] : 0;
|
||||
|
||||
return [textPart, delayPart];
|
||||
});
|
||||
delaysData = JSON.stringify(safeDelays);
|
||||
let delaysDataAttribute = '';
|
||||
if (log.chain && Array.isArray(log.chain.delays) && log.chain.delays.length === 2 && typeof log.chain.delays[0] === 'string' && Array.isArray(log.chain.delays[1])) {
|
||||
try {
|
||||
const originalDelaysJson = JSON.stringify(log.chain.delays);
|
||||
const escapedJsonString = escapeHtml(originalDelaysJson);
|
||||
delaysDataAttribute = `data-delays='${escapedJsonString}'`;
|
||||
} catch (e) {
|
||||
console.error('Failed to stringify delays data:', e, log.chain.delays);
|
||||
delaysDataAttribute = '';
|
||||
}
|
||||
}
|
||||
|
||||
return `<tr>
|
||||
@@ -1064,11 +1061,12 @@
|
||||
<td>${new Date(log.timestamp).toLocaleString()}</td>
|
||||
<td>${log.model}</td>
|
||||
<td><div class="token-info-tooltip"><button class="info-button" onclick='showTokenModal(${JSON.stringify(log.token_info)})'>查看详情<div class="tooltip-content">${formatSimpleTokenInfo(log.token_info)}</div></button></div></td>
|
||||
<td>${log.chain ? `<div class="token-info-tooltip prompt-preview"><button class="info-button view-conversation" data-prompt="${encodeURIComponent(JSON.stringify(log.chain.prompt))}" data-delays='${delaysData}'>查看对话<div class="tooltip-content">${formatDialogPreview(JSON.stringify(log.chain.prompt))}</div></button></div>` : '-'}</td>
|
||||
<td>${log.chain ? `<div class="token-info-tooltip prompt-preview"><button class="info-button view-conversation" data-prompt="${encodeURIComponent(JSON.stringify(log.chain.prompt))}" ${delaysDataAttribute}>查看对话<div class="tooltip-content">${formatDialogPreview(JSON.stringify(log.chain.prompt))}</div></button></div>` : '-'}</td>
|
||||
<td>${formatTiming(log.timing.total)}</td>
|
||||
<td>${formatUsage(log.chain?.usage)}</td>
|
||||
<td>${log.stream ? '是' : '否'}</td>
|
||||
<td>${log.status}</td>
|
||||
<td>${log.error.details || log.error || '-'}</td>
|
||||
<td>${typeof log.error === 'string' ? log.error : log.error?.error ?? '-'}</td>
|
||||
</tr>`;
|
||||
}).join('');
|
||||
|
||||
@@ -1372,6 +1370,36 @@
|
||||
return messages;
|
||||
}
|
||||
|
||||
function escapeHtml(content) {
|
||||
// 先转义HTML特殊字符
|
||||
const escaped = content
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, ''');
|
||||
return escaped;
|
||||
}
|
||||
|
||||
function escapeHtmlAndControlChars(content) {
|
||||
// 先转义HTML特殊字符
|
||||
let escaped = content
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, ''');
|
||||
|
||||
// 然后转义控制字符
|
||||
escaped = escaped
|
||||
.replace(/\\/g, '\\\\')
|
||||
.replace(/\n/g, '\\n')
|
||||
.replace(/\t/g, '\\t')
|
||||
.replace(/\r/g, '\\r');
|
||||
|
||||
return escaped;
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化对话内容为HTML表格
|
||||
* @param {Array<{role: string, content: string}>} messages - 对话消息数组
|
||||
@@ -1388,17 +1416,6 @@
|
||||
'assistant': '助手'
|
||||
};
|
||||
|
||||
function escapeHtml(content) {
|
||||
// 先转义HTML特殊字符
|
||||
const escaped = content
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, ''');
|
||||
return escaped;
|
||||
}
|
||||
|
||||
return `<table class="message-table"><thead><tr><th>角色</th><th>内容</th></tr></thead><tbody>${messages.map(msg =>
|
||||
`<tr><td>${roleLabels[msg.role] || msg.role}</td><td>${escapeHtml(msg.content).replace(/\n/g, '<br>')}</td></tr>`
|
||||
).join('')}</tbody></table>`;
|
||||
@@ -1407,72 +1424,102 @@
|
||||
/**
|
||||
* 显示对话详情弹窗
|
||||
* @param {string} promptStr - 对话提示字符串
|
||||
* @param {Array} delays - 延迟数据数组
|
||||
* @param {Array} delaysTuple - 延迟数据数组
|
||||
*/
|
||||
function showConversationModal(promptStr, delays) {
|
||||
function showConversationModal(promptStr, delaysTuple) {
|
||||
try {
|
||||
const modal = document.getElementById('conversationModal');
|
||||
const dialogContent = document.getElementById('dialogContent');
|
||||
const delaysContent = document.getElementById('delaysContent');
|
||||
const tabPrompt = document.getElementById('tab-prompt');
|
||||
const tabDelays = document.getElementById('tab-delays');
|
||||
const delaysTableBody = document.querySelector('#delaysTable tbody');
|
||||
const delayChartContainer = document.querySelector('.delay-chart-container');
|
||||
|
||||
if (!modal || !dialogContent || !delaysContent || !tabPrompt || !tabDelays) {
|
||||
if (!modal || !dialogContent || !delaysContent || !tabPrompt || !tabDelays || !delaysTableBody || !delayChartContainer) {
|
||||
console.error('Modal elements not found');
|
||||
return;
|
||||
}
|
||||
|
||||
// 显示对话内容
|
||||
const messages = parsePrompt(promptStr);
|
||||
dialogContent.innerHTML = formatPromptToTable(messages);
|
||||
// 处理 Prompt
|
||||
try {
|
||||
const messages = parsePrompt(promptStr);
|
||||
dialogContent.innerHTML = formatPromptToTable(messages);
|
||||
} catch (e) {
|
||||
console.error('解析 Prompt 数据失败:', e);
|
||||
dialogContent.innerHTML = '<p>无法加载对话内容。</p>';
|
||||
}
|
||||
|
||||
// 处理延迟数据
|
||||
if (delays && delays.length > 0) {
|
||||
const delaysTableBody = document.querySelector('#delaysTable tbody');
|
||||
delaysTableBody.innerHTML = '';
|
||||
// 处理 Delays
|
||||
delaysTableBody.innerHTML = '';
|
||||
delayChartContainer.innerHTML = '<canvas id="delayChart"></canvas>';
|
||||
|
||||
let fullText = '';
|
||||
let delayPoints = [];
|
||||
let chartDataPoints = [{ time: 0, chars: 0, text: '' }];
|
||||
|
||||
if (delaysTuple) {
|
||||
if (Array.isArray(delaysTuple) && delaysTuple.length === 2 && typeof delaysTuple[0] === 'string' && Array.isArray(delaysTuple[1])) {
|
||||
fullText = delaysTuple[0];
|
||||
delayPoints = delaysTuple[1];
|
||||
} else {
|
||||
console.warn('Delays data format is incorrect:', delaysTuple);
|
||||
}
|
||||
}
|
||||
|
||||
if (delayPoints.length > 0) {
|
||||
let currentIndex = 0;
|
||||
let totalChars = 0;
|
||||
let totalTime = 0;
|
||||
|
||||
// 解码并显示延迟数据
|
||||
delays.forEach(([encodedText, delay], index) => {
|
||||
try {
|
||||
const text = encodedText ? decodeURIComponent(atob(encodedText)) : '';
|
||||
totalChars += text.length;
|
||||
totalTime += delay;
|
||||
const rate = text.length / delay;
|
||||
const avgRate = totalChars / totalTime;
|
||||
|
||||
const row = document.createElement('tr');
|
||||
row.innerHTML = `
|
||||
<td>${index + 1}</td>
|
||||
<td>${text.replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>').replace(/"/g, '"').replace(/'/g, ''')}</td>
|
||||
<td>${delay.toFixed(3)}</td>
|
||||
<td>${rate.toFixed(1)} (平均: ${avgRate.toFixed(1)})</td>
|
||||
`;
|
||||
delaysTableBody.appendChild(row);
|
||||
} catch (e) {
|
||||
console.error('处理延迟数据项失败:', e);
|
||||
delayPoints.forEach(([length, deltaTime], index) => {
|
||||
if (typeof length !== 'number' || typeof deltaTime !== 'number' || length < 0 || deltaTime < 0) {
|
||||
console.warn(`Skipping invalid delay point at index ${index}:`, [length, deltaTime]);
|
||||
return;
|
||||
}
|
||||
|
||||
const chunkText = fullText.substring(currentIndex, currentIndex + length);
|
||||
currentIndex += length;
|
||||
totalChars += length;
|
||||
totalTime += deltaTime;
|
||||
|
||||
const rate = deltaTime > 0 ? (length / deltaTime) : Infinity;
|
||||
const avgRate = totalTime > 0 ? (totalChars / totalTime) : Infinity;
|
||||
|
||||
const row = document.createElement('tr');
|
||||
row.innerHTML = `
|
||||
<td>${index + 1}</td>
|
||||
<td>${escapeHtmlAndControlChars(chunkText)}</td>
|
||||
<td>${deltaTime.toFixed(3)}</td>
|
||||
<td>${isFinite(rate) ? rate.toFixed(1) : 'N/A'} (平均: ${isFinite(avgRate) ? avgRate.toFixed(1) : 'N/A'})</td>
|
||||
`;
|
||||
delaysTableBody.appendChild(row);
|
||||
|
||||
chartDataPoints.push({
|
||||
time: totalTime,
|
||||
chars: totalChars,
|
||||
text: chunkText
|
||||
});
|
||||
});
|
||||
|
||||
// 初始化延迟图表
|
||||
initDelayChart(delays);
|
||||
initDelayChart(chartDataPoints);
|
||||
tabDelays.style.display = '';
|
||||
|
||||
} else {
|
||||
document.querySelector('#delaysTable tbody').innerHTML = '<tr><td colspan="2">无延迟数据</td></tr>';
|
||||
document.querySelector('.delay-chart-container').innerHTML = '<div style="text-align: center; padding: 20px;">无延迟数据可供分析</div>';
|
||||
delaysTableBody.innerHTML = '<tr><td colspan="4">无延迟数据</td></tr>';
|
||||
delayChartContainer.innerHTML = '<div style="text-align: center; padding: 20px;">无延迟数据可供分析</div>';
|
||||
tabDelays.style.display = 'none';
|
||||
}
|
||||
|
||||
// 设置标签切换事件
|
||||
// 设置标签页
|
||||
tabPrompt.onclick = () => setActiveTab('prompt');
|
||||
tabDelays.onclick = () => setActiveTab('delays');
|
||||
|
||||
// 设置默认激活的标签页
|
||||
setActiveTab('prompt');
|
||||
modal.style.display = 'block';
|
||||
|
||||
} catch (e) {
|
||||
console.error('显示对话详情失败:', e);
|
||||
console.error('原始prompt:', promptStr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1506,10 +1553,14 @@
|
||||
|
||||
/**
|
||||
* 初始化延迟图表
|
||||
* @param {Array} delays - 延迟数组
|
||||
* @param {Array<{time: number, chars: number, text: string}>} chartDataPoints - 包含累计时间和字符数以及块文本的数据点数组
|
||||
*/
|
||||
function initDelayChart(delays) {
|
||||
if (!delays || delays.length <= 1) {
|
||||
function initDelayChart(chartDataPoints) {
|
||||
if (!chartDataPoints || chartDataPoints.length <= 1) {
|
||||
const container = document.querySelector('.delay-chart-container');
|
||||
if (container) {
|
||||
container.innerHTML = '<div style="text-align: center; padding: 20px;">延迟数据不足,无法绘制图表</div>';
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1519,71 +1570,35 @@
|
||||
return;
|
||||
}
|
||||
|
||||
// 销毁之前的图表(如果存在)
|
||||
const existingChart = Chart.getChart(ctx);
|
||||
if (existingChart) {
|
||||
existingChart.destroy();
|
||||
}
|
||||
|
||||
// 计算字符数和累计时间
|
||||
const rawDataPoints = [];
|
||||
let totalChars = 0;
|
||||
let accumulatedTime = 0;
|
||||
const maxPoints = 100;
|
||||
let sampledPoints = chartDataPoints;
|
||||
|
||||
// 解密所有数据点并计算累计时间
|
||||
for (let i = 0; i < delays.length; i++) {
|
||||
const [encodedText, delay] = delays[i];
|
||||
if (encodedText && typeof delay === 'number') {
|
||||
try {
|
||||
const text = decodeURIComponent(atob(encodedText));
|
||||
totalChars += text.length;
|
||||
accumulatedTime += delay; // 累加延迟时间
|
||||
rawDataPoints.push({
|
||||
time: accumulatedTime, // 使用累计时间
|
||||
chars: totalChars,
|
||||
text: text
|
||||
});
|
||||
} catch (e) {
|
||||
console.error('解码延迟数据失败:', e);
|
||||
}
|
||||
if (chartDataPoints.length > maxPoints) {
|
||||
sampledPoints = [];
|
||||
const interval = (chartDataPoints.length - 2) / (maxPoints - 2);
|
||||
|
||||
sampledPoints.push(chartDataPoints[0]);
|
||||
|
||||
for (let i = 1; i < maxPoints - 1; i++) {
|
||||
const rawIndex = Math.round(i * interval);
|
||||
const actualIndex = Math.min(rawIndex + 1, chartDataPoints.length - 2);
|
||||
sampledPoints.push(chartDataPoints[actualIndex]);
|
||||
}
|
||||
|
||||
sampledPoints.push(chartDataPoints[chartDataPoints.length - 1]);
|
||||
}
|
||||
|
||||
// 优化数据点密度
|
||||
const maxPoints = 10;
|
||||
let dataPoints = [];
|
||||
|
||||
if (rawDataPoints.length > maxPoints) {
|
||||
// 计算采样间隔
|
||||
const interval = Math.floor(rawDataPoints.length / maxPoints);
|
||||
|
||||
// 确保包含第一个点
|
||||
dataPoints.push(rawDataPoints[0]);
|
||||
|
||||
// 采样中间点,使用更大的间隔
|
||||
for (let i = interval; i < rawDataPoints.length - interval; i += interval) {
|
||||
// 在每个采样点周围计算平均值,使曲线更平滑
|
||||
const start = Math.max(0, i - Math.floor(interval / 2));
|
||||
const end = Math.min(rawDataPoints.length, i + Math.floor(interval / 2));
|
||||
const avgPoint = rawDataPoints[i];
|
||||
dataPoints.push(avgPoint);
|
||||
}
|
||||
|
||||
// 确保包含最后一个点
|
||||
if (dataPoints[dataPoints.length - 1] !== rawDataPoints[rawDataPoints.length - 1]) {
|
||||
dataPoints.push(rawDataPoints[rawDataPoints.length - 1]);
|
||||
}
|
||||
} else {
|
||||
dataPoints = rawDataPoints;
|
||||
}
|
||||
|
||||
// 创建新图表
|
||||
new Chart(ctx, {
|
||||
type: 'line',
|
||||
data: {
|
||||
datasets: [{
|
||||
label: '累计输出字符数',
|
||||
data: dataPoints.map(point => ({
|
||||
data: sampledPoints.map(point => ({
|
||||
x: point.time,
|
||||
y: point.chars
|
||||
})),
|
||||
@@ -1606,16 +1621,33 @@
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
title: function (context) {
|
||||
return `用时: ${context[0].raw.x.toFixed(1)}秒`;
|
||||
const pointIndex = context[0].dataIndex;
|
||||
if (pointIndex < sampledPoints.length) {
|
||||
return `用时: ${sampledPoints[pointIndex].time.toFixed(1)}秒`;
|
||||
}
|
||||
return '';
|
||||
},
|
||||
label: function (context) {
|
||||
const point = context.raw;
|
||||
const rate = point.x > 0 ? (point.y / point.x).toFixed(1) : 0;
|
||||
return [
|
||||
`字符数: ${point.y}`,
|
||||
`平均速率: ${rate} 字符/秒`,
|
||||
`文本: ${dataPoints[context.dataIndex].text}`
|
||||
];
|
||||
const pointIndex = context.dataIndex;
|
||||
if (pointIndex < sampledPoints.length) {
|
||||
const point = sampledPoints[pointIndex];
|
||||
const prevPoint = pointIndex > 0 ? sampledPoints[pointIndex - 1] : { time: 0, chars: 0, text: '' };
|
||||
|
||||
const deltaTime = point.time - prevPoint.time;
|
||||
const deltaChars = point.chars - prevPoint.chars;
|
||||
const currentRate = deltaTime > 0 ? (deltaChars / deltaTime).toFixed(1) : 'N/A';
|
||||
const avgRate = point.time > 0 ? (point.chars / point.time).toFixed(1) : 'N/A';
|
||||
|
||||
const chunkText = point.text || '';
|
||||
|
||||
return [
|
||||
`累计字符: ${point.chars}`,
|
||||
`平均速率: ${avgRate} 字符/秒`,
|
||||
`当前块速率: ${currentRate} 字符/秒`,
|
||||
`当前块文本: ${escapeHtmlAndControlChars(chunkText.length > 50 ? chunkText.substring(0, 50) + '...' : chunkText)}`
|
||||
];
|
||||
}
|
||||
return '';
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1625,14 +1657,15 @@
|
||||
beginAtZero: true,
|
||||
title: {
|
||||
display: true,
|
||||
text: '字符数'
|
||||
text: '累计字符数'
|
||||
}
|
||||
},
|
||||
x: {
|
||||
type: 'linear',
|
||||
beginAtZero: true,
|
||||
title: {
|
||||
display: true,
|
||||
text: '用时(秒)'
|
||||
text: '累计用时(秒)'
|
||||
},
|
||||
ticks: {
|
||||
callback: function (value) {
|
||||
@@ -1644,6 +1677,16 @@
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化使用量信息
|
||||
* @param {Object} usage - 使用量对象 { input: number, output: number }
|
||||
* @returns {string} 格式化后的字符串
|
||||
*/
|
||||
function formatUsage(usage) {
|
||||
if (!usage) return '-';
|
||||
return `${usage.input} / ${usage.output}`;
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
|
||||
|
Reference in New Issue
Block a user