v0.1.3-rc.4正式版

This commit is contained in:
wisdgod
2025-01-27 14:03:46 +08:00
parent 76d5b55b5a
commit c58f2697f0
41 changed files with 1956 additions and 964 deletions

View File

@@ -2,12 +2,13 @@ use crate::{
app::{
constant::{
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP, OBJECT_CHAT_COMPLETION,
OBJECT_CHAT_COMPLETION_CHUNK, STATUS_FAILED, STATUS_PENDING, STATUS_SUCCESS,
OBJECT_CHAT_COMPLETION_CHUNK,
},
lazy::{
AUTH_TOKEN, KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT,
lazy::{AUTH_TOKEN, KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT},
model::{
AppConfig, AppState, ChatRequest, LogStatus, RequestLog, TimingInfo, TokenInfo,
UsageCheck,
},
model::{AppConfig, AppState, ChatRequest, RequestLog, TimingInfo, TokenInfo, UsageCheck},
},
chat::{
config::KeyConfig,
@@ -16,11 +17,11 @@ use crate::{
model::{
ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage,
},
stream::{parse_stream_data, StreamMessage},
stream::{StreamDecoder, StreamMessage},
},
common::{
client::build_client,
model::{error::ChatError, userinfo::MembershipType, ErrorResponse},
model::{error::ChatError, userinfo::MembershipType, ApiStatus, ErrorResponse},
utils::{
format_time_ms, from_base64, get_token_profile, tokeninfo_to_token,
validate_token_and_checksum,
@@ -38,16 +39,13 @@ use axum::{
Json,
};
use bytes::Bytes;
use futures::{Stream, StreamExt};
use futures::StreamExt;
use prost::Message as _;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{
convert::Infallible,
sync::{atomic::AtomicBool, Arc},
};
use std::{
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
};
use tokio::sync::Mutex;
use uuid::Uuid;
@@ -66,8 +64,16 @@ pub async fn handle_chat(
Json(request): Json<ChatRequest>,
) -> Result<Response<Body>, (StatusCode, Json<ErrorResponse>)> {
let allow_claude = AppConfig::get_allow_claude();
let is_search = request.model.ends_with("-online");
let model_name = if is_search {
request.model[..request.model.len() - 7].to_string()
} else {
request.model.clone()
};
// 验证模型是否支持并获取模型信息
let model = AVAILABLE_MODELS.iter().find(|m| m.id == request.model);
let model = AVAILABLE_MODELS.iter().find(|m| m.id == model_name);
let model_supported = model.is_some();
if !(model_supported || allow_claude && request.model.starts_with("claude")) {
@@ -168,7 +174,7 @@ pub async fn handle_chat(
return false;
}
let is_premium = USAGE_CHECK_MODELS.contains(&request.model.as_str());
let is_premium = USAGE_CHECK_MODELS.contains(&model_name.as_str());
let standard = &profile.usage.standard;
let premium = &profile.usage.premium;
@@ -213,14 +219,28 @@ pub async fn handle_chat(
tokio::spawn(async move {
let profile = get_token_profile(&auth_token_clone).await;
let mut state = state_clone.lock().await;
// 根据id查找对应的日志
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == log_id)
{
log.token_info.profile = profile;
// 先找到所有需要更新的位置的索引
let token_info_idx = state
.token_infos
.iter()
.position(|info| info.token == auth_token_clone);
let log_idx = state.request_logs.iter().rposition(|log| log.id == log_id);
// 根据索引更新
match (token_info_idx, log_idx) {
(Some(t_idx), Some(l_idx)) => {
state.token_infos[t_idx].profile = profile.clone();
state.request_logs[l_idx].token_info.profile = profile;
}
(Some(t_idx), None) => {
state.token_infos[t_idx].profile = profile;
}
(None, Some(l_idx)) => {
state.request_logs[l_idx].token_info.profile = profile;
}
(None, None) => {}
}
});
}
@@ -240,7 +260,7 @@ pub async fn handle_chat(
first: None,
},
stream: request.stream,
status: STATUS_PENDING,
status: LogStatus::Pending,
error: None,
});
@@ -252,9 +272,10 @@ pub async fn handle_chat(
// 将消息转换为hex格式
let hex_data = match super::adapter::encode_chat_message(
request.messages,
&request.model,
&model_name,
current_config.disable_vision(),
current_config.enable_slow_pool(),
is_search,
)
.await
{
@@ -267,7 +288,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.status = LogStatus::Failed;
log.error = Some(e.to_string());
}
state.active_requests -= 1;
@@ -282,7 +303,7 @@ pub async fn handle_chat(
};
// 构建请求客户端
let client = build_client(&auth_token, &checksum);
let client = build_client(&auth_token, &checksum, is_search);
// 添加超时设置
let response = tokio::time::timeout(
std::time::Duration::from_secs(*SERVICE_TIMEOUT),
@@ -303,7 +324,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_SUCCESS;
log.status = LogStatus::Success;
}
}
resp
@@ -318,7 +339,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.status = LogStatus::Failed;
log.error = Some(e.to_string());
}
state.active_requests -= 1;
@@ -340,7 +361,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.status = LogStatus::Failed;
log.error = Some("Request timeout".to_string());
}
state.active_requests -= 1;
@@ -359,70 +380,171 @@ pub async fn handle_chat(
state.active_requests -= 1;
}
let convert_web_ref = current_config.include_web_references();
if request.stream {
let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple());
let full_text = Arc::new(Mutex::new(String::with_capacity(1024)));
let is_start = Arc::new(AtomicBool::new(true));
let start_time = std::time::Instant::now();
let first_chunk_time = Arc::new(Mutex::new(None));
let first_chunk_time = Arc::new(Mutex::new(None::<f64>));
let decoder = Arc::new(Mutex::new(StreamDecoder::new()));
// 定义消息处理器的上下文结构体
struct MessageProcessContext<'a> {
response_id: &'a str,
model: &'a str,
is_start: &'a AtomicBool,
first_chunk_time: &'a Mutex<Option<f64>>,
start_time: std::time::Instant,
state: &'a Mutex<AppState>,
current_id: u64,
}
// 处理消息并生成响应数据的辅助函数
async fn process_messages(
messages: Vec<StreamMessage>,
ctx: &MessageProcessContext<'_>,
) -> String {
let mut response_data = String::new();
for message in messages {
match message {
StreamMessage::Content(text) => {
// 记录首字时间(如果还未记录)
if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() {
if first_time.is_none() {
*first_time = Some(ctx.start_time.elapsed().as_secs_f64());
}
}
let is_first = ctx.is_start.load(Ordering::SeqCst);
let response = ChatResponse {
id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
created: chrono::Utc::now().timestamp(),
model: if is_first {
Some(ctx.model.to_string())
} else {
None
},
choices: vec![Choice {
index: 0,
message: None,
delta: Some(Delta {
role: if is_first {
ctx.is_start.store(false, Ordering::SeqCst);
Some(Role::Assistant)
} else {
None
},
content: Some(text),
}),
finish_reason: None,
}],
usage: None,
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
}
StreamMessage::StreamEnd => {
// 计算总时间和首次片段时间
let total_time = ctx.start_time.elapsed().as_secs_f64();
let first_time = ctx.first_chunk_time.lock().await.unwrap_or(total_time);
{
let mut state = ctx.state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == ctx.current_id)
{
log.timing.total = format_time_ms(total_time);
log.timing.first = Some(format_time_ms(first_time));
}
}
let response = ChatResponse {
id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
created: chrono::Utc::now().timestamp(),
model: None,
choices: vec![Choice {
index: 0,
message: None,
delta: Some(Delta {
role: None,
content: None,
}),
finish_reason: Some(FINISH_REASON_STOP.to_string()),
}],
usage: None,
};
response_data.push_str(&format!(
"data: {}\n\ndata: [DONE]\n\n",
serde_json::to_string(&response).unwrap()
));
}
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = ctx.state.try_lock() {
if let Some(last_log) = state.request_logs.last_mut() {
last_log.prompt = Some(debug_prompt);
}
}
}
_ => {} // 忽略其他消息类型
}
}
response_data
}
let stream = {
// 创建新的 stream
let mut stream = response.bytes_stream();
if current_config.enable_stream_check() {
// 检查第一个 chunk
// 处理第一个chunk并获取first_result
while decoder.lock().await.has_no_first_result() {
match stream.next().await {
Some(first_chunk) => {
let chunk = first_chunk.map_err(|e| {
let error_message = format!("Failed to read response chunk: {}", e);
// 理论上,若程序正常,必定成功,因为前面判断过了
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ChatError::RequestFailed(error_message).to_json()),
)
})?;
match parse_stream_data(&chunk) {
Err(StreamError::ChatError(error)) => {
let error_respone = error.to_error_response();
// 更新请求日志为失败
if let Err(StreamError::ChatError(error)) =
decoder.lock().await.decode(&chunk, convert_web_ref)
{
let error_response = error.to_error_response();
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some(error_respone.native_code());
log.timing.total =
format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
}
log.status = LogStatus::Failed;
log.error = Some(error_response.native_code());
log.timing.total =
format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
}
return Err((
error_respone.status_code(),
Json(error_respone.to_common()),
));
}
Ok(_) | Err(_) => {
// 创建一个包含第一个 chunk 的 stream
Box::pin(
futures::stream::once(async move { Ok(chunk) }).chain(stream),
)
as Pin<
Box<
dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send,
>,
>
}
return Err((
error_response.status_code(),
Json(error_response.to_common()),
));
}
}
None => {
// Box::pin(stream)
// as Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>
// 更新请求日志为失败
{
let mut state = state.lock().await;
@@ -432,7 +554,7 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.status = LogStatus::Failed;
log.error = Some("Empty stream response".to_string());
state.error_requests += 1;
}
@@ -446,176 +568,66 @@ pub async fn handle_chat(
));
}
}
} else {
Box::pin(stream)
as Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>
}
}
.then({
let buffer = Arc::new(Mutex::new(Vec::new()));
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
move |chunk| {
let buffer = buffer.clone();
// 处理后续的stream
stream.then({
let decoder = decoder.clone();
let response_id = response_id.clone();
let model = request.model.clone();
let is_start = is_start.clone();
let full_text = full_text.clone();
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
// 根据配置决定是否发送最后的 finish_reason
let include_finish_reason = current_config.include_stop_stream();
async move {
let chunk = chunk.unwrap_or_default();
let mut buffer_guard = buffer.lock().await;
buffer_guard.extend_from_slice(&chunk);
move |chunk| {
let decoder = decoder.clone();
let response_id = response_id.clone();
let model = model.clone();
let is_start = is_start.clone();
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
match parse_stream_data(&buffer_guard) {
Ok(StreamMessage::Content(texts)) => {
buffer_guard.clear();
let mut response_data = String::new();
async move {
let chunk = chunk.unwrap_or_default();
// 记录首字时间(如果还未记录)
if let Ok(mut first_time) = first_chunk_time.try_lock() {
if first_time.is_none() {
*first_time =
Some(format_time_ms(start_time.elapsed().as_secs_f64()));
}
let ctx = MessageProcessContext {
response_id: &response_id,
model: &model,
is_start: &is_start,
first_chunk_time: &first_chunk_time,
start_time,
state: &state,
current_id,
};
// 使用decoder处理chunk
let messages = match decoder.lock().await.decode(&chunk, convert_web_ref) {
Ok(msgs) => msgs,
Err(e) => {
eprintln!("[警告] Stream error: {}", e);
return Ok::<_, Infallible>(Bytes::new());
}
};
// 处理文本内容
for text in texts {
let mut text_guard = full_text.lock().await;
text_guard.push_str(&text);
let is_first = is_start.load(Ordering::SeqCst);
let mut response_data = String::new();
let response = ChatResponse {
id: response_id.clone(),
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
created: chrono::Utc::now().timestamp(),
model: if is_first { Some(model.clone()) } else { None },
choices: vec![Choice {
index: 0,
message: None,
delta: Some(Delta {
role: if is_first {
is_start.store(false, Ordering::SeqCst);
Some(Role::Assistant)
} else {
None
},
content: Some(text),
}),
finish_reason: None,
}],
usage: None,
};
response_data.push_str(&format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
));
}
Ok::<_, Infallible>(Bytes::from(response_data))
}
Ok(StreamMessage::StreamStart) => {
buffer_guard.clear();
// 发送初始响应,包含模型信息
let response = ChatResponse {
id: response_id.clone(),
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
created: chrono::Utc::now().timestamp(),
model: {
is_start.store(true, Ordering::SeqCst);
Some(model.clone())
},
choices: vec![Choice {
index: 0,
message: None,
delta: Some(Delta {
role: Some(Role::Assistant),
content: Some(String::new()),
}),
finish_reason: None,
}],
usage: None,
};
Ok(Bytes::from(format!(
"data: {}\n\n",
serde_json::to_string(&response).unwrap()
)))
}
Ok(StreamMessage::StreamEnd) => {
buffer_guard.clear();
// 计算总时间和首次片段时间
let total_time = format_time_ms(start_time.elapsed().as_secs_f64());
let first_time = first_chunk_time.lock().await.unwrap_or(total_time);
{
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.timing.total = total_time;
log.timing.first = Some(first_time);
}
}
if include_finish_reason {
let response = ChatResponse {
id: response_id.clone(),
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
created: chrono::Utc::now().timestamp(),
model: None,
choices: vec![Choice {
index: 0,
message: None,
delta: Some(Delta {
role: None,
content: None,
}),
finish_reason: Some(FINISH_REASON_STOP.to_string()),
}],
usage: None,
};
Ok(Bytes::from(format!(
"data: {}\n\ndata: [DONE]\n\n",
serde_json::to_string(&response).unwrap()
)))
} else {
Ok(Bytes::from("data: [DONE]\n\n"))
if let Some(first_msg) = decoder.lock().await.take_first_result() {
let first_response = process_messages(first_msg, &ctx).await;
if !first_response.is_empty() {
response_data.push_str(&first_response);
}
}
Ok(StreamMessage::Incomplete) => {
// 保持buffer中的数据以待下一个chunk
Ok(Bytes::new())
}
Ok(StreamMessage::Debug(debug_prompt)) => {
buffer_guard.clear();
if let Ok(mut state) = state.try_lock() {
if let Some(last_log) = state.request_logs.last_mut() {
last_log.prompt = Some(debug_prompt.clone());
}
}
Ok(Bytes::new())
}
Err(e) => {
buffer_guard.clear();
eprintln!("[警告] Stream error: {}", e);
Ok(Bytes::new())
let current_response = process_messages(messages, &ctx).await;
if !current_response.is_empty() {
response_data.push_str(&current_response);
}
Ok(Bytes::from(response_data))
}
}
}
});
})
};
Ok(Response::builder()
.header("Cache-Control", "no-cache")
@@ -626,81 +638,62 @@ pub async fn handle_chat(
} else {
// 非流式响应
let start_time = std::time::Instant::now();
let mut first_chunk_received = false;
let mut first_chunk_time = 0.0;
let mut first_chunk_time = None::<f64>;
let mut decoder = StreamDecoder::new();
let mut full_text = String::with_capacity(1024);
let mut stream = response.bytes_stream();
let mut prompt = None;
let mut all_chunks = Vec::new();
let mut buffer = Vec::new();
// 收集所有的chunks
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| {
// 更新请求日志为失败
if let Ok(mut state) = state.try_lock() {
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some(format!("Failed to read response chunk: {}", e));
state.error_requests += 1;
}
}
let error_message = format!("Failed to read response chunk: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(
ChatError::RequestFailed(format!("Failed to read response chunk: {}", e))
.to_json(),
),
Json(ChatError::RequestFailed(error_message).to_json()),
)
})?;
all_chunks.extend(chunk);
}
buffer.extend_from_slice(&chunk);
// 一次性解码所有数据
let messages = match decoder.decode(&all_chunks, convert_web_ref) {
Ok(msgs) => msgs,
Err(StreamError::ChatError(error)) => {
let error_response = error.to_error_response();
return Err((
error_response.status_code(),
Json(error_response.to_common()),
));
}
Err(e) => {
let error_response = ErrorResponse {
status: ApiStatus::Error,
code: Some(500),
error: Some(e.to_string()),
message: None,
};
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)));
}
};
match parse_stream_data(&buffer) {
Ok(StreamMessage::Content(texts)) => {
if !first_chunk_received {
first_chunk_time = format_time_ms(start_time.elapsed().as_secs_f64());
first_chunk_received = true;
// 处理所有消息
for message in messages {
match message {
StreamMessage::Content(text) => {
if first_chunk_time.is_none() {
first_chunk_time = Some(start_time.elapsed().as_secs_f64());
}
for text in texts {
full_text.push_str(&text);
}
buffer.clear();
full_text.push_str(&text);
}
Ok(StreamMessage::Incomplete) => continue,
Ok(StreamMessage::Debug(debug_prompt)) => {
prompt = Some(debug_prompt);
buffer.clear();
}
Ok(StreamMessage::StreamStart) | Ok(StreamMessage::StreamEnd) => {
buffer.clear();
}
Err(StreamError::ChatError(error)) => {
let error = error.to_error_response();
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.error = Some(error.native_code());
log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = state.try_lock() {
if let Some(last_log) = state.request_logs.last_mut() {
last_log.prompt = Some(debug_prompt);
}
}
return Err((error.status_code(), Json(error.to_common())));
}
Err(_) => {
buffer.clear();
continue;
}
_ => {}
}
}
@@ -715,11 +708,8 @@ pub async fn handle_chat(
.rev()
.find(|log| log.id == current_id)
{
log.status = STATUS_FAILED;
log.status = LogStatus::Failed;
log.error = Some("Empty response received".to_string());
if let Some(p) = prompt {
log.prompt = Some(p);
}
state.error_requests += 1;
}
}
@@ -761,9 +751,8 @@ pub async fn handle_chat(
.find(|log| log.id == current_id)
{
log.timing.total = total_time;
log.timing.first = Some(first_chunk_time);
log.prompt = prompt;
log.status = STATUS_SUCCESS;
log.timing.first = first_chunk_time;
log.status = LogStatus::Success;
}
}