0.1.3-rc.5.2.5

This commit is contained in:
wisdgod
2025-04-16 00:59:07 +08:00
parent 5071997ffc
commit c3bfb3b66e
29 changed files with 1017 additions and 883 deletions

View File

@@ -2,7 +2,8 @@ use crate::{
app::{
constant::{
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP, OBJECT_CHAT_COMPLETION,
OBJECT_CHAT_COMPLETION_CHUNK,
OBJECT_CHAT_COMPLETION_CHUNK, header_value_chunked, header_value_event_stream,
header_value_json, header_value_keep_alive, header_value_no_cache_revalidate,
},
lazy::{
AUTH_TOKEN, GENERAL_TIMEZONE, IS_NO_REQUEST_LOGS, IS_UNLIMITED_REQUEST_LOGS,
@@ -26,12 +27,10 @@ use crate::{
},
core::{
config::KeyConfig,
constant::{Models, USAGE_CHECK_MODELS},
constant::Models,
error::StreamError,
model::{
ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage,
},
stream::{StreamDecoder, StreamMessage},
model::{ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role},
stream::decoder::{StreamDecoder, StreamMessage},
},
leak::intern_string,
};
@@ -51,17 +50,17 @@ use axum::{
use bytes::Bytes;
use futures::StreamExt;
use prost::Message as _;
use std::{borrow::Cow, sync::atomic::{AtomicUsize, Ordering}};
use std::{
borrow::Cow,
sync::atomic::{AtomicUsize, Ordering},
};
use std::{
convert::Infallible,
sync::{Arc, atomic::AtomicBool},
};
use tokio::sync::Mutex;
use super::model::{ChatRequest, Model};
const NO_CACHE: &str = "no-cache, must-revalidate";
const KEEP_ALIVE: &str = "keep-alive";
use super::{constant::FREE_MODELS, model::ChatRequest};
static CURRENT_KEY_INDEX: AtomicUsize = AtomicUsize::new(0);
@@ -106,7 +105,7 @@ pub async fn handle_models(
));
}
let index = CURRENT_KEY_INDEX.fetch_add(1, Ordering::SeqCst) % token_infos.len();
let index = CURRENT_KEY_INDEX.load(Ordering::Acquire) % token_infos.len();
let token_info = &token_infos[index];
is_pri = true;
(
@@ -317,19 +316,18 @@ pub async fn handle_chat(
if log.token_info.token == auth_token {
if let Some(profile) = &log.token_info.profile {
if profile.stripe.membership_type == MembershipType::Free {
let is_premium = USAGE_CHECK_MODELS.contains(&model);
need_profile_check = if is_premium {
profile
.usage
.premium
.max_requests
.is_some_and(|max| profile.usage.premium.num_requests >= max)
} else {
need_profile_check = if FREE_MODELS.contains(&model.id) {
profile
.usage
.standard
.max_requests
.is_some_and(|max| profile.usage.standard.num_requests >= max)
} else {
profile
.usage
.premium
.max_requests
.is_some_and(|max| profile.usage.premium.num_requests >= max)
};
}
break;
@@ -350,15 +348,14 @@ pub async fn handle_chat(
let next_id = state
.request_manager
.request_logs
.last()
.back()
.map_or(1, |log| log.id + 1);
current_id = next_id;
// 如果需要获取用户使用情况,创建后台任务获取profile
if Model::is_usage_check(
model,
UsageCheck::from_proto(current_config.usage_check_models.as_ref()),
) {
if model.is_usage_check(UsageCheck::from_proto(
current_config.usage_check_models.as_ref(),
)) {
let auth_token_clone = auth_token.clone();
let state_clone = state_clone.clone();
let log_id = next_id;
@@ -398,7 +395,7 @@ pub async fn handle_chat(
});
}
state.request_manager.request_logs.push(RequestLog {
state.request_manager.request_logs.push_back(RequestLog {
id: next_id,
timestamp: request_time,
model: intern_string(request.model),
@@ -546,19 +543,16 @@ pub async fn handle_chat(
let is_start = Arc::new(AtomicBool::new(true));
let start_time = std::time::Instant::now();
let decoder = Arc::new(Mutex::new(StreamDecoder::new()));
let is_usage_sent = Arc::new(AtomicBool::new(false));
let need_usage = if request.stream_options.is_some_and(|opt| opt.include_usage) {
Arc::new(Mutex::new(NeedUsage::Need {
client,
auth_token,
checksum,
client_key,
timezone,
is_pri,
}))
} else {
Arc::new(Mutex::new(NeedUsage::None))
};
let need_usage = Arc::new(Mutex::new(NeedUsage::Need {
is_need: request.stream_options.is_some_and(|opt| opt.include_usage),
client,
auth_token,
checksum,
client_key,
timezone,
is_pri,
}));
let usage_uuid = Arc::new(Mutex::new(None));
// 定义消息处理器的上下文结构体
struct MessageProcessContext<'a> {
@@ -568,15 +562,17 @@ pub async fn handle_chat(
start_time: std::time::Instant,
state: &'a Mutex<AppState>,
current_id: u64,
usage_uuid: &'a Mutex<Option<String>>,
need_usage: &'a Mutex<NeedUsage>,
is_usage_sent: &'a AtomicBool,
created: i64,
}
#[derive(Default)]
enum NeedUsage {
#[default]
None,
Taked,
Need {
is_need: bool,
client: reqwest::Client,
auth_token: String,
checksum: String,
@@ -589,7 +585,10 @@ pub async fn handle_chat(
impl NeedUsage {
#[inline(always)]
const fn is_need(&self) -> bool {
matches!(*self, Self::Need { .. })
match self {
Self::Taked => false,
Self::Need { is_need, .. } => *is_need,
}
}
#[inline(always)]
@@ -602,8 +601,8 @@ pub async fn handle_chat(
async fn process_messages(
messages: Vec<StreamMessage>,
ctx: &MessageProcessContext<'_>,
) -> String {
let mut response_data = String::new();
) -> Vec<u8> {
let mut response_data = Vec::new();
for message in messages {
match message {
@@ -611,7 +610,7 @@ pub async fn handle_chat(
let is_first = ctx.is_start.load(Ordering::Acquire);
let response = ChatResponse {
id: ctx.response_id.to_string(),
id: ctx.response_id,
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: chrono::Utc::now().timestamp(),
model: if is_first { Some(ctx.model) } else { None },
@@ -641,65 +640,13 @@ pub async fn handle_chat(
},
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
response_data.extend_from_slice(b"data: ");
response_data.extend_from_slice(&serde_json::to_vec(&response).unwrap());
response_data.extend_from_slice(b"\n\n");
}
StreamMessage::Usage(usage_uuid) => {
if !ctx.is_usage_sent.load(Ordering::Acquire) {
if let NeedUsage::Need {
client,
auth_token,
checksum,
client_key,
timezone,
is_pri,
} = ctx.need_usage.lock().await.take()
{
let usage = get_token_usage(
client,
&auth_token,
&checksum,
&client_key,
timezone,
is_pri,
usage_uuid,
)
.await;
if let Some(ref usage) = usage {
let mut state = ctx.state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == ctx.current_id)
{
if let Some(chain) = &mut log.chain {
chain.usage = OptionUsage::Uasge {
input: usage.prompt_tokens,
output: usage.completion_tokens,
}
}
}
}
let response = ChatResponse {
id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: chrono::Utc::now().timestamp(),
model: None,
choices: vec![],
usage: TriState::Some(usage.unwrap_or_default()),
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
ctx.is_usage_sent.store(true, Ordering::Release);
}
} else {
crate::debug_println!("usage is sent, but find {usage_uuid}");
if !usage_uuid.is_empty() && ctx.need_usage.lock().await.is_need() {
*ctx.usage_uuid.lock().await = Some(usage_uuid);
}
}
StreamMessage::StreamEnd => {
@@ -720,7 +667,7 @@ pub async fn handle_chat(
}
let response = ChatResponse {
id: ctx.response_id.to_string(),
id: ctx.response_id,
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: chrono::Utc::now().timestamp(),
model: None,
@@ -740,31 +687,65 @@ pub async fn handle_chat(
TriState::None
},
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
if !ctx.is_usage_sent.load(Ordering::Acquire)
&& ctx.need_usage.lock().await.is_need()
{
let response = ChatResponse {
id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: chrono::Utc::now().timestamp(),
model: None,
choices: vec![],
usage: TriState::Some(Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
}),
response_data.extend_from_slice(b"data: ");
response_data.extend_from_slice(&serde_json::to_vec(&response).unwrap());
response_data.extend_from_slice(b"\n\n");
if let Some(usage_uuid) = ctx.usage_uuid.lock().await.take() {
if let NeedUsage::Need {
is_need,
client,
auth_token,
checksum,
client_key,
timezone,
is_pri,
} = ctx.need_usage.lock().await.take()
{
let usage = if *crate::app::lazy::REAL_USAGE {
let usage = tokio::spawn(get_token_usage(
client, auth_token, checksum, client_key, timezone, is_pri,
usage_uuid,
))
.await
.unwrap_or_default();
if let Some(ref usage) = usage {
let mut state = ctx.state.lock().await;
if let Some(log) = state
.request_manager
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == ctx.current_id)
{
if let Some(chain) = &mut log.chain {
chain.usage = OptionUsage::Uasge {
input: usage.prompt_tokens,
output: usage.completion_tokens,
}
}
}
}
usage
} else {
None
};
if is_need {
let response = ChatResponse {
id: ctx.response_id,
object: OBJECT_CHAT_COMPLETION_CHUNK,
created: ctx.created,
model: None,
choices: vec![],
usage: TriState::Some(usage.unwrap_or_default()),
};
response_data.extend_from_slice(b"data: ");
response_data
.extend_from_slice(&serde_json::to_vec(&response).unwrap());
response_data.extend_from_slice(b"\n\n");
}
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
ctx.is_usage_sent.store(true, Ordering::Release);
};
}
}
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = ctx.state.try_lock() {
@@ -781,7 +762,7 @@ pub async fn handle_chat(
} else {
log.chain = Some(Chain {
prompt: Prompt::new(debug_prompt),
delays: vec![],
delays: None,
usage: OptionUsage::None,
});
}
@@ -866,6 +847,8 @@ pub async fn handle_chat(
}
}
let created = Arc::new(std::sync::OnceLock::new());
// 处理后续的stream
let stream = stream
.then({
@@ -877,27 +860,29 @@ pub async fn handle_chat(
let response_id = response_id.clone();
let is_start = is_start.clone();
let state = state.clone();
let is_usage_sent = is_usage_sent.clone();
let need_usage = need_usage.clone();
let usage_uuid = usage_uuid.clone();
let created = created.clone();
async move {
let chunk = match chunk {
Ok(c) => c,
Err(e) => {
crate::debug_println!("Find chunk error: {e}");
Err(_) => {
// crate::debug_println!("Find chunk error: {e}");
return Ok::<_, Infallible>(Bytes::new());
}
};
let ctx = MessageProcessContext {
response_id: &response_id,
model,
model: model.id,
is_start: &is_start,
start_time,
state: &state,
current_id,
usage_uuid: &usage_uuid,
need_usage: &need_usage,
is_usage_sent: &is_usage_sent,
created: *created.get_or_init(|| chrono::Utc::now().timestamp()),
};
// 使用decoder处理chunk
@@ -930,16 +915,16 @@ pub async fn handle_chat(
}
};
let mut response_data = String::new();
let mut response_data = Vec::new();
if let Some(first_msg) = decoder.lock().await.take_first_result() {
let first_response = process_messages(first_msg, &ctx).await;
response_data.push_str(&first_response);
response_data.extend_from_slice(&first_response);
}
let current_response = process_messages(messages, &ctx).await;
if !current_response.is_empty() {
response_data.push_str(&current_response);
response_data.extend_from_slice(&current_response);
}
Ok(Bytes::from(response_data))
@@ -970,10 +955,10 @@ pub async fn handle_chat(
}));
Ok(Response::builder()
.header(CACHE_CONTROL, NO_CACHE)
.header(CONNECTION, KEEP_ALIVE)
.header(CONTENT_TYPE, "text/event-stream")
.header(TRANSFER_ENCODING, "chunked")
.header(CACHE_CONTROL, header_value_no_cache_revalidate())
.header(CONNECTION, header_value_keep_alive())
.header(CONTENT_TYPE, header_value_event_stream())
.header(TRANSFER_ENCODING, header_value_chunked())
.body(Body::from_stream(stream))
.unwrap())
} else {
@@ -1095,13 +1080,7 @@ pub async fn handle_chat(
let (usage1, usage2) = if !usage_uuid.is_empty() {
let result = get_token_usage(
client,
&auth_token,
&checksum,
&client_key,
timezone,
is_pri,
usage_uuid,
client, auth_token, checksum, client_key, timezone, is_pri, usage_uuid,
)
.await;
let result2 = match result {
@@ -1117,10 +1096,10 @@ pub async fn handle_chat(
};
let response_data = ChatResponse {
id: format!("chatcmpl-{trace_id}"),
id: &format!("chatcmpl-{trace_id}"),
object: OBJECT_CHAT_COMPLETION,
created: chrono::Utc::now().timestamp(),
model: Some(model),
model: Some(model.id),
choices: vec![Choice {
index: 0,
message: Some(Message {
@@ -1155,11 +1134,11 @@ pub async fn handle_chat(
}
}
let data = serde_json::to_string(&response_data).unwrap();
let data = serde_json::to_vec(&response_data).unwrap();
Ok(Response::builder()
.header(CACHE_CONTROL, NO_CACHE)
.header(CONNECTION, KEEP_ALIVE)
.header(CONTENT_TYPE, "application/json")
.header(CACHE_CONTROL, header_value_no_cache_revalidate())
.header(CONNECTION, header_value_keep_alive())
.header(CONTENT_TYPE, header_value_json())
.header(CONTENT_LENGTH, data.len())
.body(Body::from(data))
.unwrap())