修复一些问题,作为v0.1.3-rc.4的补充

This commit is contained in:
wisdgod
2025-01-27 17:39:31 +08:00
parent c58f2697f0
commit 00a6980da9
8 changed files with 162 additions and 141 deletions

View File

@@ -491,8 +491,13 @@ pub async fn handle_chat(
}
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = ctx.state.try_lock() {
if let Some(last_log) = state.request_logs.last_mut() {
last_log.prompt = Some(debug_prompt);
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == ctx.current_id)
{
log.prompt = Some(debug_prompt);
}
}
}
@@ -507,7 +512,7 @@ pub async fn handle_chat(
let mut stream = response.bytes_stream();
// 处理第一个chunk并获取first_result
while decoder.lock().await.has_no_first_result() {
while !decoder.lock().await.is_first_result_ready() {
match stream.next().await {
Some(first_chunk) => {
let chunk = first_chunk.map_err(|e| {
@@ -642,9 +647,8 @@ pub async fn handle_chat(
let mut decoder = StreamDecoder::new();
let mut full_text = String::with_capacity(1024);
let mut stream = response.bytes_stream();
let mut all_chunks = Vec::new();
// 收集所有的chunks
// 逐个处理chunks
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| {
let error_message = format!("Failed to read response chunk: {}", e);
@@ -653,47 +657,50 @@ pub async fn handle_chat(
Json(ChatError::RequestFailed(error_message).to_json()),
)
})?;
all_chunks.extend(chunk);
}
// 一次性解码所有数据
let messages = match decoder.decode(&all_chunks, convert_web_ref) {
Ok(msgs) => msgs,
Err(StreamError::ChatError(error)) => {
let error_response = error.to_error_response();
return Err((
error_response.status_code(),
Json(error_response.to_common()),
));
}
Err(e) => {
let error_response = ErrorResponse {
status: ApiStatus::Error,
code: Some(500),
error: Some(e.to_string()),
message: None,
};
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)));
}
};
// 处理所有消息
for message in messages {
match message {
StreamMessage::Content(text) => {
if first_chunk_time.is_none() {
first_chunk_time = Some(start_time.elapsed().as_secs_f64());
}
full_text.push_str(&text);
}
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = state.try_lock() {
if let Some(last_log) = state.request_logs.last_mut() {
last_log.prompt = Some(debug_prompt);
// 立即处理当前chunk
match decoder.decode(&chunk, convert_web_ref) {
Ok(messages) => {
for message in messages {
match message {
StreamMessage::Content(text) => {
if first_chunk_time.is_none() {
first_chunk_time = Some(start_time.elapsed().as_secs_f64());
}
full_text.push_str(&text);
}
StreamMessage::Debug(debug_prompt) => {
if let Ok(mut state) = state.try_lock() {
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.prompt = Some(debug_prompt);
}
}
}
_ => {}
}
}
}
_ => {}
Err(StreamError::ChatError(error)) => {
let error_response = error.to_error_response();
return Err((
error_response.status_code(),
Json(error_response.to_common()),
));
}
Err(e) => {
let error_response = ErrorResponse {
status: ApiStatus::Error,
code: Some(500),
error: Some(e.to_string()),
message: None,
};
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)));
}
}
}