v0.1.3-rc.3

This commit is contained in:
wisdgod
2025-01-14 09:13:13 +08:00
parent 732cfbc58e
commit 061156fb79
36 changed files with 3585 additions and 787 deletions

View File

@@ -1,14 +1,15 @@
use super::constant::AVAILABLE_MODELS;
use crate::{
app::{
constant::{
AUTHORIZATION_BEARER_PREFIX, CURSOR_API2_STREAM_CHAT, FINISH_REASON_STOP,
OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK, STATUS_FAILED, STATUS_SUCCESS,
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP,
OBJECT_CHAT_COMPLETION, OBJECT_CHAT_COMPLETION_CHUNK, STATUS_FAILED, STATUS_PENDING,
STATUS_SUCCESS,
},
lazy::AUTH_TOKEN,
model::{AppConfig, AppState, ChatRequest, RequestLog, TokenInfo},
model::{AppConfig, AppState, ChatRequest, RequestLog, TimingInfo, TokenInfo},
},
chat::{
constant::{AVAILABLE_MODELS, USAGE_CHECK_MODELS},
error::StreamError,
model::{
ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage,
@@ -17,8 +18,10 @@ use crate::{
},
common::{
client::build_client,
models::{error::ChatError, ErrorResponse},
utils::{get_user_usage, validate_token_and_checksum},
models::{error::ChatError, userinfo::MembershipType, ErrorResponse},
utils::{
format_time_ms, get_token_profile, validate_token_and_checksum,
},
},
};
use axum::{
@@ -93,7 +96,7 @@ pub async fn handle_chat(
))?;
// 验证 AuthToken 和 获取 token 信息
let (auth_token, checksum, alias) = if auth_header == AUTH_TOKEN.as_str() {
let (auth_token, checksum) = if auth_header == AUTH_TOKEN.as_str() {
// 如果是管理员Token,使用原有逻辑
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
let state_guard = state.lock().await;
@@ -108,11 +111,7 @@ pub async fn handle_chat(
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
let token_info = &token_infos[index];
(
token_info.token.clone(),
token_info.checksum.clone(),
token_info.alias.clone(),
)
(token_info.token.clone(), token_info.checksum.clone())
} else {
// 否则尝试解析token
validate_token_and_checksum(auth_header).ok_or((
@@ -121,6 +120,8 @@ pub async fn handle_chat(
))?
};
let current_id: u64;
// 更新请求日志
{
let state_clone = state.clone();
@@ -128,29 +129,68 @@ pub async fn handle_chat(
state.total_requests += 1;
state.active_requests += 1;
// 如果有model且需要获取使用情况,创建后台任务获取
if let Some(model) = model {
if model.is_usage_check() {
let auth_token_clone = auth_token.clone();
let checksum_clone = checksum.clone();
let state_clone = state_clone.clone();
// 查找最新的相同token的日志,检查使用情况
let need_profile_check = state
.request_logs
.iter()
.rev()
.find(|log| log.token_info.token == auth_token && log.token_info.profile.is_some())
.and_then(|log| log.token_info.profile.as_ref())
.map(|profile| {
if profile.stripe.membership_type != MembershipType::Free {
return false;
}
tokio::spawn(async move {
let usage = get_user_usage(&auth_token_clone, &checksum_clone).await;
let mut state = state_clone.lock().await;
// 根据时间戳找到对应的日志
if let Some(log) = state
.request_logs
.iter_mut()
.find(|log| log.timestamp == request_time)
{
log.token_info.usage = usage;
}
});
}
let is_premium = USAGE_CHECK_MODELS.contains(&request.model.as_str());
let standard = &profile.usage.standard;
let premium = &profile.usage.premium;
if is_premium {
premium
.max_requests
.map_or(false, |max| premium.num_requests >= max)
} else {
standard
.max_requests
.map_or(false, |max| standard.num_requests >= max)
}
})
.unwrap_or(false);
// 如果达到限制,直接返回未授权错误
if need_profile_check {
state.active_requests -= 1;
state.error_requests += 1;
return Err((
StatusCode::UNAUTHORIZED,
Json(ChatError::Unauthorized.to_json()),
));
}
let next_id = state.request_logs.last().map_or(1, |log| log.id + 1);
current_id = next_id;
// 如果需要获取用户使用情况,创建后台任务获取profile
if model.map(|m| m.is_usage_check()).unwrap_or(false) {
let auth_token_clone = auth_token.clone();
let state_clone = state_clone.clone();
let log_id = next_id;
tokio::spawn(async move {
let profile = get_token_profile(&auth_token_clone).await;
let mut state = state_clone.lock().await;
// 根据id查找对应的日志
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == log_id)
{
log.token_info.profile = profile;
}
});
}
state.request_logs.push(RequestLog {
id: next_id,
timestamp: request_time,
@@ -158,12 +198,15 @@ pub async fn handle_chat(
token_info: TokenInfo {
token: auth_token.clone(),
checksum: checksum.clone(),
alias: alias.clone(),
usage: None,
profile: None,
},
prompt: None,
timing: TimingInfo {
total: 0.0,
first: None,
},
stream: request.stream,
status: "pending",
status: STATUS_PENDING,
error: None,
});
@@ -173,19 +216,54 @@ pub async fn handle_chat(
}
// 将消息转换为hex格式
let hex_data = super::adapter::encode_chat_message(request.messages, &request.model)
.await
.map_err(|_| {
(
let hex_data = match super::adapter::encode_chat_message(request.messages, &request.model).await
{
Ok(data) => data,
Err(e) => {
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some(e.to_string());
}
state.active_requests -= 1;
state.error_requests += 1;
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(
ChatError::RequestFailed("Failed to encode chat message".to_string()).to_json(),
),
)
})?;
));
}
};
// 构建请求客户端
let client = build_client(&auth_token, &checksum, CURSOR_API2_STREAM_CHAT);
// let client_key = match generate_client_key(&checksum) {
// Some(key) => key,
// None => {
// let mut state = state.lock().await;
// if let Some(log) = state
// .request_logs
// .iter_mut()
// .rev()
// .find(|log| log.id == current_id)
// {
// log.status = STATUS_FAILED;
// log.error = Some(ERR_CHECKSUM_NO_GOOD.to_string());
// }
// state.active_requests -= 1;
// state.error_requests += 1;
// return Err((
// StatusCode::BAD_REQUEST,
// Json(ChatError::RequestFailed(ERR_CHECKSUM_NO_GOOD.to_string()).to_json()),
// ));
// }
// };
let client = build_client(&auth_token, &checksum);
let response = client.body(hex_data).send().await;
// 处理请求结果
@@ -194,7 +272,14 @@ pub async fn handle_chat(
// 更新请求日志为成功
{
let mut state = state.lock().await;
state.request_logs.last_mut().unwrap().status = STATUS_SUCCESS;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_SUCCESS;
}
}
resp
}
@@ -202,10 +287,17 @@ pub async fn handle_chat(
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(last_log) = state.request_logs.last_mut() {
last_log.status = STATUS_FAILED;
last_log.error = Some(e.to_string());
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some(e.to_string());
}
state.active_requests -= 1;
state.error_requests += 1;
}
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
@@ -224,6 +316,8 @@ pub async fn handle_chat(
let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple());
let full_text = Arc::new(Mutex::new(String::with_capacity(1024)));
let is_start = Arc::new(AtomicBool::new(true));
let start_time = std::time::Instant::now();
let first_chunk_time = Arc::new(Mutex::new(None));
let stream = {
// 创建新的 stream
@@ -250,9 +344,16 @@ pub async fn handle_chat(
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(last_log) = state.request_logs.last_mut() {
last_log.status = STATUS_FAILED;
last_log.error = Some(error_respone.native_code());
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some(error_respone.native_code());
log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
}
}
return Err((
@@ -279,9 +380,15 @@ pub async fn handle_chat(
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(last_log) = state.request_logs.last_mut() {
last_log.status = STATUS_FAILED;
last_log.error = Some("Empty stream response".to_string());
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some("Empty stream response".to_string());
state.error_requests += 1;
}
}
return Err((
@@ -299,7 +406,9 @@ pub async fn handle_chat(
}
}
.then({
let buffer = Arc::new(Mutex::new(Vec::new())); // 创建共享的buffer
let buffer = Arc::new(Mutex::new(Vec::new()));
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
move |chunk| {
let buffer = buffer.clone();
@@ -307,6 +416,7 @@ pub async fn handle_chat(
let model = request.model.clone();
let is_start = is_start.clone();
let full_text = full_text.clone();
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
async move {
@@ -319,6 +429,14 @@ pub async fn handle_chat(
buffer_guard.clear();
let mut response_data = String::new();
// 记录首字时间(如果还未记录)
if let Ok(mut first_time) = first_chunk_time.try_lock() {
if first_time.is_none() {
*first_time = Some(format_time_ms(start_time.elapsed().as_secs_f64()));
}
}
// 处理文本内容
for text in texts {
let mut text_guard = full_text.lock().await;
text_guard.push_str(&text);
@@ -387,6 +505,23 @@ pub async fn handle_chat(
// 根据配置决定是否发送最后的 finish_reason
let include_finish_reason = AppConfig::get_stop_stream();
// 计算总时间和首次片段时间
let total_time = format_time_ms(start_time.elapsed().as_secs_f64());
let first_time = first_chunk_time.lock().await.unwrap_or(total_time);
{
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.timing.total = total_time;
log.timing.first = Some(first_time);
}
}
if include_finish_reason {
let response = ChatResponse {
id: response_id.clone(),
@@ -443,13 +578,29 @@ pub async fn handle_chat(
.unwrap())
} else {
// 非流式响应
let mut full_text = String::with_capacity(1024); // 预分配合适的容量
let start_time = std::time::Instant::now();
let mut first_chunk_received = false;
let mut first_chunk_time = 0.0;
let mut full_text = String::with_capacity(1024);
let mut stream = response.bytes_stream();
let mut prompt = None;
let mut buffer = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| {
// 更新请求日志为失败
if let Ok(mut state) = state.try_lock() {
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some(format!("Failed to read response chunk: {}", e));
state.error_requests += 1;
}
}
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(
@@ -463,14 +614,16 @@ pub async fn handle_chat(
match parse_stream_data(&buffer) {
Ok(StreamMessage::Content(texts)) => {
if !first_chunk_received {
first_chunk_time = format_time_ms(start_time.elapsed().as_secs_f64());
first_chunk_received = true;
}
for text in texts {
full_text.push_str(&text);
}
buffer.clear();
}
Ok(StreamMessage::Incomplete) => {
continue;
}
Ok(StreamMessage::Incomplete) => continue,
Ok(StreamMessage::Debug(debug_prompt)) => {
prompt = Some(debug_prompt);
buffer.clear();
@@ -479,11 +632,23 @@ pub async fn handle_chat(
buffer.clear();
}
Err(StreamError::ChatError(error)) => {
return Err((
StatusCode::from_u16(error.status_code())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
Json(error.to_error_response().to_common()),
));
let error = error.to_error_response();
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some(error.native_code());
log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
}
}
return Err((error.status_code(), Json(error.to_common())));
}
Err(_) => {
buffer.clear();
@@ -492,21 +657,23 @@ pub async fn handle_chat(
}
}
let prompt_tokens = prompt.as_ref().map(|p| p.len() as u32).unwrap_or(0);
let completion_tokens = full_text.len() as u32;
let total_tokens = prompt_tokens + completion_tokens;
// 检查响应是否为空
if full_text.is_empty() {
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(last_log) = state.request_logs.last_mut() {
last_log.status = STATUS_FAILED;
last_log.error = Some("Empty response received".to_string());
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some("Empty response received".to_string());
if let Some(p) = prompt {
last_log.prompt = Some(p);
log.prompt = Some(p);
}
state.error_requests += 1;
}
}
return Err((
@@ -515,14 +682,6 @@ pub async fn handle_chat(
));
}
// 更新请求日志提示词
{
let mut state = state.lock().await;
if let Some(last_log) = state.request_logs.last_mut() {
last_log.prompt = prompt;
}
}
let response_data = ChatResponse {
id: format!("chatcmpl-{}", Uuid::new_v4().simple()),
object: OBJECT_CHAT_COMPLETION.to_string(),
@@ -538,12 +697,29 @@ pub async fn handle_chat(
finish_reason: Some(FINISH_REASON_STOP.to_string()),
}],
usage: Some(Usage {
prompt_tokens,
completion_tokens,
total_tokens,
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
}),
};
{
// 更新请求日志时间信息和状态
let total_time = format_time_ms(start_time.elapsed().as_secs_f64());
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.timing.total = total_time;
log.timing.first = Some(first_chunk_time);
log.prompt = prompt;
log.status = STATUS_SUCCESS;
}
}
Ok(Response::builder()
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&response_data).unwrap()))