v0.1.3-rc.5

This commit is contained in:
wisdgod
2025-02-24 08:50:37 +08:00
parent fb0de13712
commit 0e65370ca2
59 changed files with 8861 additions and 2505 deletions

View File

@@ -4,7 +4,10 @@ use crate::{
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP, OBJECT_CHAT_COMPLETION,
OBJECT_CHAT_COMPLETION_CHUNK,
},
lazy::{AUTH_TOKEN, KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT},
lazy::{
AUTH_TOKEN, CURSOR_API2_CHAT_URL, CURSOR_API2_CHAT_WEB_URL, IS_UNLIMITED_REQUEST_LOGS,
KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT,
},
model::{
AppConfig, AppState, ChatRequest, LogStatus, RequestLog, TimingInfo, TokenInfo,
UsageCheck,
@@ -12,7 +15,7 @@ use crate::{
},
chat::{
config::KeyConfig,
constant::{AVAILABLE_MODELS, USAGE_CHECK_MODELS},
constant::{Models, USAGE_CHECK_MODELS},
error::StreamError,
model::{
ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage,
@@ -21,22 +24,22 @@ use crate::{
},
common::{
client::build_client,
model::{error::ChatError, userinfo::MembershipType, ApiStatus, ErrorResponse},
model::{ApiStatus, ErrorResponse, error::ChatError, userinfo::MembershipType},
utils::{
format_time_ms, from_base64, get_token_profile, tokeninfo_to_token,
validate_token_and_checksum, TrimNewlines as _,
TrimNewlines as _, format_time_ms, from_base64, get_available_models,
get_token_profile, tokeninfo_to_token, validate_token_and_checksum,
},
},
};
use axum::{
Json,
body::Body,
extract::State,
http::{
header::{AUTHORIZATION, CONTENT_TYPE},
HeaderMap, StatusCode,
header::{AUTHORIZATION, CONTENT_TYPE},
},
response::Response,
Json,
};
use bytes::Bytes;
use futures::StreamExt;
@@ -44,17 +47,129 @@ use prost::Message as _;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{
convert::Infallible,
sync::{atomic::AtomicBool, Arc},
sync::{Arc, atomic::AtomicBool},
};
use tokio::sync::Mutex;
use uuid::Uuid;
use super::{constant::LONG_CONTEXT_MODELS, model::Model};
// 辅助函数提取认证token
fn extract_auth_token(headers: &HeaderMap) -> Result<&str, (StatusCode, Json<ErrorResponse>)> {
headers
.get(AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
.ok_or((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
))
}
// 辅助函数解析token信息
async fn resolve_token_info(
auth_header: &str,
state: &Arc<Mutex<AppState>>,
) -> Result<(String, String), (StatusCode, Json<ErrorResponse>)> {
match auth_header {
// 管理员Token处理
token if is_admin_token(token) => resolve_admin_token(state).await,
// 动态密钥处理
token if is_dynamic_key(token) => resolve_dynamic_key(token),
// 普通用户Token处理
token => validate_token_and_checksum(token).ok_or((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
)),
}
}
// 辅助函数检查是否为管理员token
fn is_admin_token(token: &str) -> bool {
token == AUTH_TOKEN.as_str()
|| (AppConfig::is_share() && token == AppConfig::get_share_token().as_str())
}
// 辅助函数:检查是否为动态密钥
fn is_dynamic_key(token: &str) -> bool {
AppConfig::get_dynamic_key() && token.starts_with(&*KEY_PREFIX)
}
// 辅助函数处理管理员token
async fn resolve_admin_token(
state: &Arc<Mutex<AppState>>,
) -> Result<(String, String), (StatusCode, Json<ErrorResponse>)> {
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
let state_guard = state.lock().await;
let token_infos = &state_guard.token_manager.tokens;
if token_infos.is_empty() {
return Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(ChatError::NoTokens.to_json()),
));
}
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
let token_info = &token_infos[index];
Ok((token_info.token.clone(), token_info.checksum.clone()))
}
// 辅助函数:处理动态密钥
fn resolve_dynamic_key(token: &str) -> Result<(String, String), (StatusCode, Json<ErrorResponse>)> {
from_base64(&token[*KEY_PREFIX_LEN..])
.and_then(|decoded_bytes| KeyConfig::decode(&decoded_bytes[..]).ok())
.and_then(|key_config| key_config.auth_token)
.and_then(|token_info| tokeninfo_to_token(&token_info))
.ok_or((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
))
}
// 模型列表处理
pub async fn handle_models() -> Json<ModelsResponse> {
Json(ModelsResponse {
object: "list",
data: &AVAILABLE_MODELS,
})
pub async fn handle_models(
State(state): State<Arc<Mutex<AppState>>>,
headers: HeaderMap,
) -> Result<Json<ModelsResponse>, (StatusCode, Json<ErrorResponse>)> {
// 如果没有认证头,返回默认可用模型
if headers.get(AUTHORIZATION).is_none() {
return Ok(Json(ModelsResponse::with_default_models()));
}
// 提取和验证认证token
let auth_token = extract_auth_token(&headers)?;
let (token, checksum) = resolve_token_info(auth_token, &state).await?;
// 获取可用模型列表
let models = get_available_models(&token, &checksum).await.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
status: ApiStatus::Failure,
code: Some(StatusCode::INTERNAL_SERVER_ERROR.as_u16()),
error: Some("Failed to fetch available models".to_string()),
message: Some("Unable to get available models".to_string()),
}),
))?;
// 更新模型列表
if let Err(e) = Models::update(models) {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
status: ApiStatus::Failure,
code: Some(StatusCode::INTERNAL_SERVER_ERROR.as_u16()),
error: Some("Failed to update models".to_string()),
message: Some(e.to_string()),
}),
));
}
Ok(Json(ModelsResponse::new(Models::to_arc())))
}
// 聊天处理函数的签名
@@ -73,15 +188,15 @@ pub async fn handle_chat(
};
// 验证模型是否支持并获取模型信息
let model = AVAILABLE_MODELS.iter().find(|m| m.id == model_name);
let model_supported = model.is_some();
if !(model_supported || allow_claude && request.model.starts_with("claude")) {
return Err((
StatusCode::BAD_REQUEST,
Json(ChatError::ModelNotSupported(request.model).to_json()),
));
}
let model =
if Models::exists(&model_name) || (allow_claude && request.model.starts_with("claude")) {
Some(&model_name)
} else {
return Err((
StatusCode::BAD_REQUEST,
Json(ChatError::ModelNotSupported(request.model).to_json()),
));
};
let request_time = chrono::Local::now();
@@ -114,7 +229,7 @@ pub async fn handle_chat(
{
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
let state_guard = state.lock().await;
let token_infos = &state_guard.token_infos;
let token_infos = &state_guard.token_manager.tokens;
// 检查是否存在可用的token
if token_infos.is_empty() {
@@ -159,56 +274,85 @@ pub async fn handle_chat(
{
let state_clone = state.clone();
let mut state = state.lock().await;
state.total_requests += 1;
state.active_requests += 1;
state.request_manager.total_requests += 1;
state.request_manager.active_requests += 1;
// 查找最新的相同token的日志,检查使用情况
let need_profile_check = state
.request_logs
.iter()
.rev()
.find(|log| log.token_info.token == auth_token && log.token_info.profile.is_some())
.and_then(|log| log.token_info.profile.as_ref())
.map(|profile| {
if profile.stripe.membership_type != MembershipType::Free {
return false;
let mut found_count: u32 = 0;
let mut no_prompt_count: u32 = 0;
let mut need_profile_check = false;
for log in state.request_manager.request_logs.iter().rev() {
if log.token_info.token == auth_token {
if !LONG_CONTEXT_MODELS.contains(&log.model.as_str()) {
found_count += 1;
}
let is_premium = USAGE_CHECK_MODELS.contains(&model_name.as_str());
let standard = &profile.usage.standard;
let premium = &profile.usage.premium;
if is_premium {
premium
.max_requests
.map_or(false, |max| premium.num_requests >= max)
} else {
standard
.max_requests
.map_or(false, |max| standard.num_requests >= max)
if log.prompt.is_none() {
no_prompt_count += 1;
}
})
.unwrap_or(false);
// 如果达到限制,直接返回未授权错误
if found_count == 1 && log.token_info.profile.is_some() {
if let Some(profile) = &log.token_info.profile {
if profile.stripe.membership_type == MembershipType::Free {
let is_premium = USAGE_CHECK_MODELS.contains(&model_name.as_str());
need_profile_check =
if is_premium {
profile.usage.premium.max_requests.is_some_and(|max| {
profile.usage.premium.num_requests >= max
})
} else {
profile.usage.standard.max_requests.is_some_and(|max| {
profile.usage.standard.num_requests >= max
})
};
}
}
}
if found_count == 2 {
break;
}
}
}
if found_count == 2 && no_prompt_count == 2 {
state.request_manager.active_requests -= 1;
state.request_manager.error_requests += 1;
return Err((
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
status: ApiStatus::Error,
code: Some(429),
error: Some("rate_limit_exceeded".to_string()),
message: Some("Too many requests without prompt".to_string()),
}),
));
}
// 处理检查结果
if need_profile_check {
state.active_requests -= 1;
state.error_requests += 1;
state.request_manager.active_requests -= 1;
state.request_manager.error_requests += 1;
return Err((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
));
}
let next_id = state.request_logs.last().map_or(1, |log| log.id + 1);
let next_id = state
.request_manager
.request_logs
.last()
.map_or(1, |log| log.id + 1);
current_id = next_id;
// 如果需要获取用户使用情况,创建后台任务获取profile
if model
.map(|m| {
m.is_usage_check(UsageCheck::from_proto(
current_config.usage_check_models.as_ref(),
))
Model::is_usage_check(
m,
UsageCheck::from_proto(current_config.usage_check_models.as_ref()),
)
})
.unwrap_or(false)
{
@@ -222,30 +366,35 @@ pub async fn handle_chat(
// 先找到所有需要更新的位置的索引
let token_info_idx = state
.token_infos
.token_manager
.tokens
.iter()
.position(|info| info.token == auth_token_clone);
let log_idx = state.request_logs.iter().rposition(|log| log.id == log_id);
let log_idx = state
.request_manager
.request_logs
.iter()
.rposition(|log| log.id == log_id);
// 根据索引更新
match (token_info_idx, log_idx) {
(Some(t_idx), Some(l_idx)) => {
state.token_infos[t_idx].profile = profile.clone();
state.request_logs[l_idx].token_info.profile = profile;
state.token_manager.tokens[t_idx].profile = profile.clone();
state.request_manager.request_logs[l_idx].token_info.profile = profile;
}
(Some(t_idx), None) => {
state.token_infos[t_idx].profile = profile;
state.token_manager.tokens[t_idx].profile = profile;
}
(None, Some(l_idx)) => {
state.request_logs[l_idx].token_info.profile = profile;
state.request_manager.request_logs[l_idx].token_info.profile = profile;
}
(None, None) => {}
}
});
}
state.request_logs.push(RequestLog {
state.request_manager.request_logs.push(RequestLog {
id: next_id,
timestamp: request_time,
model: request.model.clone(),
@@ -253,6 +402,7 @@ pub async fn handle_chat(
token: auth_token.clone(),
checksum: checksum.clone(),
profile: None,
tags: None,
},
prompt: None,
timing: TimingInfo {
@@ -264,8 +414,10 @@ pub async fn handle_chat(
error: None,
});
if state.request_logs.len() > *REQUEST_LOGS_LIMIT {
state.request_logs.remove(0);
if !*IS_UNLIMITED_REQUEST_LOGS
&& state.request_manager.request_logs.len() > *REQUEST_LOGS_LIMIT
{
state.request_manager.request_logs.remove(0);
}
}
@@ -283,6 +435,7 @@ pub async fn handle_chat(
Err(e) => {
let mut state = state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -291,8 +444,8 @@ pub async fn handle_chat(
log.status = LogStatus::Failed;
log.error = Some(e.to_string());
}
state.active_requests -= 1;
state.error_requests += 1;
state.request_manager.active_requests -= 1;
state.request_manager.error_requests += 1;
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(
@@ -303,7 +456,16 @@ pub async fn handle_chat(
};
// 构建请求客户端
let client = build_client(&auth_token, &checksum, is_search);
let client = build_client(
&auth_token,
&checksum,
if is_search {
&CURSOR_API2_CHAT_WEB_URL
} else {
&CURSOR_API2_CHAT_URL
},
true,
);
// 添加超时设置
let response = tokio::time::timeout(
std::time::Duration::from_secs(*SERVICE_TIMEOUT),
@@ -319,6 +481,7 @@ pub async fn handle_chat(
{
let mut state = state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -334,6 +497,7 @@ pub async fn handle_chat(
{
let mut state = state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -342,8 +506,8 @@ pub async fn handle_chat(
log.status = LogStatus::Failed;
log.error = Some(e.to_string());
}
state.active_requests -= 1;
state.error_requests += 1;
state.request_manager.active_requests -= 1;
state.request_manager.error_requests += 1;
}
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
@@ -356,6 +520,7 @@ pub async fn handle_chat(
{
let mut state = state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -364,8 +529,8 @@ pub async fn handle_chat(
log.status = LogStatus::Failed;
log.error = Some("Request timeout".to_string());
}
state.active_requests -= 1;
state.error_requests += 1;
state.request_manager.active_requests -= 1;
state.request_manager.error_requests += 1;
}
return Err((
StatusCode::GATEWAY_TIMEOUT,
@@ -377,7 +542,7 @@ pub async fn handle_chat(
// 释放活动请求计数
{
let mut state = state.lock().await;
state.active_requests -= 1;
state.request_manager.active_requests -= 1;
}
let convert_web_ref = current_config.include_web_references();
@@ -460,6 +625,7 @@ pub async fn handle_chat(
{
let mut state = ctx.state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -494,6 +660,7 @@ pub async fn handle_chat(
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = ctx.state.try_lock() {
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -518,11 +685,12 @@ pub async fn handle_chat(
if let Err(StreamError::ChatError(error)) =
decoder.lock().await.decode(&chunk, convert_web_ref)
{
let error_response = error.to_error_response();
let error_response = error.into_error_response();
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -532,12 +700,12 @@ pub async fn handle_chat(
log.error = Some(error_response.native_code());
log.timing.total =
format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
state.request_manager.error_requests += 1;
}
}
return Err((
error_response.status_code(),
Json(error_response.to_common()),
Json(error_response.into_common()),
));
}
}
@@ -553,6 +721,7 @@ pub async fn handle_chat(
{
let mut state = state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -560,7 +729,7 @@ pub async fn handle_chat(
{
log.status = LogStatus::Failed;
log.error = Some("Empty stream response".to_string());
state.error_requests += 1;
state.request_manager.error_requests += 1;
}
}
return Err((
@@ -667,6 +836,7 @@ pub async fn handle_chat(
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = state.try_lock() {
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -681,10 +851,10 @@ pub async fn handle_chat(
}
}
Err(StreamError::ChatError(error)) => {
let error_response = error.to_error_response();
let error_response = error.into_error_response();
return Err((
error_response.status_code(),
Json(error_response.to_common()),
Json(error_response.into_common()),
));
}
Err(e) => {
@@ -705,6 +875,7 @@ pub async fn handle_chat(
{
let mut state = state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
@@ -712,7 +883,7 @@ pub async fn handle_chat(
{
log.status = LogStatus::Failed;
log.error = Some("Empty response received".to_string());
state.error_requests += 1;
state.request_manager.error_requests += 1;
}
}
return Err((
@@ -747,6 +918,7 @@ pub async fn handle_chat(
let total_time = format_time_ms(start_time.elapsed().as_secs_f64());
let mut state = state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()