mirror of
https://github.com/wisdgod/cursor-api.git
synced 2025-10-15 19:20:39 +08:00
v0.1.3-rc.4正式版
This commit is contained in:
@@ -2,12 +2,13 @@ use crate::{
|
||||
app::{
|
||||
constant::{
|
||||
AUTHORIZATION_BEARER_PREFIX, FINISH_REASON_STOP, OBJECT_CHAT_COMPLETION,
|
||||
OBJECT_CHAT_COMPLETION_CHUNK, STATUS_FAILED, STATUS_PENDING, STATUS_SUCCESS,
|
||||
OBJECT_CHAT_COMPLETION_CHUNK,
|
||||
},
|
||||
lazy::{
|
||||
AUTH_TOKEN, KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT,
|
||||
lazy::{AUTH_TOKEN, KEY_PREFIX, KEY_PREFIX_LEN, REQUEST_LOGS_LIMIT, SERVICE_TIMEOUT},
|
||||
model::{
|
||||
AppConfig, AppState, ChatRequest, LogStatus, RequestLog, TimingInfo, TokenInfo,
|
||||
UsageCheck,
|
||||
},
|
||||
model::{AppConfig, AppState, ChatRequest, RequestLog, TimingInfo, TokenInfo, UsageCheck},
|
||||
},
|
||||
chat::{
|
||||
config::KeyConfig,
|
||||
@@ -16,11 +17,11 @@ use crate::{
|
||||
model::{
|
||||
ChatResponse, Choice, Delta, Message, MessageContent, ModelsResponse, Role, Usage,
|
||||
},
|
||||
stream::{parse_stream_data, StreamMessage},
|
||||
stream::{StreamDecoder, StreamMessage},
|
||||
},
|
||||
common::{
|
||||
client::build_client,
|
||||
model::{error::ChatError, userinfo::MembershipType, ErrorResponse},
|
||||
model::{error::ChatError, userinfo::MembershipType, ApiStatus, ErrorResponse},
|
||||
utils::{
|
||||
format_time_ms, from_base64, get_token_profile, tokeninfo_to_token,
|
||||
validate_token_and_checksum,
|
||||
@@ -38,16 +39,13 @@ use axum::{
|
||||
Json,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::{Stream, StreamExt};
|
||||
use futures::StreamExt;
|
||||
use prost::Message as _;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -66,8 +64,16 @@ pub async fn handle_chat(
|
||||
Json(request): Json<ChatRequest>,
|
||||
) -> Result<Response<Body>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let allow_claude = AppConfig::get_allow_claude();
|
||||
|
||||
let is_search = request.model.ends_with("-online");
|
||||
let model_name = if is_search {
|
||||
request.model[..request.model.len() - 7].to_string()
|
||||
} else {
|
||||
request.model.clone()
|
||||
};
|
||||
|
||||
// 验证模型是否支持并获取模型信息
|
||||
let model = AVAILABLE_MODELS.iter().find(|m| m.id == request.model);
|
||||
let model = AVAILABLE_MODELS.iter().find(|m| m.id == model_name);
|
||||
let model_supported = model.is_some();
|
||||
|
||||
if !(model_supported || allow_claude && request.model.starts_with("claude")) {
|
||||
@@ -168,7 +174,7 @@ pub async fn handle_chat(
|
||||
return false;
|
||||
}
|
||||
|
||||
let is_premium = USAGE_CHECK_MODELS.contains(&request.model.as_str());
|
||||
let is_premium = USAGE_CHECK_MODELS.contains(&model_name.as_str());
|
||||
let standard = &profile.usage.standard;
|
||||
let premium = &profile.usage.premium;
|
||||
|
||||
@@ -213,14 +219,28 @@ pub async fn handle_chat(
|
||||
tokio::spawn(async move {
|
||||
let profile = get_token_profile(&auth_token_clone).await;
|
||||
let mut state = state_clone.lock().await;
|
||||
// 根据id查找对应的日志
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == log_id)
|
||||
{
|
||||
log.token_info.profile = profile;
|
||||
|
||||
// 先找到所有需要更新的位置的索引
|
||||
let token_info_idx = state
|
||||
.token_infos
|
||||
.iter()
|
||||
.position(|info| info.token == auth_token_clone);
|
||||
|
||||
let log_idx = state.request_logs.iter().rposition(|log| log.id == log_id);
|
||||
|
||||
// 根据索引更新
|
||||
match (token_info_idx, log_idx) {
|
||||
(Some(t_idx), Some(l_idx)) => {
|
||||
state.token_infos[t_idx].profile = profile.clone();
|
||||
state.request_logs[l_idx].token_info.profile = profile;
|
||||
}
|
||||
(Some(t_idx), None) => {
|
||||
state.token_infos[t_idx].profile = profile;
|
||||
}
|
||||
(None, Some(l_idx)) => {
|
||||
state.request_logs[l_idx].token_info.profile = profile;
|
||||
}
|
||||
(None, None) => {}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -240,7 +260,7 @@ pub async fn handle_chat(
|
||||
first: None,
|
||||
},
|
||||
stream: request.stream,
|
||||
status: STATUS_PENDING,
|
||||
status: LogStatus::Pending,
|
||||
error: None,
|
||||
});
|
||||
|
||||
@@ -252,9 +272,10 @@ pub async fn handle_chat(
|
||||
// 将消息转换为hex格式
|
||||
let hex_data = match super::adapter::encode_chat_message(
|
||||
request.messages,
|
||||
&request.model,
|
||||
&model_name,
|
||||
current_config.disable_vision(),
|
||||
current_config.enable_slow_pool(),
|
||||
is_search,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -267,7 +288,7 @@ pub async fn handle_chat(
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.status = LogStatus::Failed;
|
||||
log.error = Some(e.to_string());
|
||||
}
|
||||
state.active_requests -= 1;
|
||||
@@ -282,7 +303,7 @@ pub async fn handle_chat(
|
||||
};
|
||||
|
||||
// 构建请求客户端
|
||||
let client = build_client(&auth_token, &checksum);
|
||||
let client = build_client(&auth_token, &checksum, is_search);
|
||||
// 添加超时设置
|
||||
let response = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(*SERVICE_TIMEOUT),
|
||||
@@ -303,7 +324,7 @@ pub async fn handle_chat(
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_SUCCESS;
|
||||
log.status = LogStatus::Success;
|
||||
}
|
||||
}
|
||||
resp
|
||||
@@ -318,7 +339,7 @@ pub async fn handle_chat(
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.status = LogStatus::Failed;
|
||||
log.error = Some(e.to_string());
|
||||
}
|
||||
state.active_requests -= 1;
|
||||
@@ -340,7 +361,7 @@ pub async fn handle_chat(
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.status = LogStatus::Failed;
|
||||
log.error = Some("Request timeout".to_string());
|
||||
}
|
||||
state.active_requests -= 1;
|
||||
@@ -359,70 +380,171 @@ pub async fn handle_chat(
|
||||
state.active_requests -= 1;
|
||||
}
|
||||
|
||||
let convert_web_ref = current_config.include_web_references();
|
||||
|
||||
if request.stream {
|
||||
let response_id = format!("chatcmpl-{}", Uuid::new_v4().simple());
|
||||
let full_text = Arc::new(Mutex::new(String::with_capacity(1024)));
|
||||
let is_start = Arc::new(AtomicBool::new(true));
|
||||
let start_time = std::time::Instant::now();
|
||||
let first_chunk_time = Arc::new(Mutex::new(None));
|
||||
let first_chunk_time = Arc::new(Mutex::new(None::<f64>));
|
||||
let decoder = Arc::new(Mutex::new(StreamDecoder::new()));
|
||||
|
||||
// 定义消息处理器的上下文结构体
|
||||
struct MessageProcessContext<'a> {
|
||||
response_id: &'a str,
|
||||
model: &'a str,
|
||||
is_start: &'a AtomicBool,
|
||||
first_chunk_time: &'a Mutex<Option<f64>>,
|
||||
start_time: std::time::Instant,
|
||||
state: &'a Mutex<AppState>,
|
||||
current_id: u64,
|
||||
}
|
||||
|
||||
// 处理消息并生成响应数据的辅助函数
|
||||
async fn process_messages(
|
||||
messages: Vec<StreamMessage>,
|
||||
ctx: &MessageProcessContext<'_>,
|
||||
) -> String {
|
||||
let mut response_data = String::new();
|
||||
|
||||
for message in messages {
|
||||
match message {
|
||||
StreamMessage::Content(text) => {
|
||||
// 记录首字时间(如果还未记录)
|
||||
if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() {
|
||||
if first_time.is_none() {
|
||||
*first_time = Some(ctx.start_time.elapsed().as_secs_f64());
|
||||
}
|
||||
}
|
||||
|
||||
let is_first = ctx.is_start.load(Ordering::SeqCst);
|
||||
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: if is_first {
|
||||
Some(ctx.model.to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: None,
|
||||
delta: Some(Delta {
|
||||
role: if is_first {
|
||||
ctx.is_start.store(false, Ordering::SeqCst);
|
||||
Some(Role::Assistant)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
content: Some(text),
|
||||
}),
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
}
|
||||
StreamMessage::StreamEnd => {
|
||||
// 计算总时间和首次片段时间
|
||||
let total_time = ctx.start_time.elapsed().as_secs_f64();
|
||||
let first_time = ctx.first_chunk_time.lock().await.unwrap_or(total_time);
|
||||
|
||||
{
|
||||
let mut state = ctx.state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == ctx.current_id)
|
||||
{
|
||||
log.timing.total = format_time_ms(total_time);
|
||||
log.timing.first = Some(format_time_ms(first_time));
|
||||
}
|
||||
}
|
||||
|
||||
let response = ChatResponse {
|
||||
id: ctx.response_id.to_string(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: None,
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: None,
|
||||
delta: Some(Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
}),
|
||||
finish_reason: Some(FINISH_REASON_STOP.to_string()),
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\ndata: [DONE]\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
}
|
||||
StreamMessage::Debug(debug_prompt) => {
|
||||
if let Ok(mut state) = ctx.state.try_lock() {
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.prompt = Some(debug_prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {} // 忽略其他消息类型
|
||||
}
|
||||
}
|
||||
|
||||
response_data
|
||||
}
|
||||
|
||||
let stream = {
|
||||
// 创建新的 stream
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
if current_config.enable_stream_check() {
|
||||
// 检查第一个 chunk
|
||||
// 处理第一个chunk并获取first_result
|
||||
while decoder.lock().await.has_no_first_result() {
|
||||
match stream.next().await {
|
||||
Some(first_chunk) => {
|
||||
let chunk = first_chunk.map_err(|e| {
|
||||
let error_message = format!("Failed to read response chunk: {}", e);
|
||||
// 理论上,若程序正常,必定成功,因为前面判断过了
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ChatError::RequestFailed(error_message).to_json()),
|
||||
)
|
||||
})?;
|
||||
|
||||
match parse_stream_data(&chunk) {
|
||||
Err(StreamError::ChatError(error)) => {
|
||||
let error_respone = error.to_error_response();
|
||||
// 更新请求日志为失败
|
||||
if let Err(StreamError::ChatError(error)) =
|
||||
decoder.lock().await.decode(&chunk, convert_web_ref)
|
||||
{
|
||||
let error_response = error.to_error_response();
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.error = Some(error_respone.native_code());
|
||||
log.timing.total =
|
||||
format_time_ms(start_time.elapsed().as_secs_f64());
|
||||
state.error_requests += 1;
|
||||
}
|
||||
log.status = LogStatus::Failed;
|
||||
log.error = Some(error_response.native_code());
|
||||
log.timing.total =
|
||||
format_time_ms(start_time.elapsed().as_secs_f64());
|
||||
state.error_requests += 1;
|
||||
}
|
||||
return Err((
|
||||
error_respone.status_code(),
|
||||
Json(error_respone.to_common()),
|
||||
));
|
||||
}
|
||||
Ok(_) | Err(_) => {
|
||||
// 创建一个包含第一个 chunk 的 stream
|
||||
Box::pin(
|
||||
futures::stream::once(async move { Ok(chunk) }).chain(stream),
|
||||
)
|
||||
as Pin<
|
||||
Box<
|
||||
dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send,
|
||||
>,
|
||||
>
|
||||
}
|
||||
return Err((
|
||||
error_response.status_code(),
|
||||
Json(error_response.to_common()),
|
||||
));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Box::pin(stream)
|
||||
// as Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
@@ -432,7 +554,7 @@ pub async fn handle_chat(
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.status = LogStatus::Failed;
|
||||
log.error = Some("Empty stream response".to_string());
|
||||
state.error_requests += 1;
|
||||
}
|
||||
@@ -446,176 +568,66 @@ pub async fn handle_chat(
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Box::pin(stream)
|
||||
as Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>
|
||||
}
|
||||
}
|
||||
.then({
|
||||
let buffer = Arc::new(Mutex::new(Vec::new()));
|
||||
let first_chunk_time = first_chunk_time.clone();
|
||||
let state = state.clone();
|
||||
|
||||
move |chunk| {
|
||||
let buffer = buffer.clone();
|
||||
// 处理后续的stream
|
||||
stream.then({
|
||||
let decoder = decoder.clone();
|
||||
let response_id = response_id.clone();
|
||||
let model = request.model.clone();
|
||||
let is_start = is_start.clone();
|
||||
let full_text = full_text.clone();
|
||||
let first_chunk_time = first_chunk_time.clone();
|
||||
let state = state.clone();
|
||||
// 根据配置决定是否发送最后的 finish_reason
|
||||
let include_finish_reason = current_config.include_stop_stream();
|
||||
|
||||
async move {
|
||||
let chunk = chunk.unwrap_or_default();
|
||||
let mut buffer_guard = buffer.lock().await;
|
||||
buffer_guard.extend_from_slice(&chunk);
|
||||
move |chunk| {
|
||||
let decoder = decoder.clone();
|
||||
let response_id = response_id.clone();
|
||||
let model = model.clone();
|
||||
let is_start = is_start.clone();
|
||||
let first_chunk_time = first_chunk_time.clone();
|
||||
let state = state.clone();
|
||||
|
||||
match parse_stream_data(&buffer_guard) {
|
||||
Ok(StreamMessage::Content(texts)) => {
|
||||
buffer_guard.clear();
|
||||
let mut response_data = String::new();
|
||||
async move {
|
||||
let chunk = chunk.unwrap_or_default();
|
||||
|
||||
// 记录首字时间(如果还未记录)
|
||||
if let Ok(mut first_time) = first_chunk_time.try_lock() {
|
||||
if first_time.is_none() {
|
||||
*first_time =
|
||||
Some(format_time_ms(start_time.elapsed().as_secs_f64()));
|
||||
}
|
||||
let ctx = MessageProcessContext {
|
||||
response_id: &response_id,
|
||||
model: &model,
|
||||
is_start: &is_start,
|
||||
first_chunk_time: &first_chunk_time,
|
||||
start_time,
|
||||
state: &state,
|
||||
current_id,
|
||||
};
|
||||
|
||||
// 使用decoder处理chunk
|
||||
let messages = match decoder.lock().await.decode(&chunk, convert_web_ref) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
eprintln!("[警告] Stream error: {}", e);
|
||||
return Ok::<_, Infallible>(Bytes::new());
|
||||
}
|
||||
};
|
||||
|
||||
// 处理文本内容
|
||||
for text in texts {
|
||||
let mut text_guard = full_text.lock().await;
|
||||
text_guard.push_str(&text);
|
||||
let is_first = is_start.load(Ordering::SeqCst);
|
||||
let mut response_data = String::new();
|
||||
|
||||
let response = ChatResponse {
|
||||
id: response_id.clone(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: if is_first { Some(model.clone()) } else { None },
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: None,
|
||||
delta: Some(Delta {
|
||||
role: if is_first {
|
||||
is_start.store(false, Ordering::SeqCst);
|
||||
Some(Role::Assistant)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
content: Some(text),
|
||||
}),
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
|
||||
response_data.push_str(&format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
));
|
||||
}
|
||||
|
||||
Ok::<_, Infallible>(Bytes::from(response_data))
|
||||
}
|
||||
Ok(StreamMessage::StreamStart) => {
|
||||
buffer_guard.clear();
|
||||
// 发送初始响应,包含模型信息
|
||||
let response = ChatResponse {
|
||||
id: response_id.clone(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: {
|
||||
is_start.store(true, Ordering::SeqCst);
|
||||
Some(model.clone())
|
||||
},
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: None,
|
||||
delta: Some(Delta {
|
||||
role: Some(Role::Assistant),
|
||||
content: Some(String::new()),
|
||||
}),
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
|
||||
Ok(Bytes::from(format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
)))
|
||||
}
|
||||
Ok(StreamMessage::StreamEnd) => {
|
||||
buffer_guard.clear();
|
||||
|
||||
// 计算总时间和首次片段时间
|
||||
let total_time = format_time_ms(start_time.elapsed().as_secs_f64());
|
||||
let first_time = first_chunk_time.lock().await.unwrap_or(total_time);
|
||||
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.timing.total = total_time;
|
||||
log.timing.first = Some(first_time);
|
||||
}
|
||||
}
|
||||
|
||||
if include_finish_reason {
|
||||
let response = ChatResponse {
|
||||
id: response_id.clone(),
|
||||
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: None,
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: None,
|
||||
delta: Some(Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
}),
|
||||
finish_reason: Some(FINISH_REASON_STOP.to_string()),
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
Ok(Bytes::from(format!(
|
||||
"data: {}\n\ndata: [DONE]\n\n",
|
||||
serde_json::to_string(&response).unwrap()
|
||||
)))
|
||||
} else {
|
||||
Ok(Bytes::from("data: [DONE]\n\n"))
|
||||
if let Some(first_msg) = decoder.lock().await.take_first_result() {
|
||||
let first_response = process_messages(first_msg, &ctx).await;
|
||||
if !first_response.is_empty() {
|
||||
response_data.push_str(&first_response);
|
||||
}
|
||||
}
|
||||
Ok(StreamMessage::Incomplete) => {
|
||||
// 保持buffer中的数据以待下一个chunk
|
||||
Ok(Bytes::new())
|
||||
}
|
||||
Ok(StreamMessage::Debug(debug_prompt)) => {
|
||||
buffer_guard.clear();
|
||||
if let Ok(mut state) = state.try_lock() {
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.prompt = Some(debug_prompt.clone());
|
||||
}
|
||||
}
|
||||
Ok(Bytes::new())
|
||||
}
|
||||
Err(e) => {
|
||||
buffer_guard.clear();
|
||||
eprintln!("[警告] Stream error: {}", e);
|
||||
Ok(Bytes::new())
|
||||
|
||||
let current_response = process_messages(messages, &ctx).await;
|
||||
if !current_response.is_empty() {
|
||||
response_data.push_str(¤t_response);
|
||||
}
|
||||
|
||||
Ok(Bytes::from(response_data))
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
})
|
||||
};
|
||||
|
||||
Ok(Response::builder()
|
||||
.header("Cache-Control", "no-cache")
|
||||
@@ -626,81 +638,62 @@ pub async fn handle_chat(
|
||||
} else {
|
||||
// 非流式响应
|
||||
let start_time = std::time::Instant::now();
|
||||
let mut first_chunk_received = false;
|
||||
let mut first_chunk_time = 0.0;
|
||||
let mut first_chunk_time = None::<f64>;
|
||||
let mut decoder = StreamDecoder::new();
|
||||
let mut full_text = String::with_capacity(1024);
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut prompt = None;
|
||||
let mut all_chunks = Vec::new();
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
// 收集所有的chunks
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|e| {
|
||||
// 更新请求日志为失败
|
||||
if let Ok(mut state) = state.try_lock() {
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.error = Some(format!("Failed to read response chunk: {}", e));
|
||||
state.error_requests += 1;
|
||||
}
|
||||
}
|
||||
let error_message = format!("Failed to read response chunk: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
ChatError::RequestFailed(format!("Failed to read response chunk: {}", e))
|
||||
.to_json(),
|
||||
),
|
||||
Json(ChatError::RequestFailed(error_message).to_json()),
|
||||
)
|
||||
})?;
|
||||
all_chunks.extend(chunk);
|
||||
}
|
||||
|
||||
buffer.extend_from_slice(&chunk);
|
||||
// 一次性解码所有数据
|
||||
let messages = match decoder.decode(&all_chunks, convert_web_ref) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(StreamError::ChatError(error)) => {
|
||||
let error_response = error.to_error_response();
|
||||
return Err((
|
||||
error_response.status_code(),
|
||||
Json(error_response.to_common()),
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
let error_response = ErrorResponse {
|
||||
status: ApiStatus::Error,
|
||||
code: Some(500),
|
||||
error: Some(e.to_string()),
|
||||
message: None,
|
||||
};
|
||||
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)));
|
||||
}
|
||||
};
|
||||
|
||||
match parse_stream_data(&buffer) {
|
||||
Ok(StreamMessage::Content(texts)) => {
|
||||
if !first_chunk_received {
|
||||
first_chunk_time = format_time_ms(start_time.elapsed().as_secs_f64());
|
||||
first_chunk_received = true;
|
||||
// 处理所有消息
|
||||
for message in messages {
|
||||
match message {
|
||||
StreamMessage::Content(text) => {
|
||||
if first_chunk_time.is_none() {
|
||||
first_chunk_time = Some(start_time.elapsed().as_secs_f64());
|
||||
}
|
||||
for text in texts {
|
||||
full_text.push_str(&text);
|
||||
}
|
||||
buffer.clear();
|
||||
full_text.push_str(&text);
|
||||
}
|
||||
Ok(StreamMessage::Incomplete) => continue,
|
||||
Ok(StreamMessage::Debug(debug_prompt)) => {
|
||||
prompt = Some(debug_prompt);
|
||||
buffer.clear();
|
||||
}
|
||||
Ok(StreamMessage::StreamStart) | Ok(StreamMessage::StreamEnd) => {
|
||||
buffer.clear();
|
||||
}
|
||||
Err(StreamError::ChatError(error)) => {
|
||||
let error = error.to_error_response();
|
||||
// 更新请求日志为失败
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
if let Some(log) = state
|
||||
.request_logs
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.error = Some(error.native_code());
|
||||
log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64());
|
||||
state.error_requests += 1;
|
||||
StreamMessage::Debug(debug_prompt) => {
|
||||
if let Ok(mut state) = state.try_lock() {
|
||||
if let Some(last_log) = state.request_logs.last_mut() {
|
||||
last_log.prompt = Some(debug_prompt);
|
||||
}
|
||||
}
|
||||
return Err((error.status_code(), Json(error.to_common())));
|
||||
}
|
||||
Err(_) => {
|
||||
buffer.clear();
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,11 +708,8 @@ pub async fn handle_chat(
|
||||
.rev()
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.status = STATUS_FAILED;
|
||||
log.status = LogStatus::Failed;
|
||||
log.error = Some("Empty response received".to_string());
|
||||
if let Some(p) = prompt {
|
||||
log.prompt = Some(p);
|
||||
}
|
||||
state.error_requests += 1;
|
||||
}
|
||||
}
|
||||
@@ -761,9 +751,8 @@ pub async fn handle_chat(
|
||||
.find(|log| log.id == current_id)
|
||||
{
|
||||
log.timing.total = total_time;
|
||||
log.timing.first = Some(first_chunk_time);
|
||||
log.prompt = prompt;
|
||||
log.status = STATUS_SUCCESS;
|
||||
log.timing.first = first_chunk_time;
|
||||
log.status = LogStatus::Success;
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user