mirror of
https://github.com/wisdgod/cursor-api.git
synced 2025-10-16 03:30:36 +08:00
修复了一些问题
This commit is contained in:
@@ -4,10 +4,13 @@ use crate::{
|
||||
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP, OBJECT_CHAT_COMPLETION,
|
||||
OBJECT_CHAT_COMPLETION_CHUNK, STATUS_FAILED, STATUS_PENDING, STATUS_SUCCESS,
|
||||
},
|
||||
lazy::{AUTH_TOKEN, SHARED_AUTH_TOKEN, USE_SHARE},
|
||||
model::{AppConfig, AppState, ChatRequest, RequestLog, TimingInfo, TokenInfo},
|
||||
lazy::{
|
||||
AUTH_TOKEN, KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT,
|
||||
},
|
||||
model::{AppConfig, AppState, ChatRequest, RequestLog, TimingInfo, TokenInfo, UsageCheck},
|
||||
},
|
||||
chat::{
|
||||
config::KeyConfig,
|
||||
constant::{AVAILABLE_MODELS, USAGE_CHECK_MODELS},
|
||||
error::StreamError,
|
||||
model::{
|
||||
@@ -17,8 +20,11 @@ use crate::{
|
||||
},
|
||||
common::{
|
||||
client::build_client,
|
||||
models::{error::ChatError, userinfo::MembershipType, ErrorResponse},
|
||||
utils::{format_time_ms, get_token_profile, validate_token_and_checksum},
|
||||
model::{error::ChatError, userinfo::MembershipType, ErrorResponse},
|
||||
utils::{
|
||||
format_time_ms, from_base64, get_token_profile, tokeninfo_to_token,
|
||||
validate_token_and_checksum,
|
||||
},
|
||||
},
|
||||
};
|
||||
use axum::{
|
||||
@@ -33,6 +39,7 @@ use axum::{
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::{Stream, StreamExt};
|
||||
use prost::Message as _;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
@@ -44,8 +51,6 @@ use std::{
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
const REQUEST_LOGS_LIMIT: usize = 1000;
|
||||
|
||||
// 模型列表处理
|
||||
pub async fn handle_models() -> Json<ModelsResponse> {
|
||||
Json(ModelsResponse {
|
||||
@@ -92,10 +97,15 @@ pub async fn handle_chat(
|
||||
Json(ChatError::Unauthorized.to_json()),
|
||||
))?;
|
||||
|
||||
let mut current_config = KeyConfig::new_with_global();
|
||||
|
||||
// 验证认证token并获取token信息
|
||||
let (auth_token, checksum) = match auth_header {
|
||||
// 管理员Token验证逻辑
|
||||
token if token == AUTH_TOKEN.as_str() || (*USE_SHARE && token == SHARED_AUTH_TOKEN.as_str()) => {
|
||||
token
|
||||
if token == AUTH_TOKEN.as_str()
|
||||
|| (AppConfig::is_share() && token == AppConfig::get_share_token().as_str()) =>
|
||||
{
|
||||
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
|
||||
let state_guard = state.lock().await;
|
||||
let token_infos = &state_guard.token_infos;
|
||||
@@ -103,7 +113,7 @@ pub async fn handle_chat(
|
||||
// 检查是否存在可用的token
|
||||
if token_infos.is_empty() {
|
||||
return Err((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(ChatError::NoTokens.to_json()),
|
||||
));
|
||||
}
|
||||
@@ -112,7 +122,21 @@ pub async fn handle_chat(
|
||||
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
|
||||
let token_info = &token_infos[index];
|
||||
(token_info.token.clone(), token_info.checksum.clone())
|
||||
},
|
||||
}
|
||||
|
||||
token if AppConfig::get_dynamic_key() && token.starts_with(&*KEY_PREFIX) => {
|
||||
from_base64(&token[*KEY_PREFIX_LEN..])
|
||||
.and_then(|decoded_bytes| KeyConfig::decode(&decoded_bytes[..]).ok())
|
||||
.and_then(|key_config| {
|
||||
key_config.copy_without_auth_token(&mut current_config);
|
||||
key_config.auth_token
|
||||
})
|
||||
.and_then(|token_info| tokeninfo_to_token(&token_info))
|
||||
.ok_or((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ChatError::Unauthorized.to_json()),
|
||||
))?
|
||||
}
|
||||
|
||||
// 普通用户Token验证逻辑
|
||||
token => validate_token_and_checksum(token).ok_or((
|
||||
@@ -121,6 +145,8 @@ pub async fn handle_chat(
|
||||
))?,
|
||||
};
|
||||
|
||||
let current_config = current_config;
|
||||
|
||||
let current_id: u64;
|
||||
|
||||
// 更新请求日志
|
||||
@@ -172,7 +198,14 @@ pub async fn handle_chat(
|
||||
current_id = next_id;
|
||||
|
||||
// 如果需要获取用户使用情况,创建后台任务获取profile
|
||||
if model.map(|m| m.is_usage_check()).unwrap_or(false) {
|
||||
if model
|
||||
.map(|m| {
|
||||
m.is_usage_check(UsageCheck::from_proto(
|
||||
current_config.usage_check_models.as_ref(),
|
||||
))
|
||||
})
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let auth_token_clone = auth_token.clone();
|
||||
let state_clone = state_clone.clone();
|
||||
let log_id = next_id;
|
||||
@@ -211,13 +244,19 @@ pub async fn handle_chat(
|
||||
error: None,
|
||||
});
|
||||
|
||||
if state.request_logs.len() > REQUEST_LOGS_LIMIT {
|
||||
if state.request_logs.len() > *REQUEST_LOGS_LIMIT {
|
||||
state.request_logs.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
// 将消息转换为hex格式
|
||||
let hex_data = match super::adapter::encode_chat_message(request.messages, &request.model).await
|
||||
let hex_data = match super::adapter::encode_chat_message(
|
||||
request.messages,
|
||||
&request.model,
|
||||
current_config.disable_vision(),
|
||||
current_config.enable_slow_pool(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
@@ -244,27 +283,55 @@ pub async fn handle_chat(
|
||||
|
||||
// 构建请求客户端
|
||||
let client = build_client(&auth_token, &checksum);
|
||||
let response = client.body(hex_data).send().await;
|
||||
// 添加超时设置
|
||||
let response = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(*SERVICE_TIMEOUT),
|
||||
client.body(hex_data).send(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// 处理请求结果
|
||||
let response = match response {
|
||||
Ok(resp) => {
|
||||
// 更新请求日志为成功
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
Ok(inner_response) => match inner_response {
|
||||
Ok(resp) => {
|
||||
// 更新请求日志为成功
|
||||
{
|
||||
log.status = STATUS_SUCCESS;
|
||||
let mut state = state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_SUCCESS;
|
||||
}
|
||||
}
|
||||
resp
|
||||
}
|
||||
resp
|
||||
}
|
||||
Err(e) => {
|
||||
// 更新请求日志为失败
|
||||
Err(e) => {
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.error = Some(e.to_string());
|
||||
}
|
||||
state.active_requests -= 1;
|
||||
state.error_requests += 1;
|
||||
}
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::RequestFailed(e.to_string()).to_json()),
|
||||
));
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
// 处理超时错误
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(log) = state
|
||||
@@ -274,14 +341,14 @@ pub async fn handle_chat(
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.error = Some(e.to_string());
|
||||
log.error = Some("Request timeout".to_string());
|
||||
}
|
||||
state.active_requests -= 1;
|
||||
state.error_requests += 1;
|
||||
}
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::RequestFailed(e.to_string()).to_json()),
|
||||
StatusCode::GATEWAY_TIMEOUT,
|
||||
Json(ChatError::RequestFailed("Request timeout".to_string()).to_json()),
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -303,9 +370,7 @@ pub async fn handle_chat(
|
||||
// 创建新的 stream
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
let enable_stream_check = AppConfig::get_stream_check();
|
||||
|
||||
if enable_stream_check {
|
||||
if current_config.enable_stream_check() {
|
||||
// 检查第一个 chunk
|
||||
match stream.next().await {
|
||||
Some(first_chunk) => {
|
||||
@@ -399,6 +464,8 @@ pub async fn handle_chat(
|
||||
let full_text = full_text.clone();
|
||||
let first_chunk_time = first_chunk_time.clone();
|
||||
let state = state.clone();
|
||||
// 根据配置决定是否发送最后的 finish_reason
|
||||
let include_finish_reason = current_config.include_stop_stream();
|
||||
|
||||
async move {
|
||||
let chunk = chunk.unwrap_or_default();
|
||||
@@ -484,8 +551,6 @@ pub async fn handle_chat(
|
||||
}
|
||||
Ok(StreamMessage::StreamEnd) => {
|
||||
buffer_guard.clear();
|
||||
// 根据配置决定是否发送最后的 finish_reason
|
||||
let include_finish_reason = AppConfig::get_stop_stream();
|
||||
|
||||
// 计算总时间和首次片段时间
|
||||
let total_time = format_time_ms(start_time.elapsed().as_secs_f64());
|
||||
|
Reference in New Issue
Block a user