v0.1.3-rc.5.2-pre

This commit is contained in:
wisdgod
2025-03-05 04:21:37 +08:00
parent 0e65370ca2
commit 6e00911d7c
54 changed files with 6733 additions and 1877 deletions

View File

@@ -9,8 +9,8 @@ use crate::{
KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT,
},
model::{
AppConfig, AppState, ChatRequest, LogStatus, RequestLog, TimingInfo, TokenInfo,
UsageCheck,
AppConfig, AppState, Chain, LogStatus, RequestLog, TimingInfo, TokenInfo, UsageCheck,
proxy_pool::ProxyPool,
},
},
chat::{
@@ -23,10 +23,12 @@ use crate::{
stream::{StreamDecoder, StreamMessage},
},
common::{
client::build_client,
model::{ApiStatus, ErrorResponse, error::ChatError, userinfo::MembershipType},
client::build_request,
model::{
ApiStatus, ErrorResponse, error::ChatError, tri::TriState, userinfo::MembershipType,
},
utils::{
TrimNewlines as _, format_time_ms, from_base64, get_available_models,
InstantExt as _, TrimNewlines as _, format_time_ms, from_base64, get_available_models,
get_token_profile, tokeninfo_to_token, validate_token_and_checksum,
},
},
@@ -44,6 +46,7 @@ use axum::{
use bytes::Bytes;
use futures::StreamExt;
use prost::Message as _;
use reqwest::Client;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{
convert::Infallible,
@@ -52,7 +55,7 @@ use std::{
use tokio::sync::Mutex;
use uuid::Uuid;
use super::{constant::LONG_CONTEXT_MODELS, model::Model};
use super::model::{ChatRequest, Model};
// 辅助函数提取认证token
fn extract_auth_token(headers: &HeaderMap) -> Result<&str, (StatusCode, Json<ErrorResponse>)> {
@@ -70,7 +73,7 @@ fn extract_auth_token(headers: &HeaderMap) -> Result<&str, (StatusCode, Json<Err
async fn resolve_token_info(
auth_header: &str,
state: &Arc<Mutex<AppState>>,
) -> Result<(String, String), (StatusCode, Json<ErrorResponse>)> {
) -> Result<(String, String, Client), (StatusCode, Json<ErrorResponse>)> {
match auth_header {
// 管理员Token处理
token if is_admin_token(token) => resolve_admin_token(state).await,
@@ -79,10 +82,13 @@ async fn resolve_token_info(
token if is_dynamic_key(token) => resolve_dynamic_key(token),
// 普通用户Token处理
token => validate_token_and_checksum(token).ok_or((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
)),
token => {
let (token, checksum) = validate_token_and_checksum(token).ok_or((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
))?;
Ok((token, checksum, ProxyPool::get_general_client()))
}
}
}
@@ -100,7 +106,7 @@ fn is_dynamic_key(token: &str) -> bool {
// 辅助函数处理管理员token
async fn resolve_admin_token(
state: &Arc<Mutex<AppState>>,
) -> Result<(String, String), (StatusCode, Json<ErrorResponse>)> {
) -> Result<(String, String, Client), (StatusCode, Json<ErrorResponse>)> {
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
let state_guard = state.lock().await;
@@ -116,11 +122,17 @@ async fn resolve_admin_token(
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
let token_info = &token_infos[index];
Ok((token_info.token.clone(), token_info.checksum.clone()))
Ok((
token_info.token.clone(),
token_info.checksum.clone(),
token_info.get_client(),
))
}
// 辅助函数:处理动态密钥
fn resolve_dynamic_key(token: &str) -> Result<(String, String), (StatusCode, Json<ErrorResponse>)> {
fn resolve_dynamic_key(
token: &str,
) -> Result<(String, String, Client), (StatusCode, Json<ErrorResponse>)> {
from_base64(&token[*KEY_PREFIX_LEN..])
.and_then(|decoded_bytes| KeyConfig::decode(&decoded_bytes[..]).ok())
.and_then(|key_config| key_config.auth_token)
@@ -143,18 +155,20 @@ pub async fn handle_models(
// 提取和验证认证token
let auth_token = extract_auth_token(&headers)?;
let (token, checksum) = resolve_token_info(auth_token, &state).await?;
let (token, checksum, client) = resolve_token_info(auth_token, &state).await?;
// 获取可用模型列表
let models = get_available_models(&token, &checksum).await.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
status: ApiStatus::Failure,
code: Some(StatusCode::INTERNAL_SERVER_ERROR.as_u16()),
error: Some("Failed to fetch available models".to_string()),
message: Some("Unable to get available models".to_string()),
}),
))?;
let models = get_available_models(client, &token, &checksum)
.await
.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
status: ApiStatus::Failure,
code: Some(StatusCode::INTERNAL_SERVER_ERROR.as_u16()),
error: Some("Failed to fetch available models".to_string()),
message: Some("Unable to get available models".to_string()),
}),
))?;
// 更新模型列表
if let Err(e) = Models::update(models) {
@@ -221,7 +235,7 @@ pub async fn handle_chat(
let mut current_config = KeyConfig::new_with_global();
// 验证认证token并获取token信息
let (auth_token, checksum) = match auth_header {
let (auth_token, checksum, client) = match auth_header {
// 管理员Token验证逻辑
token
if token == AUTH_TOKEN.as_str()
@@ -242,7 +256,11 @@ pub async fn handle_chat(
// 轮询选择token
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
let token_info = &token_infos[index];
(token_info.token.clone(), token_info.checksum.clone())
(
token_info.token.clone(),
token_info.checksum.clone(),
token_info.get_client(),
)
}
token if AppConfig::get_dynamic_key() && token.starts_with(&*KEY_PREFIX) => {
@@ -260,10 +278,13 @@ pub async fn handle_chat(
}
// 普通用户Token验证逻辑
token => validate_token_and_checksum(token).ok_or((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
))?,
token => {
let (token, checksum) = validate_token_and_checksum(token).ok_or((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
))?;
(token, checksum, ProxyPool::get_general_client())
}
};
let current_config = current_config;
@@ -277,58 +298,32 @@ pub async fn handle_chat(
state.request_manager.total_requests += 1;
state.request_manager.active_requests += 1;
let mut found_count: u32 = 0;
let mut no_prompt_count: u32 = 0;
let mut need_profile_check = false;
for log in state.request_manager.request_logs.iter().rev() {
if log.token_info.token == auth_token {
if !LONG_CONTEXT_MODELS.contains(&log.model.as_str()) {
found_count += 1;
}
if log.prompt.is_none() {
no_prompt_count += 1;
}
if found_count == 1 && log.token_info.profile.is_some() {
if let Some(profile) = &log.token_info.profile {
if profile.stripe.membership_type == MembershipType::Free {
let is_premium = USAGE_CHECK_MODELS.contains(&model_name.as_str());
need_profile_check =
if is_premium {
profile.usage.premium.max_requests.is_some_and(|max| {
profile.usage.premium.num_requests >= max
})
} else {
profile.usage.standard.max_requests.is_some_and(|max| {
profile.usage.standard.num_requests >= max
})
};
}
if let Some(profile) = &log.token_info.profile {
if profile.stripe.membership_type == MembershipType::Free {
let is_premium = USAGE_CHECK_MODELS.contains(&model_name.as_str());
need_profile_check = if is_premium {
profile
.usage
.premium
.max_requests
.is_some_and(|max| profile.usage.premium.num_requests >= max)
} else {
profile
.usage
.standard
.max_requests
.is_some_and(|max| profile.usage.standard.num_requests >= max)
};
}
}
if found_count == 2 {
break;
}
}
}
if found_count == 2 && no_prompt_count == 2 {
state.request_manager.active_requests -= 1;
state.request_manager.error_requests += 1;
return Err((
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
status: ApiStatus::Error,
code: Some(429),
error: Some("rate_limit_exceeded".to_string()),
message: Some("Too many requests without prompt".to_string()),
}),
));
}
// 处理检查结果
if need_profile_check {
state.request_manager.active_requests -= 1;
@@ -359,9 +354,10 @@ pub async fn handle_chat(
let auth_token_clone = auth_token.clone();
let state_clone = state_clone.clone();
let log_id = next_id;
let client = client.clone();
tokio::spawn(async move {
let profile = get_token_profile(&auth_token_clone).await;
let profile = get_token_profile(client, &auth_token_clone).await;
let mut state = state_clone.lock().await;
// 先找到所有需要更新的位置的索引
@@ -404,11 +400,8 @@ pub async fn handle_chat(
profile: None,
tags: None,
},
prompt: None,
timing: TimingInfo {
total: 0.0,
first: None,
},
chain: None,
timing: TimingInfo { total: 0.0 },
stream: request.stream,
status: LogStatus::Pending,
error: None,
@@ -441,7 +434,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.status = LogStatus::Failure;
log.error = Some(e.to_string());
}
state.request_manager.active_requests -= 1;
@@ -456,7 +449,8 @@ pub async fn handle_chat(
};
// 构建请求客户端
let client = build_client(
let client = build_request(
client,
&auth_token,
&checksum,
if is_search {
@@ -492,7 +486,8 @@ pub async fn handle_chat(
}
resp
}
Err(e) => {
Err(mut e) => {
e = e.without_url();
// 更新请求日志为失败
{
let mut state = state.lock().await;
@@ -503,7 +498,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.status = LogStatus::Failure;
log.error = Some(e.to_string());
}
state.request_manager.active_requests -= 1;
@@ -526,7 +521,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.status = LogStatus::Failure;
log.error = Some("Request timeout".to_string());
}
state.request_manager.active_requests -= 1;
@@ -551,18 +546,19 @@ pub async fn handle_chat(
let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple());
let is_start = Arc::new(AtomicBool::new(true));
let start_time = std::time::Instant::now();
let first_chunk_time = Arc::new(Mutex::new(None::<f64>));
let decoder = Arc::new(Mutex::new(StreamDecoder::new()));
let content_time = Arc::new(Mutex::new(std::time::Instant::now()));
// 定义消息处理器的上下文结构体
struct MessageProcessContext<'a> {
response_id: &'a str,
model: &'a str,
is_start: &'a AtomicBool,
first_chunk_time: &'a Mutex<Option<f64>>,
start_time: std::time::Instant,
state: &'a Mutex<AppState>,
current_id: u64,
need_usage: bool,
content_time: &'a Mutex<std::time::Instant>,
}
// 处理消息并生成响应数据的辅助函数
@@ -576,9 +572,26 @@ pub async fn handle_chat(
match message {
StreamMessage::Content(text) => {
let is_first = ctx.is_start.load(Ordering::SeqCst);
if is_first {
if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() {
*first_time = Some(ctx.start_time.elapsed().as_secs_f64());
if let Ok(mut time_tracker) = ctx.content_time.try_lock() {
let interval = time_tracker.duration_as_secs_f64();
if let Ok(mut state) = ctx.state.try_lock() {
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == ctx.current_id)
{
if let Some(chain) = &mut log.chain {
chain.delays.push((text.clone(), interval));
} else {
log.chain = Some(Chain {
prompt: String::new(),
delays: vec![(text.clone(), interval)],
});
}
}
}
}
@@ -607,9 +620,14 @@ pub async fn handle_chat(
Some(text)
},
}),
logprobs: None,
finish_reason: None,
}],
usage: None,
usage: if ctx.need_usage {
TriState::Null
} else {
TriState::None
},
};
response_data.push_str(&format!(
@@ -620,7 +638,6 @@ pub async fn handle_chat(
StreamMessage::StreamEnd => {
// 计算总时间和首次片段时间
let total_time = ctx.start_time.elapsed().as_secs_f64();
let first_time = ctx.first_chunk_time.lock().await.unwrap_or(total_time);
{
let mut state = ctx.state.lock().await;
@@ -632,7 +649,6 @@ pub async fn handle_chat(
.find(|log| log.id == ctx.current_id)
{
log.timing.total = format_time_ms(total_time);
log.timing.first = Some(format_time_ms(first_time));
}
}
@@ -648,14 +664,39 @@ pub async fn handle_chat(
role: None,
content: None,
}),
logprobs: None,
finish_reason: Some(FINISH_REASON_STOP.to_string()),
}],
usage: None,
usage: if ctx.need_usage {
TriState::Null
} else {
TriState::None
},
};
response_data.push_str(&format!(
"data: {}\n\ndata: [DONE]\n\n",
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
if ctx.need_usage {
let response = ChatResponse {
id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
created: chrono::Utc::now().timestamp(),
model: None,
choices: vec![],
usage: TriState::Some(Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
}),
};
response_data.push_str(&format!(
"data: {}\n\ndata: [DONE]\n\n",
serde_json::to_string(&response).unwrap()
));
} else {
response_data.push_str("data: [DONE]\n\n");
};
}
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = ctx.state.try_lock() {
@@ -666,7 +707,10 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == ctx.current_id)
{
log.prompt = Some(debug_prompt);
log.chain = Some(Chain {
prompt: debug_prompt,
delays: vec![],
});
}
}
}
@@ -696,7 +740,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.status = LogStatus::Failure;
log.error = Some(error_response.native_code());
log.timing.total =
format_time_ms(start_time.elapsed().as_secs_f64());
@@ -727,7 +771,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.status = LogStatus::Failure;
log.error = Some("Empty stream response".to_string());
state.request_manager.error_requests += 1;
}
@@ -748,16 +792,18 @@ pub async fn handle_chat(
let response_id = response_id.clone();
let model = request.model.clone();
let is_start = is_start.clone();
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
let need_usage = request.stream_options.is_some_and(|opt| opt.include_usage);
let content_time = content_time.clone();
move |chunk| {
let decoder = decoder.clone();
let response_id = response_id.clone();
let model = model.clone();
let is_start = is_start.clone();
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
let need_usage = need_usage;
let content_time = content_time.clone();
async move {
let chunk = chunk.unwrap_or_default();
@@ -766,10 +812,11 @@ pub async fn handle_chat(
response_id: &response_id,
model: &model,
is_start: &is_start,
first_chunk_time: &first_chunk_time,
start_time,
state: &state,
current_id,
need_usage,
content_time: &content_time,
};
// 使用decoder处理chunk
@@ -807,10 +854,12 @@ pub async fn handle_chat(
} else {
// 非流式响应
let start_time = std::time::Instant::now();
let mut first_chunk_time = None::<f64>;
let mut decoder = StreamDecoder::new();
let mut full_text = String::with_capacity(1024);
let mut stream = response.bytes_stream();
let mut prompt = String::new();
let mut content_time = std::time::Instant::now();
let mut delays: Vec<(String, f64)> = Vec::new();
// 逐个处理chunks
while let Some(chunk) = stream.next().await {
@@ -828,23 +877,12 @@ pub async fn handle_chat(
for message in messages {
match message {
StreamMessage::Content(text) => {
if first_chunk_time.is_none() {
first_chunk_time = Some(start_time.elapsed().as_secs_f64());
}
let interval = content_time.duration_as_secs_f64();
delays.push((text.clone(), interval));
full_text.push_str(&text);
}
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = state.try_lock() {
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.prompt = Some(debug_prompt);
}
}
prompt = debug_prompt;
}
_ => {}
}
@@ -881,7 +919,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.status = LogStatus::Failure;
log.error = Some("Empty response received".to_string());
state.request_manager.error_requests += 1;
}
@@ -904,9 +942,10 @@ pub async fn handle_chat(
content: MessageContent::Text(full_text.trim_leading_newlines()),
}),
delta: None,
logprobs: None,
finish_reason: Some(FINISH_REASON_STOP.to_string()),
}],
usage: Some(Usage {
usage: TriState::Some(Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
@@ -925,11 +964,23 @@ pub async fn handle_chat(
.find(|log| log.id == current_id)
{
log.timing.total = total_time;
log.timing.first = first_chunk_time;
log.status = LogStatus::Success;
}
}
// 更新最终的延迟信息
if let Ok(mut state) = state.try_lock() {
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.chain = Some(Chain { prompt, delays });
}
}
Ok(Response::builder()
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&response_data).unwrap()))