这是可回退普通版的提交

This commit is contained in:
wisdgod
2025-01-04 02:08:16 +08:00
parent c709f9bfc7
commit 732cfbc58e
26 changed files with 953 additions and 298 deletions

View File

@@ -3,11 +3,10 @@ use crate::{
app::{
constant::{
AUTHORIZATION_BEARER_PREFIX, CURSOR_API2_STREAM_CHAT, FINISH_REASON_STOP,
HEADER_NAME_CONTENT_TYPE, OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK,
STATUS_FAILED, STATUS_SUCCESS,
OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK, STATUS_FAILED, STATUS_SUCCESS,
},
model::{AppConfig, AppState, ChatRequest, RequestLog, TokenInfo},
lazy::AUTH_TOKEN,
model::{AppConfig, AppState, ChatRequest, RequestLog, TokenInfo},
},
chat::{
error::StreamError,
@@ -19,13 +18,16 @@ use crate::{
common::{
client::build_client,
models::{error::ChatError, ErrorResponse},
utils::get_user_usage,
utils::{get_user_usage, validate_token_and_checksum},
},
};
use axum::{
body::Body,
extract::State,
http::{HeaderMap, StatusCode},
http::{
header::{AUTHORIZATION, CONTENT_TYPE},
HeaderMap, StatusCode,
},
response::Response,
Json,
};
@@ -42,6 +44,8 @@ use std::{
use tokio::sync::Mutex;
use uuid::Uuid;
const REQUEST_LOGS_LIMIT: usize = 1000;
// 模型列表处理
pub async fn handle_models() -> Json<ModelsResponse> {
Json(ModelsResponse {
@@ -79,8 +83,8 @@ pub async fn handle_chat(
}
// 获取并处理认证令牌
let auth_token = headers
.get(axum::http::header::AUTHORIZATION)
let auth_header = headers
.get(AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|h| h.strip_prefix(AUTHORIZATION_BEARER_PREFIX))
.ok_or((
@@ -88,16 +92,9 @@ pub async fn handle_chat(
Json(ChatError::Unauthorized.to_json()),
))?;
// 验证 AuthToken
if auth_token != AUTH_TOKEN.as_str() {
return Err((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
));
}
// 完整的令牌处理逻辑和对应的 checksum
let (auth_token, checksum, alias) = {
// 验证 AuthToken 和 获取 token 信息
let (auth_token, checksum, alias) = if auth_header == AUTH_TOKEN.as_str() {
// 如果是管理员Token,使用原有逻辑
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
let state_guard = state.lock().await;
let token_infos = &state_guard.token_infos;
@@ -116,6 +113,12 @@ pub async fn handle_chat(
token_info.checksum.clone(),
token_info.alias.clone(),
)
} else {
// 否则尝试解析token
validate_token_and_checksum(auth_header).ok_or((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
))?
};
// 更新请求日志
@@ -147,7 +150,9 @@ pub async fn handle_chat(
}
}
let next_id = state.request_logs.last().map_or(1, |log| log.id + 1);
state.request_logs.push(RequestLog {
id: next_id,
timestamp: request_time,
model: request.model.clone(),
token_info: TokenInfo {
@@ -162,7 +167,7 @@ pub async fn handle_chat(
error: None,
});
if state.request_logs.len() > 100 {
if state.request_logs.len() > REQUEST_LOGS_LIMIT {
state.request_logs.remove(0);
}
}
@@ -420,11 +425,6 @@ pub async fn handle_chat(
}
Ok(Bytes::new())
}
Err(StreamError::ChatError(error)) => {
buffer_guard.clear();
eprintln!("Stream error occurred: {}", error.to_json());
Ok(Bytes::new())
}
Err(e) => {
buffer_guard.clear();
eprintln!("[警告] Stream error: {}", e);
@@ -438,7 +438,7 @@ pub async fn handle_chat(
Ok(Response::builder()
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.header(HEADER_NAME_CONTENT_TYPE, "text/event-stream")
.header(CONTENT_TYPE, "text/event-stream")
.body(Body::from_stream(stream))
.unwrap())
} else {
@@ -480,7 +480,7 @@ pub async fn handle_chat(
}
Err(StreamError::ChatError(error)) => {
return Err((
StatusCode::from_u16(error.error.details[0].debug.status_code())
StatusCode::from_u16(error.status_code())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
Json(error.to_error_response().to_common()),
));
@@ -545,7 +545,7 @@ pub async fn handle_chat(
};
Ok(Response::builder()
.header(HEADER_NAME_CONTENT_TYPE, "application/json")
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&response_data).unwrap()))
.unwrap())
}