mirror of
https://github.com/wisdgod/cursor-api.git
synced 2025-10-27 17:10:27 +08:00
0.1.3-rc.5.2.5
This commit is contained in:
@@ -2,7 +2,8 @@ use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP, OBJECT_CHAT_COMPLETION,
|
||||
OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
OBJECT_CHAT_COMPLETION_CHUNK, header_value_chunked, header_value_event_stream,
|
||||
header_value_json, header_value_keep_alive, header_value_no_cache_revalidate,
|
||||
},
|
||||
lazy::{
|
||||
AUTH_TOKEN, GENERAL_TIMEZONE, IS_NO_REQUEST_LOGS, IS_UNLIMITED_REQUEST_LOGS,
|
||||
@@ -26,12 +27,10 @@ use crate::{
|
||||
},
|
||||
core::{
|
||||
config::KeyConfig,
|
||||
constant::{Models, USAGE_CHECK_MODELS},
|
||||
constant::Models,
|
||||
error::StreamError,
|
||||
model::{
|
||||
ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage,
|
||||
},
|
||||
stream::{StreamDecoder, StreamMessage},
|
||||
model::{ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role},
|
||||
stream::decoder::{StreamDecoder, StreamMessage},
|
||||
},
|
||||
leak::intern_string,
|
||||
};
|
||||
@@ -51,17 +50,17 @@ use axum::{
|
||||
use bytes::Bytes;
|
||||
use futures::StreamExt;
|
||||
use prost::Message as _;
|
||||
use std::{borrow::Cow, sync::atomic::{AtomicUsize, Ordering}};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{Arc, atomic::AtomicBool},
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::model::{ChatRequest, Model};
|
||||
|
||||
const NO_CACHE: &str = "no-cache, must-revalidate";
|
||||
const KEEP_ALIVE: &str = "keep-alive";
|
||||
use super::{constant::FREE_MODELS, model::ChatRequest};
|
||||
|
||||
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
@@ -106,7 +105,7 @@ pub async fn handle_models(
|
||||
));
|
||||
}
|
||||
|
||||
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
|
||||
let index = CURRENT_KEY_INDEX.load(Ordering::Acquire) % token_infos.len();
|
||||
let token_info = &token_infos[index];
|
||||
is_pri = true;
|
||||
(
|
||||
@@ -317,19 +316,18 @@ pub async fn handle_chat(
|
||||
if log.token_info.token == auth_token {
|
||||
if let Some(profile) = &log.token_info.profile {
|
||||
if profile.stripe.membership_type == MembershipType::Free {
|
||||
let is_premium = USAGE_CHECK_MODELS.contains(&model);
|
||||
need_profile_check = if is_premium {
|
||||
profile
|
||||
.usage
|
||||
.premium
|
||||
.max_requests
|
||||
.is_some_and(|max| profile.usage.premium.num_requests >= max)
|
||||
} else {
|
||||
need_profile_check = if FREE_MODELS.contains(&model.id) {
|
||||
profile
|
||||
.usage
|
||||
.standard
|
||||
.max_requests
|
||||
.is_some_and(|max| profile.usage.standard.num_requests >= max)
|
||||
} else {
|
||||
profile
|
||||
.usage
|
||||
.premium
|
||||
.max_requests
|
||||
.is_some_and(|max| profile.usage.premium.num_requests >= max)
|
||||
};
|
||||
}
|
||||
break;
|
||||
@@ -350,15 +348,14 @@ pub async fn handle_chat(
|
||||
let next_id = state
|
||||
.request_manager
|
||||
.request_logs
|
||||
.last()
|
||||
.back()
|
||||
.map_or(1, |log| log.id + 1);
|
||||
current_id = next_id;
|
||||
|
||||
// 如果需要获取用户使用情况,创建后台任务获取profile
|
||||
if Model::is_usage_check(
|
||||
model,
|
||||
UsageCheck::from_proto(current_config.usage_check_models.as_ref()),
|
||||
) {
|
||||
if model.is_usage_check(UsageCheck::from_proto(
|
||||
current_config.usage_check_models.as_ref(),
|
||||
)) {
|
||||
let auth_token_clone = auth_token.clone();
|
||||
let state_clone = state_clone.clone();
|
||||
let log_id = next_id;
|
||||
@@ -398,7 +395,7 @@ pub async fn handle_chat(
|
||||
});
|
||||
}
|
||||
|
||||
state.request_manager.request_logs.push(RequestLog {
|
||||
state.request_manager.request_logs.push_back(RequestLog {
|
||||
id: next_id,
|
||||
timestamp: request_time,
|
||||
model: intern_string(request.model),
|
||||
@@ -546,19 +543,16 @@ pub async fn handle_chat(
|
||||
let is_start = Arc::new(AtomicBool::new(true));
|
||||
let start_time = std::time::Instant::now();
|
||||
let decoder = Arc::new(Mutex::new(StreamDecoder::new()));
|
||||
let is_usage_sent = Arc::new(AtomicBool::new(false));
|
||||
let need_usage = if request.stream_options.is_some_and(|opt| opt.include_usage) {
|
||||
Arc::new(Mutex::new(NeedUsage::Need {
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
}))
|
||||
} else {
|
||||
Arc::new(Mutex::new(NeedUsage::None))
|
||||
};
|
||||
let need_usage = Arc::new(Mutex::new(NeedUsage::Need {
|
||||
is_need: request.stream_options.is_some_and(|opt| opt.include_usage),
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
}));
|
||||
let usage_uuid = Arc::new(Mutex::new(None));
|
||||
|
||||
// 定义消息处理器的上下文结构体
|
||||
struct MessageProcessContext<'a> {
|
||||
@@ -568,15 +562,17 @@ pub async fn handle_chat(
|
||||
start_time: std::time::Instant,
|
||||
state: &'a Mutex<AppState>,
|
||||
current_id: u64,
|
||||
usage_uuid: &'a Mutex<Option<String>>,
|
||||
need_usage: &'a Mutex<NeedUsage>,
|
||||
is_usage_sent: &'a AtomicBool,
|
||||
created: i64,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
enum NeedUsage {
|
||||
#[default]
|
||||
None,
|
||||
Taked,
|
||||
Need {
|
||||
is_need: bool,
|
||||
client: reqwest::Client,
|
||||
auth_token: String,
|
||||
checksum: String,
|
||||
@@ -589,7 +585,10 @@ pub async fn handle_chat(
|
||||
impl NeedUsage {
|
||||
#[inline(always)]
|
||||
const fn is_need(&self) -> bool {
|
||||
matches!(*self, Self::Need { .. })
|
||||
match self {
|
||||
Self::Taked => false,
|
||||
Self::Need { is_need, .. } => *is_need,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -602,8 +601,8 @@ pub async fn handle_chat(
|
||||
async fn process_messages(
|
||||
messages: Vec<StreamMessage>,
|
||||
ctx: &MessageProcessContext<'_>,
|
||||
) -> String {
|
||||
let mut response_data = String::new();
|
||||
) -> Vec<u8> {
|
||||
let mut response_data = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
match message {
|
||||
@@ -611,7 +610,7 @@ pub async fn handle_chat(
|
||||
let is_first = ctx.is_start.load(Ordering::Acquire);
|
||||
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
id: ctx.response_id,
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: if is_first { Some(ctx.model) } else { None },
|
||||
@@ -641,65 +640,13 @@ pub async fn handle_chat(
|
||||
},
|
||||
};
|
||||
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
response_data.extend_from_slice(b"data: ");
|
||||
response_data.extend_from_slice(&serde_json::to_vec(&response).unwrap());
|
||||
response_data.extend_from_slice(b"\n\n");
|
||||
}
|
||||
StreamMessage::Usage(usage_uuid) => {
|
||||
if !ctx.is_usage_sent.load(Ordering::Acquire) {
|
||||
if let NeedUsage::Need {
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
} = ctx.need_usage.lock().await.take()
|
||||
{
|
||||
let usage = get_token_usage(
|
||||
client,
|
||||
&auth_token,
|
||||
&checksum,
|
||||
&client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
usage_uuid,
|
||||
)
|
||||
.await;
|
||||
if let Some(ref usage) = usage {
|
||||
let mut state = ctx.state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_manager
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == ctx.current_id)
|
||||
{
|
||||
if let Some(chain) = &mut log.chain {
|
||||
chain.usage = OptionUsage::Uasge {
|
||||
input: usage.prompt_tokens,
|
||||
output: usage.completion_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: None,
|
||||
choices: vec![],
|
||||
usage: TriState::Some(usage.unwrap_or_default()),
|
||||
};
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
ctx.is_usage_sent.store(true, Ordering::Release);
|
||||
}
|
||||
} else {
|
||||
crate::debug_println!("usage is sent, but find {usage_uuid}");
|
||||
if !usage_uuid.is_empty() && ctx.need_usage.lock().await.is_need() {
|
||||
*ctx.usage_uuid.lock().await = Some(usage_uuid);
|
||||
}
|
||||
}
|
||||
StreamMessage::StreamEnd => {
|
||||
@@ -720,7 +667,7 @@ pub async fn handle_chat(
|
||||
}
|
||||
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
id: ctx.response_id,
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: None,
|
||||
@@ -740,31 +687,65 @@ pub async fn handle_chat(
|
||||
TriState::None
|
||||
},
|
||||
};
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
if !ctx.is_usage_sent.load(Ordering::Acquire)
|
||||
&& ctx.need_usage.lock().await.is_need()
|
||||
{
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: None,
|
||||
choices: vec![],
|
||||
usage: TriState::Some(Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}),
|
||||
response_data.extend_from_slice(b"data: ");
|
||||
response_data.extend_from_slice(&serde_json::to_vec(&response).unwrap());
|
||||
response_data.extend_from_slice(b"\n\n");
|
||||
if let Some(usage_uuid) = ctx.usage_uuid.lock().await.take() {
|
||||
if let NeedUsage::Need {
|
||||
is_need,
|
||||
client,
|
||||
auth_token,
|
||||
checksum,
|
||||
client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
} = ctx.need_usage.lock().await.take()
|
||||
{
|
||||
let usage = if *crate::app::lazy::REAL_USAGE {
|
||||
let usage = tokio::spawn(get_token_usage(
|
||||
client, auth_token, checksum, client_key, timezone, is_pri,
|
||||
usage_uuid,
|
||||
))
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if let Some(ref usage) = usage {
|
||||
let mut state = ctx.state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_manager
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == ctx.current_id)
|
||||
{
|
||||
if let Some(chain) = &mut log.chain {
|
||||
chain.usage = OptionUsage::Uasge {
|
||||
input: usage.prompt_tokens,
|
||||
output: usage.completion_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
usage
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if is_need {
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id,
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
created: ctx.created,
|
||||
model: None,
|
||||
choices: vec![],
|
||||
usage: TriState::Some(usage.unwrap_or_default()),
|
||||
};
|
||||
response_data.extend_from_slice(b"data: ");
|
||||
response_data
|
||||
.extend_from_slice(&serde_json::to_vec(&response).unwrap());
|
||||
response_data.extend_from_slice(b"\n\n");
|
||||
}
|
||||
};
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
ctx.is_usage_sent.store(true, Ordering::Release);
|
||||
};
|
||||
}
|
||||
}
|
||||
StreamMessage::Debug(debug_prompt) => {
|
||||
if let Ok(mut state) = ctx.state.try_lock() {
|
||||
@@ -781,7 +762,7 @@ pub async fn handle_chat(
|
||||
} else {
|
||||
log.chain = Some(Chain {
|
||||
prompt: Prompt::new(debug_prompt),
|
||||
delays: vec![],
|
||||
delays: None,
|
||||
usage: OptionUsage::None,
|
||||
});
|
||||
}
|
||||
@@ -866,6 +847,8 @@ pub async fn handle_chat(
|
||||
}
|
||||
}
|
||||
|
||||
let created = Arc::new(std::sync::OnceLock::new());
|
||||
|
||||
// 处理后续的stream
|
||||
let stream = stream
|
||||
.then({
|
||||
@@ -877,27 +860,29 @@ pub async fn handle_chat(
|
||||
let response_id = response_id.clone();
|
||||
let is_start = is_start.clone();
|
||||
let state = state.clone();
|
||||
let is_usage_sent = is_usage_sent.clone();
|
||||
let need_usage = need_usage.clone();
|
||||
let usage_uuid = usage_uuid.clone();
|
||||
let created = created.clone();
|
||||
|
||||
async move {
|
||||
let chunk = match chunk {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
crate::debug_println!("Find chunk error: {e}");
|
||||
Err(_) => {
|
||||
// crate::debug_println!("Find chunk error: {e}");
|
||||
return Ok::<_, Infallible>(Bytes::new());
|
||||
}
|
||||
};
|
||||
|
||||
let ctx = MessageProcessContext {
|
||||
response_id: &response_id,
|
||||
model,
|
||||
model: model.id,
|
||||
is_start: &is_start,
|
||||
start_time,
|
||||
state: &state,
|
||||
current_id,
|
||||
usage_uuid: &usage_uuid,
|
||||
need_usage: &need_usage,
|
||||
is_usage_sent: &is_usage_sent,
|
||||
created: *created.get_or_init(|| chrono::Utc::now().timestamp()),
|
||||
};
|
||||
|
||||
// 使用decoder处理chunk
|
||||
@@ -930,16 +915,16 @@ pub async fn handle_chat(
|
||||
}
|
||||
};
|
||||
|
||||
let mut response_data = String::new();
|
||||
let mut response_data = Vec::new();
|
||||
|
||||
if let Some(first_msg) = decoder.lock().await.take_first_result() {
|
||||
let first_response = process_messages(first_msg, &ctx).await;
|
||||
response_data.push_str(&first_response);
|
||||
response_data.extend_from_slice(&first_response);
|
||||
}
|
||||
|
||||
let current_response = process_messages(messages, &ctx).await;
|
||||
if !current_response.is_empty() {
|
||||
response_data.push_str(¤t_response);
|
||||
response_data.extend_from_slice(¤t_response);
|
||||
}
|
||||
|
||||
Ok(Bytes::from(response_data))
|
||||
@@ -970,10 +955,10 @@ pub async fn handle_chat(
|
||||
}));
|
||||
|
||||
Ok(Response::builder()
|
||||
.header(CACHE_CONTROL, NO_CACHE)
|
||||
.header(CONNECTION, KEEP_ALIVE)
|
||||
.header(CONTENT_TYPE, "text/event-stream")
|
||||
.header(TRANSFER_ENCODING, "chunked")
|
||||
.header(CACHE_CONTROL, header_value_no_cache_revalidate())
|
||||
.header(CONNECTION, header_value_keep_alive())
|
||||
.header(CONTENT_TYPE, header_value_event_stream())
|
||||
.header(TRANSFER_ENCODING, header_value_chunked())
|
||||
.body(Body::from_stream(stream))
|
||||
.unwrap())
|
||||
} else {
|
||||
@@ -1095,13 +1080,7 @@ pub async fn handle_chat(
|
||||
|
||||
let (usage1, usage2) = if !usage_uuid.is_empty() {
|
||||
let result = get_token_usage(
|
||||
client,
|
||||
&auth_token,
|
||||
&checksum,
|
||||
&client_key,
|
||||
timezone,
|
||||
is_pri,
|
||||
usage_uuid,
|
||||
client, auth_token, checksum, client_key, timezone, is_pri, usage_uuid,
|
||||
)
|
||||
.await;
|
||||
let result2 = match result {
|
||||
@@ -1117,10 +1096,10 @@ pub async fn handle_chat(
|
||||
};
|
||||
|
||||
let response_data = ChatResponse {
|
||||
id: format!("chatcmpl-{trace_id}"),
|
||||
id: &format!("chatcmpl-{trace_id}"),
|
||||
object: OBJECT_CHAT_COMPLETION,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: Some(model),
|
||||
model: Some(model.id),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: Some(Message {
|
||||
@@ -1155,11 +1134,11 @@ pub async fn handle_chat(
|
||||
}
|
||||
}
|
||||
|
||||
let data = serde_json::to_string(&response_data).unwrap();
|
||||
let data = serde_json::to_vec(&response_data).unwrap();
|
||||
Ok(Response::builder()
|
||||
.header(CACHE_CONTROL, NO_CACHE)
|
||||
.header(CONNECTION, KEEP_ALIVE)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.header(CACHE_CONTROL, header_value_no_cache_revalidate())
|
||||
.header(CONNECTION, header_value_keep_alive())
|
||||
.header(CONTENT_TYPE, header_value_json())
|
||||
.header(CONTENT_LENGTH, data.len())
|
||||
.body(Body::from(data))
|
||||
.unwrap())
|
||||
|
||||
Reference in New Issue
Block a user