0.1.3-rc.5.2.5

This commit is contained in:
wisdgod
2025-04-16 00:59:07 +08:00
parent 5071997ffc
commit c3bfb3b66e
29 changed files with 1017 additions and 883 deletions

View File

@@ -131,3 +131,9 @@ GENERAL_TIMEZONE=Asia/Shanghai
# 使用内嵌的Claude.ai官方提示词作为默认提示词如果是claude-开头的模型优先级大于DEFAULT_INSTRUCTIONS
USE_OFFICIAL_CLAUDE_PROMPTS=false
# 真实额度(由于Cursor服务本身的问题需要等待约5秒由于架构原因流式可能有bug),否则全零
REAL_USAGE=false
# 安全哈希checksum生成更慢
SAFE_HASH=true

View File

@@ -2,7 +2,7 @@ cargo-features = ["profile-rustflags", "trim-paths"]
[package]
name = "cursor-api"
version = "0.1.3-rc.5.2.4"
version = "0.1.3-rc.5.2.5"
edition = "2024"
authors = ["wisdgod <nav@wisdgod.com>"]
description = "OpenAI format compatibility layer for the Cursor API"
@@ -25,6 +25,7 @@ flate2 = { version = "1", default-features = false, features = ["rust_backend"]
futures = { version = "^0.3", default-features = false, features = ["std"] }
gif = { version = "^0.13", default-features = false, features = ["std"] }
hex = { version = "^0.4", default-features = false, features = ["std"] }
http = "1"
image = { version = "^0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
lasso = { version = "^0.7", features = ["inline-more", "multi-threaded"] }
memmap2 = "^0.9"
@@ -34,11 +35,10 @@ paste = "^1.0"
prost = "^0.13"
prost-types = "^0.13"
rand = { version = "^0.9", default-features = false, features = ["thread_rng"] }
regex = { version = "^1.11", default-features = false, features = ["std", "perf"] }
reqwest = { version = "^0.12", default-features = false, features = ["gzip", "brotli", "json", "stream", "socks", "__tls", "charset", "rustls-tls-native-roots", "macos-system-configuration"] }
reqwest = { version = "^0.12", default-features = false, features = ["gzip", "brotli", "json", "stream", "socks", "__tls", "charset", "rustls-tls-webpki-roots", "macos-system-configuration"] }
rkyv = { version = "^0.7", default-features = false, features = ["alloc", "std", "bytecheck", "size_64", "validation", "std"] }
serde = { version = "^1.0", default-features = false, features = ["std", "derive", "rc"] }
serde_json = { package = "sonic-rs", version = "^0.4" }
serde_json = { package = "sonic-rs", version = "0.5" }
# serde_json = "^1.0"
sha2 = { version = "^0.10", default-features = false }
sysinfo = { version = "^0.34", default-features = false, features = ["system"] }

View File

@@ -80,6 +80,46 @@ deepseek-v3
deepseek-r1
o3-mini
grok-2
deepseek-v3.1
grok-3-beta
grok-3-mini-beta
gpt-4.1
```
支持思考:
```
claude-3.7-sonnet-thinking
claude-3.7-sonnet-thinking-max
o1-mini
o1-preview
o1
gemini-2.5-pro-exp-03-25
gemini-2.5-pro-max
gemini-2.0-flash-thinking-exp
deepseek-r1
o3-mini
```
支持图像:
```
claude-3.5-sonnet
claude-3.7-sonnet
claude-3.7-sonnet-thinking
claude-3.7-sonnet-max
claude-3.7-sonnet-thinking-max
gpt-4
gpt-4o
gpt-4.5-preview
claude-3-opus
gpt-4-turbo-2024-04-09
gpt-4o-128k
claude-3-haiku-200k
claude-3-5-sonnet-200k
gpt-4o-mini
claude-3.5-haiku
gemini-2.5-pro-exp-03-25
gemini-2.5-pro-max
gpt-4.1
```
## 接口说明
@@ -961,9 +1001,12 @@ string
}
],
"delays": [
"string",
[
"string",
number
[
number, // chars count
number // time
]
]
],
"usage": { // optional

View File

@@ -1,3 +1,6 @@
mod header;
pub use header::*;
#[macro_export]
macro_rules! def_pub_const {
// 单个常量定义
@@ -78,27 +81,12 @@ def_pub_const!(
STATUS_FAILURE => "failure"
);
// Header constants
def_pub_const!(
HEADER_NAME_GHOST_MODE => "x-ghost-mode"
);
// Boolean constants
def_pub_const!(
TRUE => "true",
FALSE => "false"
);
// Content type constants
def_pub_const!(
CONTENT_TYPE_PROTO => "application/proto",
CONTENT_TYPE_CONNECT_PROTO => "application/connect+proto",
CONTENT_TYPE_TEXT_HTML_WITH_UTF8 => "text/html;charset=utf-8",
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8 => "text/plain;charset=utf-8",
CONTENT_TYPE_TEXT_CSS_WITH_UTF8 => "text/css;charset=utf-8",
CONTENT_TYPE_TEXT_JS_WITH_UTF8 => "text/javascript;charset=utf-8"
);
// Authorization constants
def_pub_const!(
AUTHORIZATION_BEARER_PREFIX => "Bearer "

View File

@@ -0,0 +1,80 @@
macro_rules! def_header_name {
($($name:ident => $value:expr),+ $(,)?) => {
$(paste::paste! {
#[inline]
pub(crate) fn [<header_name_ $name>]() -> &'static http::header::HeaderName {
static HEADER_NAME: std::sync::OnceLock<http::header::HeaderName> = std::sync::OnceLock::new();
HEADER_NAME.get_or_init(|| http::header::HeaderName::from_static($value))
}
})+
};
}
macro_rules! def_header_value {
($($name:ident => $value:expr),+ $(,)?) => {
$(paste::paste! {
#[inline]
pub fn [<header_value_ $name>]() -> &'static http::header::HeaderValue {
static HEADER_NAME: std::sync::OnceLock<http::header::HeaderValue> = std::sync::OnceLock::new();
HEADER_NAME.get_or_init(|| http::header::HeaderValue::from_static($value))
}
})+
};
}
def_header_value!(
one => "1",
encoding => "gzip",
encodings => "gzip,br",
accept => "*/*",
language => "en-US",
empty => "empty",
cors => "cors",
no_cache => "no-cache",
no_cache_revalidate => "no-cache, must-revalidate",
ua_win => "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
same_origin => "same-origin",
keep_alive => "keep-alive",
trailers => "trailers",
u_eq_0 => "u=0",
connect_es => "connect-es/1.6.1",
not_a_brand => "\"Not-A.Brand\";v=\"99\", \"Chromium\";v=\"124\"",
mobile_no => "?0",
windows => "\"Windows\"",
ua_cursor => "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.42.5 Chrome/124.0.6367.243 Electron/30.4.0 Safari/537.36",
vscode_origin => "vscode-file://vscode-app",
cross_site => "cross-site",
gzip_deflate => "gzip, deflate",
event_stream => "text/event-stream",
chunked => "chunked",
json => "application/json",
proto => "application/proto",
connect_proto => "application/connect+proto",
// Content type constants
text_html_utf8 => "text/html;charset=utf-8",
text_plain_utf8 => "text/plain;charset=utf-8",
text_css_utf8 => "text/css;charset=utf-8",
text_js_utf8 => "text/javascript;charset=utf-8"
);
def_header_name!(
proxy_host => "x-co",
connect_accept_encoding => "connect-accept-encoding",
connect_protocol_version => "connect-protocol-version",
ghost_mode => "x-ghost-mode",
amzn_trace_id => "x-amzn-trace-id",
client_key => "x-client-key",
cursor_checksum => "x-cursor-checksum",
cursor_client_version => "x-cursor-client-version",
cursor_timezone => "x-cursor-timezone",
request_id => "x-request-id",
sec_ch_ua => "sec-ch-ua",
sec_ch_ua_mobile => "sec-ch-ua-mobile",
sec_ch_ua_platform => "sec-ch-ua-platform",
sec_fetch_dest => "sec-fetch-dest",
sec_fetch_mode => "sec-fetch-mode",
sec_fetch_site => "sec-fetch-site",
sec_gpc => "sec-gpc",
priority => "priority",
);

View File

@@ -11,20 +11,11 @@ use tokio::sync::{Mutex, OnceCell};
macro_rules! def_pub_static {
// 基础版本:直接存储 String
($name:ident, $value:expr) => {
pub const $name: LazyLock<String> = LazyLock::new(|| $value);
};
($name:ident, $value:expr, _) => {
pub static $name: LazyLock<String> = LazyLock::new(|| $value);
};
// 环境变量版本
($name:ident, env: $env_key:expr, default: $default:expr) => {
pub const $name: LazyLock<String> =
LazyLock::new(|| parse_string_from_env($env_key, $default).trim().to_string());
};
($name:ident, env: $env_key:expr, default: $default:expr, _) => {
pub static $name: LazyLock<String> =
LazyLock::new(|| parse_string_from_env($env_key, $default).trim().to_string());
};
@@ -41,14 +32,13 @@ macro_rules! def_pub_static {
// }
def_pub_static!(ROUTE_PREFIX, env: "ROUTE_PREFIX", default: EMPTY_STRING);
def_pub_static!(AUTH_TOKEN, env: "AUTH_TOKEN", default: EMPTY_STRING, _);
def_pub_static!(ROUTE_MODELS_PATH, format!("{}/v1/models", *ROUTE_PREFIX), _);
def_pub_static!(AUTH_TOKEN, env: "AUTH_TOKEN", default: EMPTY_STRING);
def_pub_static!(ROUTE_MODELS_PATH, format!("{}/v1/models", *ROUTE_PREFIX));
def_pub_static!(
ROUTE_CHAT_PATH,
format!("{}/v1/chat/completions", *ROUTE_PREFIX),
_
format!("{}/v1/chat/completions", *ROUTE_PREFIX)
);
def_pub_static!(ROUTE_MESSAGES_PATH, format!("{}/v1/messages", *ROUTE_PREFIX), _);
// def_pub_static!(ROUTE_MESSAGES_PATH, format!("{}/v1/messages", *ROUTE_PREFIX));
static START_TIME: OnceLock<chrono::DateTime<chrono::Local>> = OnceLock::new();
@@ -56,7 +46,7 @@ pub fn get_start_time() -> &'static chrono::DateTime<chrono::Local> {
START_TIME.get_or_init(chrono::Local::now)
}
pub const GENERAL_TIMEZONE: LazyLock<chrono_tz::Tz> = LazyLock::new(|| {
pub static GENERAL_TIMEZONE: LazyLock<chrono_tz::Tz> = LazyLock::new(|| {
use std::str::FromStr as _;
let tz = parse_string_from_env("GENERAL_TIMEZONE", EMPTY_STRING);
let tz = tz.trim();
@@ -80,9 +70,9 @@ pub fn now_in_general_timezone() -> chrono::DateTime<chrono_tz::Tz> {
GENERAL_TIMEZONE.from_utc_datetime(&chrono::Utc::now().naive_utc())
}
def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default\n<|END_USER|>\n\n<|BEGIN_ASSISTANT|>\n\n\nYour will\n<|END_ASSISTANT|>\n\n<|BEGIN_USER|>\n\n\nThe current date is {{currentDateTime}}", _);
def_pub_static!(DEFAULT_INSTRUCTIONS, env: "DEFAULT_INSTRUCTIONS", default: "Respond in Chinese by default\n<|END_USER|>\n\n<|BEGIN_ASSISTANT|>\n\n\nYour will\n<|END_ASSISTANT|>\n\n<|BEGIN_USER|>\n\n\nThe current date is {{currentDateTime}}");
const USE_OFFICIAL_CLAUDE_PROMPTS: LazyLock<bool> =
static USE_OFFICIAL_CLAUDE_PROMPTS: LazyLock<bool> =
LazyLock::new(|| parse_bool_from_env("USE_OFFICIAL_CLAUDE_PROMPTS", false));
pub fn get_default_instructions(
@@ -125,13 +115,13 @@ pub fn get_default_instructions(
)
}
def_pub_static!(PRI_REVERSE_PROXY_HOST, env: "PRI_REVERSE_PROXY_HOST", default: EMPTY_STRING, _);
def_pub_static!(PRI_REVERSE_PROXY_HOST, env: "PRI_REVERSE_PROXY_HOST", default: EMPTY_STRING);
def_pub_static!(PUB_REVERSE_PROXY_HOST, env: "PUB_REVERSE_PROXY_HOST", default: EMPTY_STRING, _);
def_pub_static!(PUB_REVERSE_PROXY_HOST, env: "PUB_REVERSE_PROXY_HOST", default: EMPTY_STRING);
const DEFAULT_KEY_PREFIX: &str = "sk-";
pub const KEY_PREFIX: LazyLock<String> = LazyLock::new(|| {
pub static KEY_PREFIX: LazyLock<String> = LazyLock::new(|| {
let value = parse_string_from_env("KEY_PREFIX", DEFAULT_KEY_PREFIX)
.trim()
.to_string();
@@ -142,9 +132,9 @@ pub const KEY_PREFIX: LazyLock<String> = LazyLock::new(|| {
}
});
pub const KEY_PREFIX_LEN: LazyLock<usize> = LazyLock::new(|| KEY_PREFIX.len());
pub static KEY_PREFIX_LEN: LazyLock<usize> = LazyLock::new(|| KEY_PREFIX.len());
pub const TOKEN_DELIMITER: LazyLock<char> = LazyLock::new(|| {
pub static TOKEN_DELIMITER: LazyLock<char> = LazyLock::new(|| {
let delimiter = parse_ascii_char_from_env("TOKEN_DELIMITER", COMMA);
if delimiter.is_ascii_alphabetic()
|| delimiter.is_ascii_digit()
@@ -158,7 +148,7 @@ pub const TOKEN_DELIMITER: LazyLock<char> = LazyLock::new(|| {
}
});
pub const USE_COMMA_DELIMITER: LazyLock<bool> = LazyLock::new(|| {
pub static USE_COMMA_DELIMITER: LazyLock<bool> = LazyLock::new(|| {
let enable = parse_bool_from_env("USE_COMMA_DELIMITER", true);
if enable && *TOKEN_DELIMITER == COMMA {
false
@@ -167,10 +157,10 @@ pub const USE_COMMA_DELIMITER: LazyLock<bool> = LazyLock::new(|| {
}
});
pub const USE_PRI_REVERSE_PROXY: LazyLock<bool> =
pub static USE_PRI_REVERSE_PROXY: LazyLock<bool> =
LazyLock::new(|| !PRI_REVERSE_PROXY_HOST.is_empty());
pub const USE_PUB_REVERSE_PROXY: LazyLock<bool> =
pub static USE_PUB_REVERSE_PROXY: LazyLock<bool> =
LazyLock::new(|| !PUB_REVERSE_PROXY_HOST.is_empty());
macro_rules! def_cursor_api_url {
@@ -257,18 +247,18 @@ static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
path
});
pub(super) const CONFIG_FILE_PATH: LazyLock<PathBuf> =
pub(super) static CONFIG_FILE_PATH: LazyLock<PathBuf> =
LazyLock::new(|| DATA_DIR.join("config.bin"));
pub(super) const LOGS_FILE_PATH: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("logs.bin"));
pub(super) static LOGS_FILE_PATH: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("logs.bin"));
pub(super) const TOKENS_FILE_PATH: LazyLock<PathBuf> =
pub(super) static TOKENS_FILE_PATH: LazyLock<PathBuf> =
LazyLock::new(|| DATA_DIR.join("tokens.bin"));
pub(super) const PROXIES_FILE_PATH: LazyLock<PathBuf> =
pub(super) static PROXIES_FILE_PATH: LazyLock<PathBuf> =
LazyLock::new(|| DATA_DIR.join("proxies.bin"));
pub const DEBUG: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("DEBUG", false));
pub static DEBUG: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("DEBUG", false));
// 使用环境变量 "DEBUG_LOG_FILE" 来指定日志文件路径,默认值为 "debug.log"
static DEBUG_LOG_FILE: LazyLock<String> =
@@ -318,19 +308,42 @@ macro_rules! debug_println {
};
}
pub const REQUEST_LOGS_LIMIT: LazyLock<usize> =
LazyLock::new(|| std::cmp::min(parse_usize_from_env("REQUEST_LOGS_LIMIT", 100), 100000));
// 请求日志相关常量
const DEFAULT_REQUEST_LOGS_LIMIT: usize = 100;
const MAX_REQUEST_LOGS_LIMIT: usize = 100000;
pub const IS_NO_REQUEST_LOGS: LazyLock<bool> = LazyLock::new(|| *REQUEST_LOGS_LIMIT == 0);
pub const IS_UNLIMITED_REQUEST_LOGS: LazyLock<bool> =
LazyLock::new(|| *REQUEST_LOGS_LIMIT == 100000);
pub const TCP_KEEPALIVE: LazyLock<u64> = LazyLock::new(|| {
let keepalive = parse_usize_from_env("TCP_KEEPALIVE", 90);
u64::try_from(keepalive).map(|t| t.min(600)).unwrap_or(90)
pub static REQUEST_LOGS_LIMIT: LazyLock<usize> = LazyLock::new(|| {
std::cmp::min(
parse_usize_from_env("REQUEST_LOGS_LIMIT", DEFAULT_REQUEST_LOGS_LIMIT),
MAX_REQUEST_LOGS_LIMIT,
)
});
pub const SERVICE_TIMEOUT: LazyLock<u64> = LazyLock::new(|| {
let timeout = parse_usize_from_env("SERVICE_TIMEOUT", 30);
u64::try_from(timeout).map(|t| t.min(600)).unwrap_or(30)
pub static IS_NO_REQUEST_LOGS: LazyLock<bool> = LazyLock::new(|| *REQUEST_LOGS_LIMIT == 0);
pub static IS_UNLIMITED_REQUEST_LOGS: LazyLock<bool> =
LazyLock::new(|| *REQUEST_LOGS_LIMIT == MAX_REQUEST_LOGS_LIMIT);
// TCP 和超时相关常量
const DEFAULT_TCP_KEEPALIVE: usize = 90;
const MAX_TCP_KEEPALIVE: u64 = 600;
pub static TCP_KEEPALIVE: LazyLock<u64> = LazyLock::new(|| {
let keepalive = parse_usize_from_env("TCP_KEEPALIVE", DEFAULT_TCP_KEEPALIVE);
u64::try_from(keepalive)
.map(|t| t.min(MAX_TCP_KEEPALIVE))
.unwrap_or(DEFAULT_TCP_KEEPALIVE as u64)
});
const DEFAULT_SERVICE_TIMEOUT: usize = 30;
const MAX_SERVICE_TIMEOUT: u64 = 600;
pub static SERVICE_TIMEOUT: LazyLock<u64> = LazyLock::new(|| {
let timeout = parse_usize_from_env("SERVICE_TIMEOUT", DEFAULT_SERVICE_TIMEOUT);
u64::try_from(timeout)
.map(|t| t.min(MAX_SERVICE_TIMEOUT))
.unwrap_or(DEFAULT_SERVICE_TIMEOUT as u64)
});
pub static REAL_USAGE: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("REAL_USAGE", false));
pub static SAFE_HASH: LazyLock<bool> = LazyLock::new(|| parse_bool_from_env("SAFE_HASH", true));

View File

@@ -84,7 +84,8 @@ pub struct RequestLog {
pub struct Chain {
#[serde(skip_serializing_if = "Prompt::is_none")]
pub prompt: Prompt,
pub delays: Vec<(String, f64)>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delays: Option<(String, Vec<(u32, f32)>)>,
#[serde(skip_serializing_if = "OptionUsage::is_none")]
pub usage: OptionUsage,
}

View File

@@ -45,7 +45,7 @@ impl<'de> Deserialize<'de> for UsageCheckModelConfig {
.split(COMMA)
.filter_map(|model| {
let model = model.trim();
Models::find_id(model)
Models::find_id(model).map(|m| m.id)
})
.collect()
};

View File

@@ -135,7 +135,7 @@ impl From<super::Prompt> for PromptHelper {
#[derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
pub struct ChainHelper {
pub prompt: PromptHelper,
pub delays: Vec<(String, f64)>,
pub delays: Option<(String, Vec<(u32, f32)>)>,
pub usage: super::OptionUsage,
}
impl From<ChainHelper> for super::Chain {

View File

@@ -1,17 +1,19 @@
mod proxy_url;
use crate::app::lazy::{PROXIES_FILE_PATH, SERVICE_TIMEOUT, TCP_KEEPALIVE};
use memmap2::{MmapMut, MmapOptions};
use parking_lot::RwLock;
use reqwest::{Client, Proxy};
use proxy_url::StringUrl;
use reqwest::Client;
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs::OpenOptions;
use std::str::FromStr;
use std::sync::LazyLock;
use std::time::Duration;
mod proxy_url;
use proxy_url::StringUrl;
use std::{
collections::{HashMap, HashSet},
fs::OpenOptions,
str::FromStr,
sync::LazyLock,
time::Duration,
};
// 新的代理值常量
pub const NON_PROXY: &str = "non";
@@ -21,10 +23,8 @@ pub const SYS_PROXY: &str = "sys";
pub static PROXY_POOL: LazyLock<RwLock<ProxyPool>> = LazyLock::new(|| {
let system_client = Client::builder()
.https_only(true)
.http1_only()
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
.timeout(Duration::from_secs(*SERVICE_TIMEOUT))
.http1_title_case_headers()
.connect_timeout(Duration::from_secs(*SERVICE_TIMEOUT))
.build()
.expect("创建默认系统客户端失败");
@@ -109,10 +109,8 @@ impl Proxies {
SingleProxy::Non,
Client::builder()
.https_only(true)
.http1_only()
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
.timeout(Duration::from_secs(*SERVICE_TIMEOUT))
.http1_title_case_headers()
.connect_timeout(Duration::from_secs(*SERVICE_TIMEOUT))
.no_proxy()
.build()
.expect("创建无代理客户端失败"),
@@ -123,29 +121,23 @@ impl Proxies {
SingleProxy::Sys,
Client::builder()
.https_only(true)
.http1_only()
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
.timeout(Duration::from_secs(*SERVICE_TIMEOUT))
.http1_title_case_headers()
.connect_timeout(Duration::from_secs(*SERVICE_TIMEOUT))
.build()
.expect("创建默认客户端失败"),
);
}
SingleProxy::Url(url) => {
if let Ok(proxy_obj) = Proxy::all(url.to_string()) {
pool.clients.insert(
(*proxy).clone(),
Client::builder()
.https_only(true)
.http1_only()
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
.timeout(Duration::from_secs(*SERVICE_TIMEOUT))
.http1_title_case_headers()
.proxy(proxy_obj)
.build()
.expect("创建代理客户端失败"),
);
}
pool.clients.insert(
(*proxy).clone(),
Client::builder()
.https_only(true)
.tcp_keepalive(Duration::from_secs(*TCP_KEEPALIVE))
.connect_timeout(Duration::from_secs(*SERVICE_TIMEOUT))
.proxy(url.as_proxy().expect("创建代理对象失败"))
.build()
.expect("创建代理客户端失败"),
);
}
}
}

View File

@@ -3,7 +3,7 @@ use memmap2::{MmapMut, MmapOptions};
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
collections::{HashMap, HashSet, VecDeque},
fs::OpenOptions,
};
@@ -59,7 +59,7 @@ pub struct RequestStatsManager {
pub total_requests: u64,
pub active_requests: u64,
pub error_requests: u64,
pub request_logs: Vec<RequestLog>,
pub request_logs: VecDeque<RequestLog>,
}
pub struct AppState {
@@ -180,7 +180,7 @@ impl TokenManager {
}
impl RequestStatsManager {
pub fn new(request_logs: Vec<RequestLog>) -> Self {
pub fn new(request_logs: VecDeque<RequestLog>) -> Self {
Self {
total_requests: request_logs.len() as u64,
active_requests: 0,
@@ -220,11 +220,11 @@ impl RequestStatsManager {
Ok(())
}
pub async fn load_logs() -> Result<Vec<RequestLog>, Box<dyn std::error::Error>> {
pub async fn load_logs() -> Result<VecDeque<RequestLog>, Box<dyn std::error::Error>> {
let file = match OpenOptions::new().read(true).open(&*LOGS_FILE_PATH) {
Ok(file) => file,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Ok(Vec::new());
return Ok(VecDeque::new());
}
Err(e) => return Err(Box::new(e)),
};

View File

@@ -25,6 +25,7 @@ impl UsageCheck {
.model_ids
.iter()
.filter_map(|id| Models::find_id(id))
.map(|m| m.id)
.collect();
if models.is_empty() {
Self::None
@@ -125,6 +126,7 @@ impl<'de> Deserialize<'de> for UsageCheck {
let model = model.trim();
Models::find_id(model)
})
.map(|m| m.id)
.collect();
if models.is_empty() {
@@ -153,6 +155,7 @@ impl UsageCheck {
let model = model.trim();
Models::find_id(model)
})
.map(|m| m.id)
.collect();
if models.is_empty() {

View File

@@ -1,7 +1,20 @@
use crate::app::{
constant::{
CONTENT_TYPE_CONNECT_PROTO, CONTENT_TYPE_PROTO, CURSOR_API2_HOST, CURSOR_HOST,
CURSOR_SETTINGS_URL, HEADER_NAME_GHOST_MODE, TRUE,
CURSOR_API2_HOST, CURSOR_HOST, CURSOR_SETTINGS_URL, TRUE, header_name_amzn_trace_id,
header_name_client_key, header_name_connect_accept_encoding,
header_name_connect_protocol_version, header_name_cursor_checksum,
header_name_cursor_client_version, header_name_cursor_timezone, header_name_ghost_mode,
header_name_priority, header_name_proxy_host, header_name_request_id,
header_name_sec_ch_ua, header_name_sec_ch_ua_mobile, header_name_sec_ch_ua_platform,
header_name_sec_fetch_dest, header_name_sec_fetch_mode, header_name_sec_fetch_site,
header_name_sec_gpc, header_value_accept, header_value_chunked, header_value_connect_es,
header_value_connect_proto, header_value_cors, header_value_cross_site, header_value_empty,
header_value_encoding, header_value_encodings, header_value_gzip_deflate,
header_value_keep_alive, header_value_language, header_value_mobile_no,
header_value_no_cache, header_value_not_a_brand, header_value_one, header_value_proto,
header_value_same_origin, header_value_trailers, header_value_u_eq_0,
header_value_ua_cursor, header_value_ua_win, header_value_vscode_origin,
header_value_windows,
},
lazy::{
PRI_REVERSE_PROXY_HOST, PUB_REVERSE_PROXY_HOST, USE_PRI_REVERSE_PROXY,
@@ -10,7 +23,7 @@ use crate::app::{
},
};
use reqwest::{
Client, RequestBuilder,
Client, Method, RequestBuilder,
header::{
ACCEPT, ACCEPT_ENCODING, ACCEPT_LANGUAGE, CACHE_CONTROL, CONNECTION, CONTENT_LENGTH,
CONTENT_TYPE, COOKIE, DNT, HOST, ORIGIN, PRAGMA, REFERER, TE, TRANSFER_ENCODING,
@@ -18,56 +31,30 @@ use reqwest::{
},
};
macro_rules! def_const {
($name:ident, $value:expr) => {
const $name: &'static str = $value;
};
}
def_const!(SEC_FETCH_DEST, "sec-fetch-dest");
def_const!(SEC_FETCH_MODE, "sec-fetch-mode");
def_const!(SEC_FETCH_SITE, "sec-fetch-site");
def_const!(SEC_GPC, "sec-gpc");
def_const!(PRIORITY, "priority");
def_const!(ONE, "1");
def_const!(ENCODINGS, "gzip,br");
def_const!(VALUE_ACCEPT, "*/*");
def_const!(VALUE_LANGUAGE, "en-US");
def_const!(EMPTY, "empty");
def_const!(CORS, "cors");
def_const!(NO_CACHE, "no-cache");
def_const!(
UA_WIN,
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
);
def_const!(SAME_ORIGIN, "same-origin");
def_const!(KEEP_ALIVE, "keep-alive");
def_const!(TRAILERS, "trailers");
def_const!(U_EQ_4, "u=4");
def_const!(U_EQ_0, "u=0");
def_const!(PROXY_HOST, "x-co");
#[inline]
fn get_client_and_host_post<'a>(
fn get_client_and_host<'a>(
client: &Client,
method: Method,
url: &'a str,
is_pri: bool,
real_host: &'a str,
) -> (RequestBuilder, &'a str) {
if is_pri && *USE_PRI_REVERSE_PROXY {
(
client.post(url).header(PROXY_HOST, real_host),
client
.request(method, url)
.header(header_name_proxy_host(), real_host),
PRI_REVERSE_PROXY_HOST.as_str(),
)
} else if !is_pri && *USE_PUB_REVERSE_PROXY {
(
client.post(url).header(PROXY_HOST, real_host),
client
.request(method, url)
.header(header_name_proxy_host(), real_host),
PUB_REVERSE_PROXY_HOST.as_str(),
)
} else {
(client.post(url), real_host)
(client.request(method, url), real_host)
}
}
@@ -95,54 +82,43 @@ pub(crate) struct AiServiceRequest<'a> {
///
/// * `reqwest::RequestBuilder` - 配置好的请求构建器
pub fn build_request(req: AiServiceRequest) -> RequestBuilder {
let (client, host) =
get_client_and_host_post(&req.client, req.url, req.is_pri, CURSOR_API2_HOST);
let (builder, host) = get_client_and_host(
&req.client,
Method::POST,
req.url,
req.is_pri,
CURSOR_API2_HOST,
);
client
builder
.header(
CONTENT_TYPE,
if req.is_stream {
CONTENT_TYPE_CONNECT_PROTO
header_value_connect_proto()
} else {
CONTENT_TYPE_PROTO
header_value_proto()
},
)
.bearer_auth(req.auth_token)
.header("connect-accept-encoding", ENCODINGS)
.header("connect-protocol-version", ONE)
.header(USER_AGENT, "connect-es/1.6.1")
.header("x-amzn-trace-id", format!("Root={}", req.trace_id))
.header("x-client-key", req.client_key)
.header("x-cursor-checksum", req.checksum)
.header("x-cursor-client-version", "0.42.5")
.header("x-cursor-timezone", req.timezone)
.header(HEADER_NAME_GHOST_MODE, TRUE)
.header("x-request-id", req.trace_id)
.header(
header_name_connect_accept_encoding(),
header_value_encoding(),
)
.header(header_name_connect_protocol_version(), header_value_one())
.header(USER_AGENT, header_value_connect_es())
.header(
header_name_amzn_trace_id(),
format!("Root={}", req.trace_id),
)
.header(header_name_client_key(), req.client_key)
.header(header_name_cursor_checksum(), req.checksum)
.header(header_name_cursor_client_version(), "0.42.5")
.header(header_name_cursor_timezone(), req.timezone)
.header(header_name_ghost_mode(), TRUE)
.header(header_name_request_id(), req.trace_id)
.header(HOST, host)
.header(CONNECTION, KEEP_ALIVE)
.header(TRANSFER_ENCODING, "chunked")
}
#[inline]
fn get_client_and_host<'a>(
client: &Client,
url: &'a str,
is_pri: bool,
real_host: &'a str,
) -> (RequestBuilder, &'a str) {
if is_pri && *USE_PRI_REVERSE_PROXY {
(
client.get(url).header(PROXY_HOST, real_host),
PRI_REVERSE_PROXY_HOST.as_str(),
)
} else if !is_pri && *USE_PUB_REVERSE_PROXY {
(
client.get(url).header(PROXY_HOST, real_host),
PUB_REVERSE_PROXY_HOST.as_str(),
)
} else {
(client.get(url), real_host)
}
.header(CONNECTION, header_value_keep_alive())
.header(TRANSFER_ENCODING, header_value_chunked())
}
/// 返回预构建的获取 Stripe 账户信息的 Cursor API 客户端
@@ -155,32 +131,30 @@ fn get_client_and_host<'a>(
///
/// * `reqwest::RequestBuilder` - 配置好的请求构建器
pub fn build_profile_request(client: &Client, auth_token: &str, is_pri: bool) -> RequestBuilder {
let (client, host) = get_client_and_host(
let (builder, host) = get_client_and_host(
client,
Method::GET,
cursor_api2_stripe_url(is_pri),
is_pri,
CURSOR_API2_HOST,
);
client
builder
.header(HOST, host)
.header("sec-ch-ua", "\"Not-A.Brand\";v=\"99\", \"Chromium\";v=\"124\"")
.header(HEADER_NAME_GHOST_MODE, TRUE)
.header("sec-ch-ua-mobile", "?0")
.header(header_name_sec_ch_ua(), header_value_not_a_brand())
.header(header_name_ghost_mode(), TRUE)
.header(header_name_sec_ch_ua_mobile(), header_value_mobile_no())
.bearer_auth(auth_token)
.header(
USER_AGENT,
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.42.5 Chrome/124.0.6367.243 Electron/30.4.0 Safari/537.36",
)
.header("sec-ch-ua-platform", "\"Windows\"")
.header(ACCEPT, VALUE_ACCEPT)
.header(ORIGIN, "vscode-file://vscode-app")
.header(SEC_FETCH_SITE, "cross-site")
.header(SEC_FETCH_MODE, CORS)
.header(SEC_FETCH_DEST, EMPTY)
.header(ACCEPT_ENCODING, ENCODINGS)
.header(ACCEPT_LANGUAGE, VALUE_LANGUAGE)
.header(PRIORITY, "u=1, i")
.header(USER_AGENT, header_value_ua_cursor())
.header(header_name_sec_ch_ua_platform(), header_value_windows())
.header(ACCEPT, header_value_accept())
.header(ORIGIN, header_value_vscode_origin())
.header(header_name_sec_fetch_site(), header_value_cross_site())
.header(header_name_sec_fetch_mode(), header_value_cors())
.header(header_name_sec_fetch_dest(), header_value_empty())
.header(ACCEPT_ENCODING, header_value_encodings())
.header(ACCEPT_LANGUAGE, header_value_language())
.header(header_name_priority(), header_value_u_eq_0())
}
/// 返回预构建的获取使用情况的 Cursor API 客户端
@@ -199,26 +173,31 @@ pub fn build_usage_request(
auth_token: &str,
is_pri: bool,
) -> RequestBuilder {
let (client, host) =
get_client_and_host(client, cursor_usage_api_url(is_pri), is_pri, CURSOR_HOST);
let (client, host) = get_client_and_host(
client,
Method::GET,
cursor_usage_api_url(is_pri),
is_pri,
CURSOR_HOST,
);
client
.header(HOST, host)
.header(USER_AGENT, UA_WIN)
.header(ACCEPT, VALUE_ACCEPT)
.header(ACCEPT_LANGUAGE, VALUE_LANGUAGE)
.header(ACCEPT_ENCODING, ENCODINGS)
.header(USER_AGENT, header_value_ua_win())
.header(ACCEPT, header_value_accept())
.header(ACCEPT_LANGUAGE, header_value_language())
.header(ACCEPT_ENCODING, header_value_encodings())
.header(REFERER, CURSOR_SETTINGS_URL)
.header(DNT, ONE)
.header(SEC_GPC, ONE)
.header(SEC_FETCH_DEST, EMPTY)
.header(SEC_FETCH_MODE, CORS)
.header(SEC_FETCH_SITE, SAME_ORIGIN)
.header(CONNECTION, KEEP_ALIVE)
.header(PRAGMA, NO_CACHE)
.header(CACHE_CONTROL, NO_CACHE)
.header(TE, TRAILERS)
.header(PRIORITY, U_EQ_4)
.header(DNT, header_value_one())
.header(header_name_sec_gpc(), header_value_one())
.header(header_name_sec_fetch_dest(), header_value_empty())
.header(header_name_sec_fetch_mode(), header_value_cors())
.header(header_name_sec_fetch_site(), header_value_same_origin())
.header(CONNECTION, header_value_keep_alive())
.header(PRAGMA, header_value_no_cache())
.header(CACHE_CONTROL, header_value_no_cache())
.header(TE, header_value_trailers())
.header(header_name_priority(), header_value_u_eq_0())
.header(
COOKIE,
format!("WorkosCursorSessionToken={user_id}%3A%3A{auth_token}"),
@@ -242,26 +221,31 @@ pub fn build_userinfo_request(
auth_token: &str,
is_pri: bool,
) -> RequestBuilder {
let (client, host) =
get_client_and_host(client, cursor_user_api_url(is_pri), is_pri, CURSOR_HOST);
let (client, host) = get_client_and_host(
client,
Method::GET,
cursor_user_api_url(is_pri),
is_pri,
CURSOR_HOST,
);
client
.header(HOST, host)
.header(USER_AGENT, UA_WIN)
.header(ACCEPT, VALUE_ACCEPT)
.header(ACCEPT_LANGUAGE, VALUE_LANGUAGE)
.header(ACCEPT_ENCODING, ENCODINGS)
.header(USER_AGENT, header_value_ua_win())
.header(ACCEPT, header_value_accept())
.header(ACCEPT_LANGUAGE, header_value_language())
.header(ACCEPT_ENCODING, header_value_encodings())
.header(REFERER, CURSOR_SETTINGS_URL)
.header(DNT, ONE)
.header(SEC_GPC, ONE)
.header(SEC_FETCH_DEST, EMPTY)
.header(SEC_FETCH_MODE, CORS)
.header(SEC_FETCH_SITE, SAME_ORIGIN)
.header(CONNECTION, KEEP_ALIVE)
.header(PRAGMA, NO_CACHE)
.header(CACHE_CONTROL, NO_CACHE)
.header(TE, TRAILERS)
.header(PRIORITY, U_EQ_4)
.header(DNT, header_value_one())
.header(header_name_sec_gpc(), header_value_one())
.header(header_name_sec_fetch_dest(), header_value_empty())
.header(header_name_sec_fetch_mode(), header_value_cors())
.header(header_name_sec_fetch_site(), header_value_same_origin())
.header(CONNECTION, header_value_keep_alive())
.header(PRAGMA, header_value_no_cache())
.header(CACHE_CONTROL, header_value_no_cache())
.header(TE, header_value_trailers())
.header(header_name_priority(), header_value_u_eq_0())
.header(
COOKIE,
format!("WorkosCursorSessionToken={user_id}%3A%3A{auth_token}"),
@@ -276,8 +260,9 @@ pub fn build_token_upgrade_request(
auth_token: &str,
is_pri: bool,
) -> RequestBuilder {
let (client, host) = get_client_and_host_post(
let (client, host) = get_client_and_host(
client,
Method::POST,
cursor_token_upgrade_url(is_pri),
is_pri,
CURSOR_HOST,
@@ -287,10 +272,10 @@ pub fn build_token_upgrade_request(
client
.header(HOST, host)
.header(USER_AGENT, UA_WIN)
.header(ACCEPT, VALUE_ACCEPT)
.header(ACCEPT_LANGUAGE, VALUE_LANGUAGE)
.header(ACCEPT_ENCODING, ENCODINGS)
.header(USER_AGENT, header_value_ua_win())
.header(ACCEPT, header_value_accept())
.header(ACCEPT_LANGUAGE, header_value_language())
.header(ACCEPT_ENCODING, header_value_encodings())
.header(
REFERER,
format!(
@@ -299,16 +284,16 @@ pub fn build_token_upgrade_request(
)
.header(CONTENT_TYPE, "application/json")
.header(CONTENT_LENGTH, body.len())
.header(DNT, ONE)
.header(SEC_GPC, ONE)
.header(SEC_FETCH_DEST, EMPTY)
.header(SEC_FETCH_MODE, CORS)
.header(SEC_FETCH_SITE, SAME_ORIGIN)
.header(CONNECTION, KEEP_ALIVE)
.header(PRAGMA, NO_CACHE)
.header(CACHE_CONTROL, NO_CACHE)
.header(TE, TRAILERS)
.header(PRIORITY, U_EQ_0)
.header(DNT, header_value_one())
.header(header_name_sec_gpc(), header_value_one())
.header(header_name_sec_fetch_dest(), header_value_empty())
.header(header_name_sec_fetch_mode(), header_value_cors())
.header(header_name_sec_fetch_site(), header_value_same_origin())
.header(CONNECTION, header_value_keep_alive())
.header(PRAGMA, header_value_no_cache())
.header(CACHE_CONTROL, header_value_no_cache())
.header(TE, header_value_trailers())
.header(header_name_priority(), header_value_u_eq_0())
.header(
COOKIE,
format!("WorkosCursorSessionToken={user_id}%3A%3A{auth_token}"),
@@ -324,17 +309,18 @@ pub fn build_token_poll_request(
) -> RequestBuilder {
let (client, host) = get_client_and_host(
client,
Method::GET,
cursor_token_poll_url(is_pri),
is_pri,
CURSOR_API2_HOST,
);
client
.header(HOST, host)
.header(ACCEPT_ENCODING, "gzip, deflate")
.header(ACCEPT_LANGUAGE, "en-US")
.header(USER_AGENT, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.48.2 Chrome/132.0.6834.210 Electron/34.3.4 Safari/537.36")
.header(ORIGIN, "vscode-file://vscode-app")
.header(HEADER_NAME_GHOST_MODE, TRUE)
.header(ACCEPT, "*/*")
.header(ACCEPT_ENCODING, header_value_gzip_deflate())
.header(ACCEPT_LANGUAGE, header_value_language())
.header(USER_AGENT, header_value_ua_cursor())
.header(ORIGIN, header_value_vscode_origin())
.header(header_name_ghost_mode(), TRUE)
.header(ACCEPT, header_value_accept())
.query(&[("uuid", uuid), ("verifier", verifier)])
}

View File

@@ -7,8 +7,7 @@ pub struct HealthCheckResponse {
pub status: ApiStatus,
pub version: &'static str,
pub uptime: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub stats: Option<SystemStats>,
pub stats: SystemStats,
pub models: Vec<&'static str>,
pub endpoints: &'static [&'static str],
}

View File

@@ -1,5 +1,4 @@
mod checksum;
use std::time::Instant;
use ::base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
pub use checksum::*;
@@ -30,8 +29,8 @@ use crate::{
},
config::key_config,
constant::{
ANTHROPIC, CREATED, CURSOR, DEEPSEEK, GOOGLE, MODEL_OBJECT, OPENAI, UNKNOWN, XAI,
calculate_display_name_v3,
ANTHROPIC, CREATED, CURSOR, DEEPSEEK, DEFAULT, GOOGLE, MODEL_OBJECT, OPENAI, UNKNOWN,
XAI, calculate_display_name_v3,
},
model::{Model, Usage},
},
@@ -111,20 +110,6 @@ impl TrimNewlines for String {
}
}
pub trait InstantExt {
fn duration_as_secs_f64(&mut self) -> f64;
}
impl InstantExt for Instant {
#[inline]
fn duration_as_secs_f64(&mut self) -> f64 {
let now = Instant::now();
let duration = now.duration_since(*self);
*self = now;
duration.as_secs_f64()
}
}
pub async fn get_token_profile(
client: Client,
auth_token: &str,
@@ -261,6 +246,12 @@ pub async fn get_available_models(
}
};
let display_name = calculate_display_name_v3(&model.name);
let is_thinking = model.supports_thinking();
let is_image = if model.name.as_str() == DEFAULT {
true
} else {
model.supports_images()
};
Model {
id: crate::leak::intern_string(model.name),
@@ -268,6 +259,8 @@ pub async fn get_available_models(
created: CREATED,
object: MODEL_OBJECT,
owned_by,
is_thinking,
is_image,
}
})
.collect(),
@@ -276,9 +269,9 @@ pub async fn get_available_models(
pub async fn get_token_usage(
client: Client,
auth_token: &str,
checksum: &str,
client_key: &str,
auth_token: String,
checksum: String,
client_key: String,
timezone: &'static str,
is_pri: bool,
usage_uuid: String,
@@ -287,9 +280,9 @@ pub async fn get_token_usage(
let trace_id = uuid::Uuid::new_v4().to_string();
let client = super::client::build_request(super::client::AiServiceRequest {
client,
auth_token,
checksum,
client_key,
auth_token: &auth_token,
checksum: &checksum,
client_key: &client_key,
url: cursor_api2_token_usage_url(is_pri),
is_stream: false,
timezone,
@@ -449,13 +442,13 @@ pub fn token_to_tokeninfo(
}
/// 将 TokenInfo 转换为 JWT token
pub fn tokeninfo_to_token(mut info: key_config::TokenInfo) -> Option<(String, String, Client)> {
pub fn tokeninfo_to_token(info: key_config::TokenInfo) -> Option<(String, String, Client)> {
// 构建 payload
let payload = TokenPayload {
sub: std::mem::take(&mut info.sub),
sub: info.sub,
exp: info.end,
randomness: std::mem::take(&mut info.randomness),
time: info.start.to_string(), // exp - 30000天
randomness: info.randomness,
time: info.start.to_string(),
iss: ISSUER.to_string(),
scope: SCOPE.to_string(),
aud: AUDIENCE.to_string(),

View File

@@ -4,11 +4,11 @@ use sha2::{Digest, Sha256};
#[inline]
pub fn generate_hash() -> String {
use rand::Rng as _;
hex::encode(
Sha256::new()
.chain_update(rand::rng().random::<[u8; 32]>())
.finalize(),
)
let mut v = rand::rng().random::<[u8; 32]>();
if *crate::app::lazy::SAFE_HASH {
v = Sha256::new().chain_update(v).finalize().into();
}
hex::encode(v)
}
#[inline]
@@ -31,10 +31,19 @@ fn deobfuscate_bytes(bytes: &mut [u8]) {
}
}
pub fn generate_timestamp_header() -> String {
pub fn generate_timestamp_header() -> [u8; 8] {
static CACHE: std::sync::LazyLock<parking_lot::Mutex<(u64, [u8; 8])>> =
std::sync::LazyLock::new(|| parking_lot::Mutex::new((0, [0u8; 8])));
let timestamp = super::now_secs() / 1_000;
let mut timestamp_bytes = vec![
let mut guard = CACHE.lock();
if guard.0 == timestamp {
return guard.1;
} else {
guard.0 = timestamp;
}
let mut timestamp_bytes = [
((timestamp >> 8) & 0xFF) as u8,
(0xFF & timestamp) as u8,
((timestamp >> 24) & 0xFF) as u8,
@@ -44,12 +53,16 @@ pub fn generate_timestamp_header() -> String {
];
obfuscate_bytes(&mut timestamp_bytes);
BASE64.encode(&timestamp_bytes)
let mut result = [0u8; 8];
let _ = BASE64.encode_slice(&timestamp_bytes, &mut result);
guard.1 = result;
result
}
#[inline]
pub fn generate_checksum(device_id: &str, mac_addr: Option<&str>) -> String {
let encoded = generate_timestamp_header();
let timestamp_header = generate_timestamp_header();
let encoded = unsafe { str::from_utf8_unchecked(&timestamp_header) };
match mac_addr {
Some(mac) => format!("{encoded}{device_id}/{mac}"),
None => format!("{encoded}{device_id}"),
@@ -100,23 +113,23 @@ pub fn generate_checksum_with_repair(checksum: &str) -> String {
}
}
let timestamp_header = generate_timestamp_header();
let encoded = unsafe { str::from_utf8_unchecked(&timestamp_header) };
// 校验通过后构造结果
match len {
72 => format!(
"{}{}/{}",
generate_timestamp_header(),
"{encoded}{}/{}",
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(8..)) },
generate_hash()
),
129 => format!(
"{}{}/{}",
generate_timestamp_header(),
"{encoded}{}/{}",
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(..64)) },
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(65..)) }
),
137 => format!(
"{}{}/{}",
generate_timestamp_header(),
"{encoded}{}/{}",
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(8..72)) },
unsafe { std::str::from_utf8_unchecked(bytes.get_unchecked(73..)) }
),

View File

@@ -18,11 +18,8 @@ use super::{
AzureState, ChatExternalLink, ConversationMessage, ExplicitContext, GetChatRequest,
ImageProto, ModelDetails, WebReference, conversation_message, image_proto,
},
constant::{
ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS,
SUPPORTED_IMAGE_MODELS,
},
model::{Message, MessageContent, Role},
constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS},
model::{Message, MessageContent, Model, Role},
};
fn parse_web_references(text: &str) -> Vec<WebReference> {
@@ -90,7 +87,7 @@ async fn process_chat_inputs(
inputs: Vec<Message>,
now_with_tz: Option<chrono::DateTime<chrono_tz::Tz>>,
disable_vision: bool,
model: &str,
model: Model,
) -> (String, Vec<ConversationMessage>, Vec<String>) {
// 收集 system 指令
let instructions = inputs
@@ -101,7 +98,7 @@ async fn process_chat_inputs(
MessageContent::Vision(contents) => contents
.iter()
.filter_map(|content| {
if content.rtype == "text" {
if content.r#type == "text" {
content.text.clone()
} else {
None
@@ -114,9 +111,9 @@ async fn process_chat_inputs(
.join("\n\n");
// 使用默认指令或收集到的指令
let image_support = !disable_vision && SUPPORTED_IMAGE_MODELS.contains(&model);
let image_support = !disable_vision && model.is_image;
let instructions = if instructions.is_empty() {
get_default_instructions(now_with_tz, model, image_support)
get_default_instructions(now_with_tz, model.id, image_support)
} else {
instructions
};
@@ -239,7 +236,7 @@ async fn process_chat_inputs(
let mut images = Vec::new();
for content in contents {
match content.rtype.as_str() {
match content.r#type.as_str() {
"text" => {
if let Some(text) = content.text {
text_parts.push(text);
@@ -493,7 +490,7 @@ async fn process_http_image(
pub async fn encode_chat_message(
inputs: Vec<Message>,
now_with_tz: Option<chrono::DateTime<chrono_tz::Tz>>,
model: &str,
model: Model,
disable_vision: bool,
enable_slow_pool: bool,
is_search: bool,
@@ -525,7 +522,7 @@ pub async fn encode_chat_message(
})
.collect();
let long_context = AppConfig::get_long_context() || LONG_CONTEXT_MODELS.contains(&model);
let long_context = AppConfig::get_long_context() || LONG_CONTEXT_MODELS.contains(&model.id);
let chat = GetChatRequest {
current_file: None,
@@ -535,7 +532,7 @@ pub async fn encode_chat_message(
workspace_root_path: None,
code_blocks: vec![],
model_details: Some(ModelDetails {
model_name: Some(model.to_string()),
model_name: Some(model.id.to_string()),
api_key: None,
enable_ghost_mode: Some(true),
azure_state: Some(AzureState {

View File

@@ -3,19 +3,19 @@
pub struct KeyConfig {
/// 认证令牌(必需)
#[prost(message, optional, tag = "1")]
pub auth_token: ::core::option::Option<key_config::TokenInfo>,
pub auth_token: Option<key_config::TokenInfo>,
/// 是否禁用图片处理能力
#[prost(bool, optional, tag = "4")]
pub disable_vision: ::core::option::Option<bool>,
pub disable_vision: Option<bool>,
/// 是否启用慢速池
#[prost(bool, optional, tag = "5")]
pub enable_slow_pool: ::core::option::Option<bool>,
pub enable_slow_pool: Option<bool>,
/// 使用量检查模型规则
#[prost(message, optional, tag = "6")]
pub usage_check_models: ::core::option::Option<key_config::UsageCheckModel>,
pub usage_check_models: Option<key_config::UsageCheckModel>,
/// 包含网络引用
#[prost(bool, optional, tag = "7")]
pub include_web_references: ::core::option::Option<bool>,
pub include_web_references: Option<bool>,
}
/// Nested message and enum types in `KeyConfig`.
pub mod key_config {
@@ -24,7 +24,7 @@ pub mod key_config {
pub struct TokenInfo {
/// 用户标识符
#[prost(string, tag = "1")]
pub sub: ::prost::alloc::string::String,
pub sub: String,
/// 生成时间Unix 时间戳)
#[prost(int64, tag = "2")]
pub start: i64,
@@ -33,19 +33,19 @@ pub mod key_config {
pub end: i64,
/// 随机字符串
#[prost(string, tag = "4")]
pub randomness: ::prost::alloc::string::String,
pub randomness: String,
/// 签名
#[prost(string, tag = "5")]
pub signature: ::prost::alloc::string::String,
pub signature: String,
/// 机器ID的SHA256哈希值
#[prost(bytes = "vec", tag = "6")]
pub machine_id: ::prost::alloc::vec::Vec<u8>,
pub machine_id: Vec<u8>,
/// MAC地址的SHA256哈希值
#[prost(bytes = "vec", tag = "7")]
pub mac_id: ::prost::alloc::vec::Vec<u8>,
pub mac_id: Vec<u8>,
/// 代理名称
#[prost(string, optional, tag = "8")]
pub proxy_name: ::core::option::Option<::prost::alloc::string::String>,
pub proxy_name: Option<String>,
}
/// 使用量检查模型规则
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -55,7 +55,7 @@ pub mod key_config {
pub r#type: i32,
/// 模型 ID 列表,当 type 为 TYPE_CUSTOM 时生效
#[prost(string, repeated, tag = "2")]
pub model_ids: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
pub model_ids: Vec<String>,
}
/// Nested message and enum types in `UsageCheckModel`.
pub mod usage_check_model {
@@ -88,7 +88,7 @@ pub mod key_config {
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
pub fn from_str_name(value: &str) -> Option<Self> {
match value {
"TYPE_DEFAULT" => Some(Self::Default),
"TYPE_DISABLED" => Some(Self::Disabled),

View File

@@ -66,6 +66,7 @@ def_pub_const!(
O1 => "o1",
O3_MINI => "o3-mini",
GPT_4_5_PREVIEW => "gpt-4.5-preview",
GPT_4_1 => "gpt-4.1",
// Cursor 模型
CURSOR_FAST => "cursor-fast",
@@ -73,7 +74,6 @@ def_pub_const!(
// Google 模型
GEMINI_1_5_FLASH_500K => "gemini-1.5-flash-500k",
GEMINI_EXP_1206 => "gemini-exp-1206",
GEMINI_2_0_PRO_EXP => "gemini-2.0-pro-exp",
GEMINI_2_5_PRO_EXP_03_25 => "gemini-2.5-pro-exp-03-25",
GEMINI_2_5_PRO_MAX => "gemini-2.5-pro-max",
@@ -83,9 +83,12 @@ def_pub_const!(
// Deepseek 模型
DEEPSEEK_V3 => "deepseek-v3",
DEEPSEEK_R1 => "deepseek-r1",
DEEPSEEK_V3_1 => "deepseek-v3.1",
// XAI 模型
GROK_2 => "grok-2",
GROK_3_BETA => "grok-3-beta",
GROK_3_MINI_BETA => "grok-3-mini-beta",
// 未知模型
DEFAULT => "default",
@@ -103,6 +106,8 @@ macro_rules! create_models {
created: CREATED,
object: MODEL_OBJECT,
owned_by: $owner,
is_thinking: SUPPORTED_THINKING_MODELS.contains(&$model),
is_image: SUPPORTED_IMAGE_MODELS.contains(&$model),
},
)*
]),
@@ -139,12 +144,8 @@ impl Models {
// }
// 查找模型并返回其 ID
pub fn find_id(model: &str) -> Option<&'static str> {
Self::read()
.models
.iter()
.find(|m| m.id == model)
.map(|m| m.id)
pub fn find_id(model: &str) -> Option<Model> {
Self::read().models.iter().find(|m| m.id == model).copied()
}
// 返回所有模型 ID 的列表
@@ -228,22 +229,21 @@ create_models!(
DEEPSEEK_R1 => DEEPSEEK,
O3_MINI => OPENAI,
GROK_2 => XAI,
DEEPSEEK_V3_1 => DEEPSEEK,
GROK_3_BETA => XAI,
GROK_3_MINI_BETA => XAI,
GPT_4_1 => OPENAI,
);
pub const USAGE_CHECK_MODELS: [&str; 13] = [
CLAUDE_3_5_SONNET,
CLAUDE_3_7_SONNET,
CLAUDE_3_7_SONNET_THINKING,
GEMINI_EXP_1206,
GPT_4,
GPT_4_TURBO_2024_04_09,
GPT_4O,
CLAUDE_3_5_HAIKU,
GPT_4O_128K,
GEMINI_1_5_FLASH_500K,
CLAUDE_3_HAIKU_200K,
CLAUDE_3_5_SONNET_200K,
DEEPSEEK_R1,
pub const FREE_MODELS: [&str; 8] = [
CURSOR_FAST,
CURSOR_SMALL,
GPT_4O_MINI,
GPT_3_5_TURBO,
DEEPSEEK_V3,
DEEPSEEK_V3_1,
GROK_3_MINI_BETA,
GPT_4_1,
];
pub const LONG_CONTEXT_MODELS: [&str; 4] = [
@@ -253,14 +253,37 @@ pub const LONG_CONTEXT_MODELS: [&str; 4] = [
CLAUDE_3_5_SONNET_200K,
];
pub const SUPPORTED_IMAGE_MODELS: [&str; 9] = [
const SUPPORTED_THINKING_MODELS: [&str; 10] = [
CLAUDE_3_7_SONNET_THINKING,
CLAUDE_3_7_SONNET_THINKING_MAX,
O1_MINI,
O1_PREVIEW,
O1,
GEMINI_2_5_PRO_EXP_03_25,
GEMINI_2_5_PRO_MAX,
GEMINI_2_0_FLASH_THINKING_EXP,
DEEPSEEK_R1,
O3_MINI,
];
const SUPPORTED_IMAGE_MODELS: [&str; 19] = [
DEFAULT,
CLAUDE_3_5_SONNET,
CLAUDE_3_7_SONNET,
CLAUDE_3_7_SONNET_THINKING,
GPT_4O,
GPT_4O_MINI,
DEFAULT,
CLAUDE_3_OPUS,
CLAUDE_3_5_HAIKU,
CLAUDE_3_7_SONNET_MAX,
CLAUDE_3_7_SONNET_THINKING_MAX,
GPT_4,
GPT_4O,
GPT_4_5_PREVIEW,
CLAUDE_3_OPUS,
GPT_4_TURBO_2024_04_09,
GPT_4O_128K,
CLAUDE_3_HAIKU_200K,
CLAUDE_3_5_SONNET_200K,
GPT_4O_MINI,
CLAUDE_3_5_HAIKU,
GEMINI_2_5_PRO_EXP_03_25,
GEMINI_2_5_PRO_MAX,
GPT_4_1,
];

View File

@@ -1,6 +1,6 @@
use std::sync::Arc;
use serde::{ser::SerializeStruct as _, Deserialize, Serialize};
use serde::{Deserialize, Serialize, ser::SerializeStruct as _};
#[derive(Serialize, Deserialize)]
#[serde(untagged)]
@@ -12,7 +12,7 @@ pub enum MessageContent {
#[derive(Serialize, Deserialize)]
pub struct VisionMessageContent {
#[serde(rename = "type")]
pub rtype: String,
pub r#type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -53,8 +53,8 @@ pub enum Role {
}
#[derive(Serialize)]
pub struct ChatResponse {
pub id: String,
pub struct ChatResponse<'a> {
pub id: &'a str,
pub object: &'static str,
pub created: i64,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -116,13 +116,16 @@ pub struct StreamOptions {
pub include_usage: bool,
}
// 模型定义
/// 模型定义
#[derive(Clone, Copy)]
pub struct Model {
pub id: &'static str,
pub display_name: &'static str,
pub created: &'static i64,
pub object: &'static str,
pub owned_by: &'static str,
pub is_thinking: bool,
pub is_image: bool,
}
impl Serialize for Model {
@@ -130,7 +133,7 @@ impl Serialize for Model {
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("Model", 7)?;
let mut state = serializer.serialize_struct("Model", 9)?;
state.serialize_field("id", &self.id)?;
state.serialize_field("display_name", &self.display_name)?;
@@ -139,6 +142,8 @@ impl Serialize for Model {
state.serialize_field("object", &self.object)?;
state.serialize_field("type", &self.object)?;
state.serialize_field("owned_by", &self.owned_by)?;
state.serialize_field("supports_thinking", &self.is_thinking)?;
state.serialize_field("supports_images", &self.is_image)?;
state.end()
}
@@ -150,19 +155,19 @@ impl PartialEq for Model {
}
}
use super::constant::{Models, USAGE_CHECK_MODELS};
use super::constant::{FREE_MODELS, Models};
use crate::{
app::model::{AppConfig, UsageCheck},
common::model::tri::TriState,
};
impl Model {
pub fn is_usage_check(model_id: &str, usage_check: Option<UsageCheck>) -> bool {
pub fn is_usage_check(&self, usage_check: Option<UsageCheck>) -> bool {
match usage_check.unwrap_or(AppConfig::get_usage_check()) {
UsageCheck::None => false,
UsageCheck::Default => USAGE_CHECK_MODELS.contains(&model_id),
UsageCheck::Default => !FREE_MODELS.contains(&self.id),
UsageCheck::All => true,
UsageCheck::Custom(models) => models.contains(&model_id),
UsageCheck::Custom(models) => models.contains(&self.id),
}
}
}

View File

@@ -6,7 +6,7 @@ use axum::{
use serde::Deserialize;
use crate::{
app::constant::CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8,
app::constant::header_value_text_plain_utf8,
common::utils::{
generate_checksum_with_default, generate_checksum_with_repair, generate_hash,
generate_timestamp_header,
@@ -16,11 +16,7 @@ use crate::{
pub async fn handle_get_hash() -> Response {
let hash = generate_hash();
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8.parse().unwrap(),
);
let headers = HeaderMap::from_iter([(CONTENT_TYPE, header_value_text_plain_utf8().clone())]);
(headers, hash).into_response()
}
@@ -37,11 +33,7 @@ pub async fn handle_get_checksum(Query(query): Query<ChecksumQuery>) -> Response
Some(checksum) => generate_checksum_with_repair(&checksum),
};
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8.parse().unwrap(),
);
let headers = HeaderMap::from_iter([(CONTENT_TYPE, header_value_text_plain_utf8().clone())]);
(headers, checksum).into_response()
}
@@ -49,11 +41,7 @@ pub async fn handle_get_checksum(Query(query): Query<ChecksumQuery>) -> Response
pub async fn handle_get_timestamp_header() -> Response {
let timestamp_header = generate_timestamp_header();
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8.parse().unwrap(),
);
let headers = HeaderMap::from_iter([(CONTENT_TYPE, header_value_text_plain_utf8().clone())]);
(headers, timestamp_header).into_response()
}

View File

@@ -1,22 +1,19 @@
use crate::{
app::{
constant::{
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_API_PATH,
ROUTE_BASIC_CALIBRATION_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH,
ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM, ROUTE_GET_HASH, ROUTE_GET_TIMESTAMP_HEADER,
ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH, ROUTE_PROXIES_ADD_PATH, ROUTE_PROXIES_DELETE_PATH,
ROUTE_PROXIES_GET_PATH, ROUTE_PROXIES_PATH, ROUTE_PROXIES_SET_GENERAL_PATH,
ROUTE_PROXIES_SET_PATH, ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH,
ROUTE_TOKEN_UPGRADE_PATH, ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_BY_TAG_GET_PATH,
ROUTE_TOKENS_DELETE_PATH, ROUTE_TOKENS_GET_PATH, ROUTE_TOKENS_PATH,
ROUTE_TOKENS_PROFILE_UPDATE_PATH, ROUTE_TOKENS_SET_PATH, ROUTE_TOKENS_STATUS_SET_PATH,
ROUTE_TOKENS_TAGS_GET_PATH, ROUTE_TOKENS_TAGS_SET_PATH, ROUTE_TOKENS_UPGRADE_PATH,
ROUTE_USER_INFO_PATH,
},
lazy::{
AUTH_TOKEN, ROUTE_CHAT_PATH, ROUTE_MESSAGES_PATH, ROUTE_MODELS_PATH, get_start_time,
PKG_VERSION, ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BASIC_CALIBRATION_PATH,
ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH, ROUTE_ENV_EXAMPLE_PATH, ROUTE_GET_CHECKSUM,
ROUTE_GET_HASH, ROUTE_GET_TIMESTAMP_HEADER, ROUTE_HEALTH_PATH, ROUTE_LOGS_PATH,
ROUTE_PROXIES_ADD_PATH, ROUTE_PROXIES_DELETE_PATH, ROUTE_PROXIES_GET_PATH,
ROUTE_PROXIES_PATH, ROUTE_PROXIES_SET_GENERAL_PATH, ROUTE_PROXIES_SET_PATH,
ROUTE_README_PATH, ROUTE_ROOT_PATH, ROUTE_STATIC_PATH, ROUTE_TOKEN_UPGRADE_PATH,
ROUTE_TOKENS_ADD_PATH, ROUTE_TOKENS_BY_TAG_GET_PATH, ROUTE_TOKENS_DELETE_PATH,
ROUTE_TOKENS_GET_PATH, ROUTE_TOKENS_PATH, ROUTE_TOKENS_PROFILE_UPDATE_PATH,
ROUTE_TOKENS_SET_PATH, ROUTE_TOKENS_STATUS_SET_PATH, ROUTE_TOKENS_TAGS_GET_PATH,
ROUTE_TOKENS_TAGS_SET_PATH, ROUTE_TOKENS_UPGRADE_PATH, ROUTE_USER_INFO_PATH,
header_value_text_html_utf8, header_value_text_plain_utf8,
},
lazy::{ROUTE_CHAT_PATH, ROUTE_MODELS_PATH, get_start_time},
model::{AppConfig, AppState, PageContent},
},
common::model::{
@@ -30,12 +27,11 @@ use axum::{
body::Body,
extract::State,
http::{
HeaderMap, StatusCode,
StatusCode,
header::{CONTENT_TYPE, LOCATION},
},
response::{IntoResponse, Response},
};
use reqwest::header::AUTHORIZATION;
use std::sync::Arc;
use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System};
use tokio::sync::Mutex;
@@ -48,20 +44,19 @@ pub async fn handle_root() -> impl IntoResponse {
.body(Body::empty())
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}
}
static ENDPOINTS: std::sync::LazyLock<[&'static str; 34]> = std::sync::LazyLock::new(|| {
static ENDPOINTS: std::sync::LazyLock<[&'static str; 33]> = std::sync::LazyLock::new(|| {
[
&*ROUTE_CHAT_PATH,
&*ROUTE_MESSAGES_PATH,
&*ROUTE_MODELS_PATH,
ROUTE_TOKENS_PATH,
ROUTE_TOKENS_GET_PATH,
@@ -97,20 +92,12 @@ static ENDPOINTS: std::sync::LazyLock<[&'static str; 34]> = std::sync::LazyLock:
]
});
pub async fn handle_health(
State(state): State<Arc<Mutex<AppState>>>,
headers: HeaderMap,
) -> Json<HealthCheckResponse> {
pub async fn handle_health(State(state): State<Arc<Mutex<AppState>>>) -> Json<HealthCheckResponse> {
let start_time = get_start_time();
let uptime = (chrono::Local::now() - start_time).num_seconds();
// 先检查 headers 是否包含有效的认证信息
let stats = if headers
.get(AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
.is_some_and(|token| token == AUTH_TOKEN.as_str())
{
let stats = {
// 只有在需要系统信息时才创建实例
let mut sys = System::new_with_specifics(
RefreshKind::nothing()
@@ -135,7 +122,7 @@ pub async fn handle_health(
let state = state.lock().await;
Some(SystemStats {
SystemStats {
started: start_time.to_string(),
total_requests: state.request_manager.total_requests,
active_requests: state.request_manager.active_requests,
@@ -147,9 +134,7 @@ pub async fn handle_health(
usage: cpu_usage, // CPU 使用率(百分比)
},
},
})
} else {
None
}
};
Json(HealthCheckResponse {

View File

@@ -1,8 +1,8 @@
use crate::{
app::{
constant::{
AUTHORIZATION_BEARER_PREFIX, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_LOGS_PATH,
AUTHORIZATION_BEARER_PREFIX, ROUTE_LOGS_PATH, header_value_text_html_utf8,
header_value_text_plain_utf8,
},
lazy::AUTH_TOKEN,
model::{AppConfig, AppState, LogStatus, PageContent, RequestLog},
@@ -30,15 +30,15 @@ use tokio::sync::Mutex;
pub async fn handle_logs() -> impl IntoResponse {
match AppConfig::get_page_content(ROUTE_LOGS_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(include_str!("../../../static/logs.min.html")))
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -92,7 +92,7 @@ pub async fn handle_logs_post(
active: None,
error: None,
logs: Vec::new(),
timestamp: Local::now().to_string(),
timestamp: Local::now(),
}));
}
}
@@ -108,7 +108,7 @@ pub async fn handle_logs_post(
active: None,
error: None,
logs: Vec::new(),
timestamp: Local::now().to_string(),
timestamp: Local::now(),
}));
}
}
@@ -117,123 +117,112 @@ pub async fn handle_logs_post(
};
// 准备日志数据(管理员或特定用户的)
let logs = if auth_header == auth_token {
state.request_manager.request_logs.clone()
} else {
let mut iterator = Box::new(state.request_manager.request_logs.iter())
as Box<dyn Iterator<Item = &RequestLog>>;
if auth_header != auth_token {
// 解析 token
let token_part = extract_token(auth_header).ok_or(StatusCode::UNAUTHORIZED)?;
// 筛选符合条件的日志
let filtered_logs: Vec<RequestLog> = state
.request_manager
.request_logs
.iter()
.filter(|log| log.token_info.token == token_part)
.cloned()
.collect();
if filtered_logs.is_empty() {
return Err(StatusCode::UNAUTHORIZED);
}
filtered_logs
iterator = Box::new(iterator.filter(move |log| log.token_info.token == token_part));
};
// 应用查询参数过滤
let mut result_logs = logs;
// 按状态过滤
if let Some(status) = &request.query.status {
result_logs.retain(|log| log.status.as_str_name() == *status);
iterator = Box::new(iterator.filter(move |log| log.status.as_str_name() == status));
}
// 按模型过滤
if let Some(model) = &request.query.model {
result_logs.retain(|log| log.model.contains(model));
iterator = Box::new(iterator.filter(move |log| log.model.contains(model)));
}
// 按用户邮箱过滤
if let Some(email) = &request.query.email {
result_logs.retain(|log| {
iterator = Box::new(iterator.filter(move |log| {
log.token_info
.profile
.as_ref()
.map(|p| p.user.email.contains(email))
.unwrap_or(false)
});
}));
}
// 按会员类型过滤
if let Some(membership_type) = membership_enum {
result_logs.retain(|log| {
iterator = Box::new(iterator.filter(move |log| {
log.token_info
.profile
.as_ref()
.map(|p| p.stripe.membership_type == membership_type)
.unwrap_or(false)
});
}));
}
// 按总耗时范围过滤
if let Some(min_time) = request.query.min_total_time {
result_logs.retain(|log| log.timing.total >= min_time);
iterator = Box::new(iterator.filter(move |log| log.timing.total >= min_time));
}
if let Some(max_time) = request.query.max_total_time {
result_logs.retain(|log| log.timing.total <= max_time);
iterator = Box::new(iterator.filter(move |log| log.timing.total <= max_time));
}
// 按是否为流式请求过滤
if let Some(stream) = request.query.stream {
result_logs.retain(|log| log.stream == stream);
iterator = Box::new(iterator.filter(move |log| log.stream == stream));
}
// 按是否有错误过滤
if let Some(has_error) = request.query.has_error {
result_logs.retain(|log| log.error.is_some() == has_error);
iterator = Box::new(iterator.filter(move |log| log.error.is_some() == has_error));
}
// 按是否有chain过滤
if let Some(has_chain) = request.query.has_chain {
result_logs.retain(|log| log.chain.is_some() == has_chain);
iterator = Box::new(iterator.filter(move |log| log.chain.is_some() == has_chain));
}
// 按日期范围过滤
if let Some(from_date) = request.query.from_date {
result_logs.retain(|log| log.timestamp >= from_date);
iterator = Box::new(iterator.filter(move |log| log.timestamp >= from_date));
}
if let Some(to_date) = request.query.to_date {
result_logs.retain(|log| log.timestamp <= to_date);
iterator = Box::new(iterator.filter(move |log| log.timestamp <= to_date));
}
// 获取总数
let total = result_logs.len() as u64;
let filtered_log_refs: Vec<_> = iterator.collect();
let total = filtered_log_refs.len() as u64;
// 应用分页
if let Some(offset) = request.query.offset {
result_logs = result_logs.into_iter().skip(offset).collect();
}
let paginated_log_refs = filtered_log_refs
.into_iter()
.skip(request.query.offset.unwrap_or(0))
.take(request.query.limit.unwrap_or(usize::MAX));
if let Some(limit) = request.query.limit {
result_logs = result_logs.into_iter().take(limit).collect();
}
let result_logs: Vec<RequestLog> = paginated_log_refs.cloned().collect();
let active = if auth_header == auth_token {
Some(state.request_manager.active_requests)
} else {
None
};
let error = if auth_header == auth_token {
Some(state.request_manager.error_requests)
} else {
None
};
drop(state);
Ok(Json(LogsResponse {
status: ApiStatus::Success,
total,
active: if auth_header == auth_token {
Some(state.request_manager.active_requests)
} else {
None
},
error: if auth_header == auth_token {
Some(state.request_manager.error_requests)
} else {
None
},
active,
error,
logs: result_logs,
timestamp: Local::now().to_string(),
timestamp: Local::now(),
}))
}
@@ -246,5 +235,5 @@ pub struct LogsResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<u64>,
pub logs: Vec<RequestLog>,
pub timestamp: String,
pub timestamp: DateTime<Local>,
}

View File

@@ -1,9 +1,9 @@
use crate::app::{
constant::{
CONTENT_TYPE_TEXT_CSS_WITH_UTF8, CONTENT_TYPE_TEXT_HTML_WITH_UTF8,
CONTENT_TYPE_TEXT_JS_WITH_UTF8, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8, ROUTE_ABOUT_PATH,
ROUTE_API_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH, ROUTE_PROXIES_PATH,
ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH, ROUTE_TOKENS_PATH,
ROUTE_ABOUT_PATH, ROUTE_API_PATH, ROUTE_BUILD_KEY_PATH, ROUTE_CONFIG_PATH,
ROUTE_PROXIES_PATH, ROUTE_README_PATH, ROUTE_SHARED_JS_PATH, ROUTE_SHARED_STYLES_PATH,
ROUTE_TOKENS_PATH, header_value_text_css_utf8, header_value_text_html_utf8,
header_value_text_js_utf8, header_value_text_plain_utf8,
},
model::{AppConfig, PageContent},
};
@@ -19,7 +19,7 @@ use axum::{
pub async fn handle_env_example() -> impl IntoResponse {
Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(include_str!("../../../.env.example")))
.unwrap()
}
@@ -28,15 +28,15 @@ pub async fn handle_env_example() -> impl IntoResponse {
pub async fn handle_config_page() -> impl IntoResponse {
match AppConfig::get_page_content(ROUTE_CONFIG_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(include_str!("../../../static/config.min.html")))
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -47,13 +47,13 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
"shared-styles.css" => {
match AppConfig::get_page_content(ROUTE_SHARED_STYLES_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_css_utf8())
.body(Body::from(include_str!(
"../../../static/shared-styles.min.css"
)))
.unwrap(),
PageContent::Text(content) | PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_CSS_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_css_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -61,13 +61,13 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
"shared.js" => {
match AppConfig::get_page_content(ROUTE_SHARED_JS_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_js_utf8())
.body(Body::from(
include_str!("../../../static/shared.min.js").to_string(),
))
.unwrap(),
PageContent::Text(content) | PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_JS_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_js_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -82,15 +82,15 @@ pub async fn handle_static(Path(path): Path<String>) -> impl IntoResponse {
pub async fn handle_readme() -> impl IntoResponse {
match AppConfig::get_page_content(ROUTE_README_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(include_str!("../../../static/readme.min.html")))
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -104,11 +104,11 @@ pub async fn handle_about() -> impl IntoResponse {
.body(Body::empty())
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -117,17 +117,17 @@ pub async fn handle_about() -> impl IntoResponse {
pub async fn handle_build_key_page() -> impl IntoResponse {
match AppConfig::get_page_content(ROUTE_BUILD_KEY_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(include_str!(
"../../../static/build_key.min.html"
)))
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -136,15 +136,15 @@ pub async fn handle_build_key_page() -> impl IntoResponse {
pub async fn handle_tokens_page() -> impl IntoResponse {
match AppConfig::get_page_content(ROUTE_TOKENS_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(include_str!("../../../static/tokens.min.html")))
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -153,15 +153,15 @@ pub async fn handle_tokens_page() -> impl IntoResponse {
pub async fn handle_proxies_page() -> impl IntoResponse {
match AppConfig::get_page_content(ROUTE_PROXIES_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(include_str!("../../../static/proxies.min.html")))
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}
@@ -170,15 +170,15 @@ pub async fn handle_proxies_page() -> impl IntoResponse {
pub async fn handle_api_page() -> impl IntoResponse {
match AppConfig::get_page_content(ROUTE_API_PATH).unwrap_or_default() {
PageContent::Default => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(include_str!("../../../static/api.min.html")))
.unwrap(),
PageContent::Text(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_PLAIN_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_plain_utf8())
.body(Body::from(content))
.unwrap(),
PageContent::Html(content) => Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_TEXT_HTML_WITH_UTF8)
.header(CONTENT_TYPE, header_value_text_html_utf8())
.body(Body::from(content))
.unwrap(),
}

View File

@@ -2,7 +2,8 @@ use crate::{
app::{
constant::{
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP, OBJECT_CHAT_COMPLETION,
OBJECT_CHAT_COMPLETION_CHUNK,
OBJECT_CHAT_COMPLETION_CHUNK, header_value_chunked, header_value_event_stream,
header_value_json, header_value_keep_alive, header_value_no_cache_revalidate,
},
lazy::{
AUTH_TOKEN, GENERAL_TIMEZONE, IS_NO_REQUEST_LOGS, IS_UNLIMITED_REQUEST_LOGS,
@@ -26,12 +27,10 @@ use crate::{
},
core::{
config::KeyConfig,
constant::{Models, USAGE_CHECK_MODELS},
constant::Models,
error::StreamError,
model::{
ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage,
},
stream::{StreamDecoder, StreamMessage},
model::{ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role},
stream::decoder::{StreamDecoder, StreamMessage},
},
leak::intern_string,
};
@@ -51,17 +50,17 @@ use axum::{
use bytes::Bytes;
use futures::StreamExt;
use prost::Message as _;
use std::{borrow::Cow, sync::atomic::{AtomicUsize, Ordering}};
use std::{
borrow::Cow,
sync::atomic::{AtomicUsize, Ordering},
};
use std::{
convert::Infallible,
sync::{Arc, atomic::AtomicBool},
};
use tokio::sync::Mutex;
use super::model::{ChatRequest, Model};
const NO_CACHE: &str = "no-cache, must-revalidate";
const KEEP_ALIVE: &str = "keep-alive";
use super::{constant::FREE_MODELS, model::ChatRequest};
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
@@ -106,7 +105,7 @@ pub async fn handle_models(
));
}
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
let index = CURRENT_KEY_INDEX.load(Ordering::Acquire) % token_infos.len();
let token_info = &token_infos[index];
is_pri = true;
(
@@ -317,19 +316,18 @@ pub async fn handle_chat(
if log.token_info.token == auth_token {
if let Some(profile) = &log.token_info.profile {
if profile.stripe.membership_type == MembershipType::Free {
let is_premium = USAGE_CHECK_MODELS.contains(&model);
need_profile_check = if is_premium {
profile
.usage
.premium
.max_requests
.is_some_and(|max| profile.usage.premium.num_requests >= max)
} else {
need_profile_check = if FREE_MODELS.contains(&model.id) {
profile
.usage
.standard
.max_requests
.is_some_and(|max| profile.usage.standard.num_requests >= max)
} else {
profile
.usage
.premium
.max_requests
.is_some_and(|max| profile.usage.premium.num_requests >= max)
};
}
break;
@@ -350,15 +348,14 @@ pub async fn handle_chat(
let next_id = state
.request_manager
.request_logs
.last()
.back()
.map_or(1, |log| log.id + 1);
current_id = next_id;
// 如果需要获取用户使用情况,创建后台任务获取profile
if Model::is_usage_check(
model,
UsageCheck::from_proto(current_config.usage_check_models.as_ref()),
) {
if model.is_usage_check(UsageCheck::from_proto(
current_config.usage_check_models.as_ref(),
)) {
let auth_token_clone = auth_token.clone();
let state_clone = state_clone.clone();
let log_id = next_id;
@@ -398,7 +395,7 @@ pub async fn handle_chat(
});
}
state.request_manager.request_logs.push(RequestLog {
state.request_manager.request_logs.push_back(RequestLog {
id: next_id,
timestamp: request_time,
model: intern_string(request.model),
@@ -546,19 +543,16 @@ pub async fn handle_chat(
let is_start = Arc::new(AtomicBool::new(true));
let start_time = std::time::Instant::now();
let decoder = Arc::new(Mutex::new(StreamDecoder::new()));
let is_usage_sent = Arc::new(AtomicBool::new(false));
let need_usage = if request.stream_options.is_some_and(|opt| opt.include_usage) {
Arc::new(Mutex::new(NeedUsage::Need {
client,
auth_token,
checksum,
client_key,
timezone,
is_pri,
}))
} else {
Arc::new(Mutex::new(NeedUsage::None))
};
let need_usage = Arc::new(Mutex::new(NeedUsage::Need {
is_need: request.stream_options.is_some_and(|opt| opt.include_usage),
client,
auth_token,
checksum,
client_key,
timezone,
is_pri,
}));
let usage_uuid = Arc::new(Mutex::new(None));
// 定义消息处理器的上下文结构体
struct MessageProcessContext<'a> {
@@ -568,15 +562,17 @@ pub async fn handle_chat(
start_time: std::time::Instant,
state: &'a Mutex<AppState>,
current_id: u64,
usage_uuid: &'a Mutex<Option<String>>,
need_usage: &'a Mutex<NeedUsage>,
is_usage_sent: &'a AtomicBool,
created: i64,
}
#[derive(Default)]
enum NeedUsage {
#[default]
None,
Taked,
Need {
is_need: bool,
client: reqwest::Client,
auth_token: String,
checksum: String,
@@ -589,7 +585,10 @@ pub async fn handle_chat(
impl NeedUsage {
#[inline(always)]
const fn is_need(&self) -> bool {
matches!(*self, Self::Need { .. })
match self {
Self::Taked => false,
Self::Need { is_need, .. } => *is_need,
}
}
#[inline(always)]
@@ -602,8 +601,8 @@ pub async fn handle_chat(
async fn process_messages(
messages: Vec<StreamMessage>,
ctx: &MessageProcessContext<'_>,
) -> String {
let mut response_data = String::new();
) -> Vec<u8> {
let mut response_data = Vec::new();
for message in messages {
match message {
@@ -611,7 +610,7 @@ pub async fn handle_chat(
let is_first = ctx.is_start.load(Ordering::Acquire);
let response = ChatResponse {
id: ctx.response_id.to_string(),
id: ctx.response_id,
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: chrono::Utc::now().timestamp(),
model: if is_first { Some(ctx.model) } else { None },
@@ -641,65 +640,13 @@ pub async fn handle_chat(
},
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
response_data.extend_from_slice(b"data: ");
response_data.extend_from_slice(&serde_json::to_vec(&response).unwrap());
response_data.extend_from_slice(b"\n\n");
}
StreamMessage::Usage(usage_uuid) => {
if !ctx.is_usage_sent.load(Ordering::Acquire) {
if let NeedUsage::Need {
client,
auth_token,
checksum,
client_key,
timezone,
is_pri,
} = ctx.need_usage.lock().await.take()
{
let usage = get_token_usage(
client,
&auth_token,
&checksum,
&client_key,
timezone,
is_pri,
usage_uuid,
)
.await;
if let Some(ref usage) = usage {
let mut state = ctx.state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == ctx.current_id)
{
if let Some(chain) = &mut log.chain {
chain.usage = OptionUsage::Uasge {
input: usage.prompt_tokens,
output: usage.completion_tokens,
}
}
}
}
let response = ChatResponse {
id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: chrono::Utc::now().timestamp(),
model: None,
choices: vec![],
usage: TriState::Some(usage.unwrap_or_default()),
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
ctx.is_usage_sent.store(true, Ordering::Release);
}
} else {
crate::debug_println!("usage is sent, but find {usage_uuid}");
if !usage_uuid.is_empty() && ctx.need_usage.lock().await.is_need() {
*ctx.usage_uuid.lock().await = Some(usage_uuid);
}
}
StreamMessage::StreamEnd => {
@@ -720,7 +667,7 @@ pub async fn handle_chat(
}
let response = ChatResponse {
id: ctx.response_id.to_string(),
id: ctx.response_id,
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: chrono::Utc::now().timestamp(),
model: None,
@@ -740,31 +687,65 @@ pub async fn handle_chat(
TriState::None
},
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
if !ctx.is_usage_sent.load(Ordering::Acquire)
&& ctx.need_usage.lock().await.is_need()
{
let response = ChatResponse {
id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: chrono::Utc::now().timestamp(),
model: None,
choices: vec![],
usage: TriState::Some(Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
}),
response_data.extend_from_slice(b"data: ");
response_data.extend_from_slice(&serde_json::to_vec(&response).unwrap());
response_data.extend_from_slice(b"\n\n");
if let Some(usage_uuid) = ctx.usage_uuid.lock().await.take() {
if let NeedUsage::Need {
is_need,
client,
auth_token,
checksum,
client_key,
timezone,
is_pri,
} = ctx.need_usage.lock().await.take()
{
let usage = if *crate::app::lazy::REAL_USAGE {
let usage = tokio::spawn(get_token_usage(
client, auth_token, checksum, client_key, timezone, is_pri,
usage_uuid,
))
.await
.unwrap_or_default();
if let Some(ref usage) = usage {
let mut state = ctx.state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == ctx.current_id)
{
if let Some(chain) = &mut log.chain {
chain.usage = OptionUsage::Uasge {
input: usage.prompt_tokens,
output: usage.completion_tokens,
}
}
}
}
usage
} else {
None
};
if is_need {
let response = ChatResponse {
id: ctx.response_id,
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: ctx.created,
model: None,
choices: vec![],
usage: TriState::Some(usage.unwrap_or_default()),
};
response_data.extend_from_slice(b"data: ");
response_data
.extend_from_slice(&serde_json::to_vec(&response).unwrap());
response_data.extend_from_slice(b"\n\n");
}
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
ctx.is_usage_sent.store(true, Ordering::Release);
};
}
}
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = ctx.state.try_lock() {
@@ -781,7 +762,7 @@ pub async fn handle_chat(
} else {
log.chain = Some(Chain {
prompt: Prompt::new(debug_prompt),
delays: vec![],
delays: None,
usage: OptionUsage::None,
});
}
@@ -866,6 +847,8 @@ pub async fn handle_chat(
}
}
let created = Arc::new(std::sync::OnceLock::new());
// 处理后续的stream
let stream = stream
.then({
@@ -877,27 +860,29 @@ pub async fn handle_chat(
let response_id = response_id.clone();
let is_start = is_start.clone();
let state = state.clone();
let is_usage_sent = is_usage_sent.clone();
let need_usage = need_usage.clone();
let usage_uuid = usage_uuid.clone();
let created = created.clone();
async move {
let chunk = match chunk {
Ok(c) => c,
Err(e) => {
crate::debug_println!("Find chunk error: {e}");
Err(_) => {
// crate::debug_println!("Find chunk error: {e}");
return Ok::<_, Infallible>(Bytes::new());
}
};
let ctx = MessageProcessContext {
response_id: &response_id,
model,
model: model.id,
is_start: &is_start,
start_time,
state: &state,
current_id,
usage_uuid: &usage_uuid,
need_usage: &need_usage,
is_usage_sent: &is_usage_sent,
created: *created.get_or_init(|| chrono::Utc::now().timestamp()),
};
// 使用decoder处理chunk
@@ -930,16 +915,16 @@ pub async fn handle_chat(
}
};
let mut response_data = String::new();
let mut response_data = Vec::new();
if let Some(first_msg) = decoder.lock().await.take_first_result() {
let first_response = process_messages(first_msg, &ctx).await;
response_data.push_str(&first_response);
response_data.extend_from_slice(&first_response);
}
let current_response = process_messages(messages, &ctx).await;
if !current_response.is_empty() {
response_data.push_str(&current_response);
response_data.extend_from_slice(&current_response);
}
Ok(Bytes::from(response_data))
@@ -970,10 +955,10 @@ pub async fn handle_chat(
}));
Ok(Response::builder()
.header(CACHE_CONTROL, NO_CACHE)
.header(CONNECTION, KEEP_ALIVE)
.header(CONTENT_TYPE, "text/event-stream")
.header(TRANSFER_ENCODING, "chunked")
.header(CACHE_CONTROL, header_value_no_cache_revalidate())
.header(CONNECTION, header_value_keep_alive())
.header(CONTENT_TYPE, header_value_event_stream())
.header(TRANSFER_ENCODING, header_value_chunked())
.body(Body::from_stream(stream))
.unwrap())
} else {
@@ -1095,13 +1080,7 @@ pub async fn handle_chat(
let (usage1, usage2) = if !usage_uuid.is_empty() {
let result = get_token_usage(
client,
&auth_token,
&checksum,
&client_key,
timezone,
is_pri,
usage_uuid,
client, auth_token, checksum, client_key, timezone, is_pri, usage_uuid,
)
.await;
let result2 = match result {
@@ -1117,10 +1096,10 @@ pub async fn handle_chat(
};
let response_data = ChatResponse {
id: format!("chatcmpl-{trace_id}"),
id: &format!("chatcmpl-{trace_id}"),
object: OBJECT_CHAT_COMPLETION,
created: chrono::Utc::now().timestamp(),
model: Some(model),
model: Some(model.id),
choices: vec![Choice {
index: 0,
message: Some(Message {
@@ -1155,11 +1134,11 @@ pub async fn handle_chat(
}
}
let data = serde_json::to_string(&response_data).unwrap();
let data = serde_json::to_vec(&response_data).unwrap();
Ok(Response::builder()
.header(CACHE_CONTROL, NO_CACHE)
.header(CONNECTION, KEEP_ALIVE)
.header(CONTENT_TYPE, "application/json")
.header(CACHE_CONTROL, header_value_no_cache_revalidate())
.header(CONNECTION, header_value_keep_alive())
.header(CONTENT_TYPE, header_value_json())
.header(CONTENT_LENGTH, data.len())
.body(Body::from(data))
.unwrap())

View File

@@ -1,2 +1 @@
mod decoder;
pub use decoder::*;
pub mod decoder;

View File

@@ -1,14 +1,28 @@
use crate::common::utils::InstantExt as _;
use crate::core::{
aiserver::v1::{StreamChatResponse, WebReference},
error::{ChatError, StreamError},
};
use bytes::{Buf, BytesMut};
use flate2::read::GzDecoder;
use prost::Message;
use std::io::Read;
use std::time::Instant;
// 解压gzip数据
pub trait InstantExt {
fn duration_as_secs_f32(&mut self) -> f32;
}
impl InstantExt for Instant {
#[inline]
fn duration_as_secs_f32(&mut self) -> f32 {
let now = Instant::now();
let duration = now.duration_since(*self);
*self = now;
duration.as_secs_f32()
}
}
/// 解压gzip数据
#[inline]
fn decompress_gzip(data: &[u8]) -> Option<Vec<u8>> {
if data.len() < 3
@@ -31,32 +45,6 @@ fn decompress_gzip(data: &[u8]) -> Option<Vec<u8>> {
}
}
pub trait ToMarkdown {
fn to_markdown(&self) -> String;
}
impl ToMarkdown for Vec<WebReference> {
#[inline]
fn to_markdown(&self) -> String {
if self.is_empty() {
return String::new();
}
let mut result = String::from("WebReferences:\n");
for (i, web_ref) in self.iter().enumerate() {
result.push_str(&format!(
"{}. [{}]({})<{}>\n",
i + 1,
web_ref.title,
web_ref.url,
web_ref.chunk
));
}
result.push('\n');
result
}
}
#[derive(PartialEq, Clone)]
pub enum StreamMessage {
// 调试
@@ -64,6 +52,7 @@ pub enum StreamMessage {
// 网络引用
WebReference(Vec<WebReference>),
// 内容开始标志
#[cfg(test)]
ContentStart,
// 消息内容
Content(String),
@@ -77,18 +66,35 @@ impl StreamMessage {
#[inline]
fn convert_web_ref_to_content(self) -> Self {
match self {
StreamMessage::WebReference(refs) => StreamMessage::Content(refs.to_markdown()),
StreamMessage::WebReference(refs) => {
if refs.is_empty() {
return StreamMessage::Content(String::new());
}
let mut result = String::from("WebReferences:\n");
for (i, web_ref) in refs.iter().enumerate() {
result.push_str(&format!(
"{}. [{}]({})<{}>\n",
i + 1,
web_ref.title,
web_ref.url,
web_ref.chunk
));
}
result.push('\n');
StreamMessage::Content(result)
}
other => other,
}
}
}
pub struct StreamDecoder {
// 主要数据缓冲区 (24字节)
buffer: Vec<u8>,
// 结果相关 (24字节 + 24字节)
// 主要数据缓冲区
buffer: BytesMut,
// 结果相关 (24字节 + 48字节)
first_result: Option<Vec<StreamMessage>>,
content_delays: Vec<(String, f64)>,
content_delays: Option<(String, Vec<(u32, f32)>)>,
// 计数器和时间 (8字节 + 8字节)
empty_stream_count: usize,
last_content_time: Instant,
@@ -101,9 +107,9 @@ pub struct StreamDecoder {
impl StreamDecoder {
pub fn new() -> Self {
Self {
buffer: Vec::new(),
buffer: BytesMut::new(),
first_result: None,
content_delays: Vec::new(),
content_delays: None,
empty_stream_count: 0,
last_content_time: Instant::now(),
first_result_ready: false,
@@ -112,10 +118,12 @@ impl StreamDecoder {
}
}
#[inline]
pub fn get_empty_stream_count(&self) -> usize {
self.empty_stream_count
}
#[inline]
pub fn reset_empty_stream_count(&mut self) {
if self.empty_stream_count > 0 {
crate::debug_println!(
@@ -126,10 +134,7 @@ impl StreamDecoder {
}
}
pub fn increment_empty_stream_count(&mut self) {
self.empty_stream_count += 1;
}
#[inline]
pub fn take_first_result(&mut self) -> Option<Vec<StreamMessage>> {
if !self.buffer.is_empty() {
return None;
@@ -145,14 +150,17 @@ impl StreamDecoder {
!self.buffer.is_empty()
}
#[inline]
pub fn is_first_result_ready(&self) -> bool {
self.first_result_ready
}
pub fn take_content_delays(&mut self) -> Vec<(String, f64)> {
#[inline]
pub fn take_content_delays(&mut self) -> Option<(String, Vec<(u32, f32)>)> {
std::mem::take(&mut self.content_delays)
}
#[inline]
pub fn no_first_cache(mut self) -> Self {
self.first_result_ready = true;
self.first_result_taken = true;
@@ -172,7 +180,7 @@ impl StreamDecoder {
if self.buffer.len() < 5 {
if self.buffer.is_empty() {
self.increment_empty_stream_count();
self.empty_stream_count += 1;
return Err(StreamError::EmptyStream);
}
@@ -182,20 +190,68 @@ impl StreamDecoder {
self.reset_empty_stream_count();
let mut messages = Vec::new();
let reserve = {
let mut offset = 0;
let mut count = 0;
while offset + 5 <= self.buffer.len() {
let msg_len: usize;
// SAFETY: The loop condition `offset + 5 <= self.buffer.len()` guarantees
// that indices `offset` through `offset + 4` are within bounds.
unsafe {
msg_len = u32::from_be_bytes([
*self.buffer.get_unchecked(offset + 1),
*self.buffer.get_unchecked(offset + 2),
*self.buffer.get_unchecked(offset + 3),
*self.buffer.get_unchecked(offset + 4),
]) as usize;
}
if msg_len == 0 {
offset += 5;
continue;
}
if offset + 5 + msg_len > self.buffer.len() {
break;
}
offset += 5 + msg_len;
count += 1;
}
count
};
if let Some(content_delays) = self.content_delays.as_mut() {
content_delays.0.reserve(reserve);
content_delays.1.reserve(reserve);
} else {
self.content_delays =
Some((String::with_capacity(reserve), Vec::with_capacity(reserve)));
}
let mut messages = Vec::with_capacity(reserve);
let mut offset = 0;
while offset + 5 <= self.buffer.len() {
let msg_type = self.buffer[offset];
let msg_len = u32::from_be_bytes([
self.buffer[offset + 1],
self.buffer[offset + 2],
self.buffer[offset + 3],
self.buffer[offset + 4],
]) as usize;
let msg_type: u8;
let msg_len: usize;
// SAFETY: The loop condition `offset + 5 <= self.buffer.len()` guarantees
// that indices `offset` through `offset + 4` are within bounds.
unsafe {
msg_type = *self.buffer.get_unchecked(offset);
msg_len = u32::from_be_bytes([
*self.buffer.get_unchecked(offset + 1),
*self.buffer.get_unchecked(offset + 2),
*self.buffer.get_unchecked(offset + 3),
*self.buffer.get_unchecked(offset + 4),
]) as usize;
}
if msg_len == 0 {
offset += 5;
#[cfg(test)]
messages.push(StreamMessage::ContentStart);
continue;
}
@@ -209,8 +265,13 @@ impl StreamDecoder {
if let Some(msg) = self.process_message(msg_type, msg_data)? {
if let StreamMessage::Content(content) = &msg {
self.has_seen_content = true;
let delay = self.last_content_time.duration_as_secs_f64();
self.content_delays.push((content.clone(), delay));
let delay = self.last_content_time.duration_as_secs_f32();
if let Some(content_delays) = self.content_delays.as_mut() {
content_delays.0.push_str(content);
content_delays
.1
.push((content.chars().count() as u32, delay));
}
}
if convert_web_ref {
messages.push(msg.convert_web_ref_to_content());
@@ -222,7 +283,7 @@ impl StreamDecoder {
offset += 5 + msg_len;
}
self.buffer.drain(..offset);
self.buffer.advance(offset);
if !self.first_result_taken && !messages.is_empty() {
if self.first_result.is_none() {

View File

@@ -7,12 +7,11 @@ struct StringPool {
}
impl StringPool {
// 驻留字符串
/// 驻留字符串
fn intern(&mut self, s: &str) -> &'static str {
if let Some(&interned) = self.pool.get(s) {
interned
} else {
// 如果字符串不存在,使用 Box::leak 将其泄漏,并添加到 pool 中
let leaked: &'static str = Box::leak(Box::from(s));
self.pool.insert(leaked);
leaked
@@ -20,60 +19,9 @@ impl StringPool {
}
}
// 全局 StringPool 实例
static STRING_POOL: LazyLock<Mutex<StringPool>> =
LazyLock::new(|| Mutex::new(StringPool::default()));
pub fn intern_string<S: AsRef<str>>(s: S) -> &'static str {
STRING_POOL.lock().intern(s.as_ref())
}
// #[derive(Clone, Copy, PartialEq, Eq, Hash)]
// pub struct InternedString(&'static str);
// impl InternedString {
// #[inline(always)]
// pub fn as_str(&self) -> &'static str {
// self.0
// }
// }
// impl Deref for InternedString {
// type Target = str;
// #[inline(always)]
// fn deref(&self) -> &'static Self::Target {
// self.0
// }
// }
// impl Borrow<str> for InternedString {
// #[inline(always)]
// fn borrow(&self) -> &'static str {
// self.0
// }
// }
// impl core::fmt::Debug for InternedString {
// #[inline(always)]
// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
// self.0.fmt(f)
// }
// }
// impl core::fmt::Display for InternedString {
// #[inline(always)]
// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
// self.0.fmt(f)
// }
// }
// impl serde::Serialize for InternedString {
// #[inline]
// fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
// where
// S: serde::Serializer,
// {
// serializer.serialize_str(self.0)
// }
// }

View File

@@ -649,6 +649,7 @@
<th>Token信息</th>
<th>对话</th>
<th>用时</th>
<th>输入/输出</th>
<th>流式响应</th>
<th>状态</th>
<th>错误信息</th>
@@ -1043,20 +1044,16 @@
updatePaginationControls();
tbody.innerHTML = data.logs.map(log => {
// 预处理延迟数据以避免HTML解析问题
let delaysData = '';
if (log.chain && log.chain.delays) {
// 为每个delay项创建安全的文本和数值对
const safeDelays = log.chain.delays.map(item => {
if (!Array.isArray(item) || item.length < 2) return ["", 0];
// 将延迟数据中的文本部分编码为Base64避免HTML解析问题
const textPart = typeof item[0] === 'string' ? btoa(encodeURIComponent(item[0])) : "";
const delayPart = typeof item[1] === 'number' ? item[1] : 0;
return [textPart, delayPart];
});
delaysData = JSON.stringify(safeDelays);
let delaysDataAttribute = '';
if (log.chain && Array.isArray(log.chain.delays) && log.chain.delays.length === 2 && typeof log.chain.delays[0] === 'string' && Array.isArray(log.chain.delays[1])) {
try {
const originalDelaysJson = JSON.stringify(log.chain.delays);
const escapedJsonString = escapeHtml(originalDelaysJson);
delaysDataAttribute = `data-delays='${escapedJsonString}'`;
} catch (e) {
console.error('Failed to stringify delays data:', e, log.chain.delays);
delaysDataAttribute = '';
}
}
return `<tr>
@@ -1064,11 +1061,12 @@
<td>${new Date(log.timestamp).toLocaleString()}</td>
<td>${log.model}</td>
<td><div class="token-info-tooltip"><button class="info-button" onclick='showTokenModal(${JSON.stringify(log.token_info)})'>查看详情<div class="tooltip-content">${formatSimpleTokenInfo(log.token_info)}</div></button></div></td>
<td>${log.chain ? `<div class="token-info-tooltip prompt-preview"><button class="info-button view-conversation" data-prompt="${encodeURIComponent(JSON.stringify(log.chain.prompt))}" data-delays='${delaysData}'>查看对话<div class="tooltip-content">${formatDialogPreview(JSON.stringify(log.chain.prompt))}</div></button></div>` : '-'}</td>
<td>${log.chain ? `<div class="token-info-tooltip prompt-preview"><button class="info-button view-conversation" data-prompt="${encodeURIComponent(JSON.stringify(log.chain.prompt))}" ${delaysDataAttribute}>查看对话<div class="tooltip-content">${formatDialogPreview(JSON.stringify(log.chain.prompt))}</div></button></div>` : '-'}</td>
<td>${formatTiming(log.timing.total)}</td>
<td>${formatUsage(log.chain?.usage)}</td>
<td>${log.stream ? '是' : '否'}</td>
<td>${log.status}</td>
<td>${log.error.details || log.error || '-'}</td>
<td>${typeof log.error === 'string' ? log.error : log.error?.error ?? '-'}</td>
</tr>`;
}).join('');
@@ -1372,6 +1370,36 @@
return messages;
}
function escapeHtml(content) {
// 先转义HTML特殊字符
const escaped = content
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#039;');
return escaped;
}
function escapeHtmlAndControlChars(content) {
// 先转义HTML特殊字符
let escaped = content
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#039;');
// 然后转义控制字符
escaped = escaped
.replace(/\\/g, '\\\\')
.replace(/\n/g, '\\n')
.replace(/\t/g, '\\t')
.replace(/\r/g, '\\r');
return escaped;
}
/**
* 格式化对话内容为HTML表格
* @param {Array<{role: string, content: string}>} messages - 对话消息数组
@@ -1388,17 +1416,6 @@
'assistant': '助手'
};
function escapeHtml(content) {
// 先转义HTML特殊字符
const escaped = content
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#039;');
return escaped;
}
return `<table class="message-table"><thead><tr><th>角色</th><th>内容</th></tr></thead><tbody>${messages.map(msg =>
`<tr><td>${roleLabels[msg.role] || msg.role}</td><td>${escapeHtml(msg.content).replace(/\n/g, '<br>')}</td></tr>`
).join('')}</tbody></table>`;
@@ -1407,72 +1424,102 @@
/**
* 显示对话详情弹窗
* @param {string} promptStr - 对话提示字符串
* @param {Array} delays - 延迟数据数组
* @param {Array} delaysTuple - 延迟数据数组
*/
function showConversationModal(promptStr, delays) {
function showConversationModal(promptStr, delaysTuple) {
try {
const modal = document.getElementById('conversationModal');
const dialogContent = document.getElementById('dialogContent');
const delaysContent = document.getElementById('delaysContent');
const tabPrompt = document.getElementById('tab-prompt');
const tabDelays = document.getElementById('tab-delays');
const delaysTableBody = document.querySelector('#delaysTable tbody');
const delayChartContainer = document.querySelector('.delay-chart-container');
if (!modal || !dialogContent || !delaysContent || !tabPrompt || !tabDelays) {
if (!modal || !dialogContent || !delaysContent || !tabPrompt || !tabDelays || !delaysTableBody || !delayChartContainer) {
console.error('Modal elements not found');
return;
}
// 显示对话内容
const messages = parsePrompt(promptStr);
dialogContent.innerHTML = formatPromptToTable(messages);
// 处理 Prompt
try {
const messages = parsePrompt(promptStr);
dialogContent.innerHTML = formatPromptToTable(messages);
} catch (e) {
console.error('解析 Prompt 数据失败:', e);
dialogContent.innerHTML = '<p>无法加载对话内容。</p>';
}
// 处理延迟数据
if (delays && delays.length > 0) {
const delaysTableBody = document.querySelector('#delaysTable tbody');
delaysTableBody.innerHTML = '';
// 处理 Delays
delaysTableBody.innerHTML = '';
delayChartContainer.innerHTML = '<canvas id="delayChart"></canvas>';
let fullText = '';
let delayPoints = [];
let chartDataPoints = [{ time: 0, chars: 0, text: '' }];
if (delaysTuple) {
if (Array.isArray(delaysTuple) && delaysTuple.length === 2 && typeof delaysTuple[0] === 'string' && Array.isArray(delaysTuple[1])) {
fullText = delaysTuple[0];
delayPoints = delaysTuple[1];
} else {
console.warn('Delays data format is incorrect:', delaysTuple);
}
}
if (delayPoints.length > 0) {
let currentIndex = 0;
let totalChars = 0;
let totalTime = 0;
// 解码并显示延迟数据
delays.forEach(([encodedText, delay], index) => {
try {
const text = encodedText ? decodeURIComponent(atob(encodedText)) : '';
totalChars += text.length;
totalTime += delay;
const rate = text.length / delay;
const avgRate = totalChars / totalTime;
const row = document.createElement('tr');
row.innerHTML = `
<td>${index + 1}</td>
<td>${text.replace(/&/g, '&amp;').replace(/</g, '&lt;').replace(/>/g, '&gt;').replace(/"/g, '&quot;').replace(/'/g, '&#039;')}</td>
<td>${delay.toFixed(3)}</td>
<td>${rate.toFixed(1)} (平均: ${avgRate.toFixed(1)})</td>
`;
delaysTableBody.appendChild(row);
} catch (e) {
console.error('处理延迟数据项失败:', e);
delayPoints.forEach(([length, deltaTime], index) => {
if (typeof length !== 'number' || typeof deltaTime !== 'number' || length < 0 || deltaTime < 0) {
console.warn(`Skipping invalid delay point at index ${index}:`, [length, deltaTime]);
return;
}
const chunkText = fullText.substring(currentIndex, currentIndex + length);
currentIndex += length;
totalChars += length;
totalTime += deltaTime;
const rate = deltaTime > 0 ? (length / deltaTime) : Infinity;
const avgRate = totalTime > 0 ? (totalChars / totalTime) : Infinity;
const row = document.createElement('tr');
row.innerHTML = `
<td>${index + 1}</td>
<td>${escapeHtmlAndControlChars(chunkText)}</td>
<td>${deltaTime.toFixed(3)}</td>
<td>${isFinite(rate) ? rate.toFixed(1) : 'N/A'} (平均: ${isFinite(avgRate) ? avgRate.toFixed(1) : 'N/A'})</td>
`;
delaysTableBody.appendChild(row);
chartDataPoints.push({
time: totalTime,
chars: totalChars,
text: chunkText
});
});
// 初始化延迟图表
initDelayChart(delays);
initDelayChart(chartDataPoints);
tabDelays.style.display = '';
} else {
document.querySelector('#delaysTable tbody').innerHTML = '<tr><td colspan="2">无延迟数据</td></tr>';
document.querySelector('.delay-chart-container').innerHTML = '<div style="text-align: center; padding: 20px;">无延迟数据可供分析</div>';
delaysTableBody.innerHTML = '<tr><td colspan="4">无延迟数据</td></tr>';
delayChartContainer.innerHTML = '<div style="text-align: center; padding: 20px;">无延迟数据可供分析</div>';
tabDelays.style.display = 'none';
}
// 设置标签切换事件
// 设置标签
tabPrompt.onclick = () => setActiveTab('prompt');
tabDelays.onclick = () => setActiveTab('delays');
// 设置默认激活的标签页
setActiveTab('prompt');
modal.style.display = 'block';
} catch (e) {
console.error('显示对话详情失败:', e);
console.error('原始prompt:', promptStr);
}
}
@@ -1506,10 +1553,14 @@
/**
* 初始化延迟图表
* @param {Array} delays - 延迟数组
* @param {Array<{time: number, chars: number, text: string}>} chartDataPoints - 包含累计时间和字符数以及块文本的数据点数组
*/
function initDelayChart(delays) {
if (!delays || delays.length <= 1) {
function initDelayChart(chartDataPoints) {
if (!chartDataPoints || chartDataPoints.length <= 1) {
const container = document.querySelector('.delay-chart-container');
if (container) {
container.innerHTML = '<div style="text-align: center; padding: 20px;">延迟数据不足,无法绘制图表</div>';
}
return;
}
@@ -1519,71 +1570,35 @@
return;
}
// 销毁之前的图表(如果存在)
const existingChart = Chart.getChart(ctx);
if (existingChart) {
existingChart.destroy();
}
// 计算字符数和累计时间
const rawDataPoints = [];
let totalChars = 0;
let accumulatedTime = 0;
const maxPoints = 100;
let sampledPoints = chartDataPoints;
// 解密所有数据点并计算累计时间
for (let i = 0; i < delays.length; i++) {
const [encodedText, delay] = delays[i];
if (encodedText && typeof delay === 'number') {
try {
const text = decodeURIComponent(atob(encodedText));
totalChars += text.length;
accumulatedTime += delay; // 累加延迟时间
rawDataPoints.push({
time: accumulatedTime, // 使用累计时间
chars: totalChars,
text: text
});
} catch (e) {
console.error('解码延迟数据失败:', e);
}
if (chartDataPoints.length > maxPoints) {
sampledPoints = [];
const interval = (chartDataPoints.length - 2) / (maxPoints - 2);
sampledPoints.push(chartDataPoints[0]);
for (let i = 1; i < maxPoints - 1; i++) {
const rawIndex = Math.round(i * interval);
const actualIndex = Math.min(rawIndex + 1, chartDataPoints.length - 2);
sampledPoints.push(chartDataPoints[actualIndex]);
}
sampledPoints.push(chartDataPoints[chartDataPoints.length - 1]);
}
// 优化数据点密度
const maxPoints = 10;
let dataPoints = [];
if (rawDataPoints.length > maxPoints) {
// 计算采样间隔
const interval = Math.floor(rawDataPoints.length / maxPoints);
// 确保包含第一个点
dataPoints.push(rawDataPoints[0]);
// 采样中间点,使用更大的间隔
for (let i = interval; i < rawDataPoints.length - interval; i += interval) {
// 在每个采样点周围计算平均值,使曲线更平滑
const start = Math.max(0, i - Math.floor(interval / 2));
const end = Math.min(rawDataPoints.length, i + Math.floor(interval / 2));
const avgPoint = rawDataPoints[i];
dataPoints.push(avgPoint);
}
// 确保包含最后一个点
if (dataPoints[dataPoints.length - 1] !== rawDataPoints[rawDataPoints.length - 1]) {
dataPoints.push(rawDataPoints[rawDataPoints.length - 1]);
}
} else {
dataPoints = rawDataPoints;
}
// 创建新图表
new Chart(ctx, {
type: 'line',
data: {
datasets: [{
label: '累计输出字符数',
data: dataPoints.map(point => ({
data: sampledPoints.map(point => ({
x: point.time,
y: point.chars
})),
@@ -1606,16 +1621,33 @@
tooltip: {
callbacks: {
title: function (context) {
return `用时: ${context[0].raw.x.toFixed(1)}`;
const pointIndex = context[0].dataIndex;
if (pointIndex < sampledPoints.length) {
return `用时: ${sampledPoints[pointIndex].time.toFixed(1)}`;
}
return '';
},
label: function (context) {
const point = context.raw;
const rate = point.x > 0 ? (point.y / point.x).toFixed(1) : 0;
return [
`字符数: ${point.y}`,
`平均速率: ${rate} 字符/秒`,
`文本: ${dataPoints[context.dataIndex].text}`
];
const pointIndex = context.dataIndex;
if (pointIndex < sampledPoints.length) {
const point = sampledPoints[pointIndex];
const prevPoint = pointIndex > 0 ? sampledPoints[pointIndex - 1] : { time: 0, chars: 0, text: '' };
const deltaTime = point.time - prevPoint.time;
const deltaChars = point.chars - prevPoint.chars;
const currentRate = deltaTime > 0 ? (deltaChars / deltaTime).toFixed(1) : 'N/A';
const avgRate = point.time > 0 ? (point.chars / point.time).toFixed(1) : 'N/A';
const chunkText = point.text || '';
return [
`累计字符: ${point.chars}`,
`平均速率: ${avgRate} 字符/秒`,
`当前块速率: ${currentRate} 字符/秒`,
`当前块文本: ${escapeHtmlAndControlChars(chunkText.length > 50 ? chunkText.substring(0, 50) + '...' : chunkText)}`
];
}
return '';
}
}
}
@@ -1625,14 +1657,15 @@
beginAtZero: true,
title: {
display: true,
text: '字符数'
text: '累计字符数'
}
},
x: {
type: 'linear',
beginAtZero: true,
title: {
display: true,
text: '用时(秒)'
text: '累计用时(秒)'
},
ticks: {
callback: function (value) {
@@ -1644,6 +1677,16 @@
}
});
}
/**
* 格式化使用量信息
* @param {Object} usage - 使用量对象 { input: number, output: number }
* @returns {string} 格式化后的字符串
*/
function formatUsage(usage) {
if (!usage) return '-';
return `${usage.input} / ${usage.output}`;
}
</script>
</body>