diff --git a/Cargo.lock b/Cargo.lock index e0dae0d..1685b3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -361,7 +361,7 @@ dependencies = [ [[package]] name = "cursor-api" -version = "0.1.3-rc.4" +version = "0.1.3-rc.4.1" dependencies = [ "axum", "base64", diff --git a/Cargo.toml b/Cargo.toml index fe760ca..e40b3ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cursor-api" -version = "0.1.3-rc.4" +version = "0.1.3-rc.4.1" edition = "2021" authors = ["wisdgod "] description = "OpenAI format compatibility layer for the Cursor API" diff --git a/src/chat/service.rs b/src/chat/service.rs index 165ebd0..9a8ef35 100644 --- a/src/chat/service.rs +++ b/src/chat/service.rs @@ -508,48 +508,15 @@ pub async fn handle_chat( response_data } - let stream = { - let mut stream = response.bytes_stream(); - - // 处理第一个chunk并获取first_result - while !decoder.lock().await.is_first_result_ready() { - match stream.next().await { - Some(first_chunk) => { - let chunk = first_chunk.map_err(|e| { - let error_message = format!("Failed to read response chunk: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed(error_message).to_json()), - ) - })?; - - if let Err(StreamError::ChatError(error)) = - decoder.lock().await.decode(&chunk, convert_web_ref) - { - let error_response = error.to_error_response(); - // 更新请求日志为失败 - { - let mut state = state.lock().await; - if let Some(log) = state - .request_logs - .iter_mut() - .rev() - .find(|log| log.id == current_id) - { - log.status = LogStatus::Failed; - log.error = Some(error_response.native_code()); - log.timing.total = - format_time_ms(start_time.elapsed().as_secs_f64()); - state.error_requests += 1; - } - } - return Err(( - error_response.status_code(), - Json(error_response.to_common()), - )); - } - } - None => { + // 首先处理stream直到获得第一个结果 + let mut stream = response.bytes_stream(); + while !decoder.lock().await.is_first_result_ready() { + match stream.next().await { + Some(Ok(chunk)) => { + if let Err(StreamError::ChatError(error)) = + decoder.lock().await.decode(&chunk, convert_web_ref) + { + let error_response = error.to_error_response(); // 更新请求日志为失败 { let mut state = state.lock().await; @@ -560,77 +527,102 @@ pub async fn handle_chat( .find(|log| log.id == current_id) { log.status = LogStatus::Failed; - log.error = Some("Empty stream response".to_string()); + log.error = Some(error_response.native_code()); + log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64()); state.error_requests += 1; } } return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json( - ChatError::RequestFailed("Empty stream response".to_string()) - .to_json(), - ), + error_response.status_code(), + Json(error_response.to_common()), )); } } + Some(Err(e)) => { + let error_message = format!("Failed to read response chunk: {}", e); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::RequestFailed(error_message).to_json()), + )); + } + None => { + // 更新请求日志为失败 + { + let mut state = state.lock().await; + if let Some(log) = state + .request_logs + .iter_mut() + .rev() + .find(|log| log.id == current_id) + { + log.status = LogStatus::Failed; + log.error = Some("Empty stream response".to_string()); + state.error_requests += 1; + } + } + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ChatError::RequestFailed("Empty stream response".to_string()).to_json()), + )); + } } + } - stream.then({ + // 处理后续的stream + let stream = stream.then({ + let decoder = decoder.clone(); + let response_id = response_id.clone(); + let model = request.model.clone(); + let is_start = is_start.clone(); + let first_chunk_time = first_chunk_time.clone(); + let state = state.clone(); + + move |chunk| { let decoder = decoder.clone(); let response_id = response_id.clone(); - let model = request.model.clone(); + let model = model.clone(); let is_start = is_start.clone(); let first_chunk_time = first_chunk_time.clone(); let state = state.clone(); - move |chunk| { - let decoder = decoder.clone(); - let response_id = response_id.clone(); - let model = model.clone(); - let is_start = is_start.clone(); - let first_chunk_time = first_chunk_time.clone(); - let state = state.clone(); + async move { + let chunk = chunk.unwrap_or_default(); - async move { - let chunk = chunk.unwrap_or_default(); + let ctx = MessageProcessContext { + response_id: &response_id, + model: &model, + is_start: &is_start, + first_chunk_time: &first_chunk_time, + start_time, + state: &state, + current_id, + }; - let ctx = MessageProcessContext { - response_id: &response_id, - model: &model, - is_start: &is_start, - first_chunk_time: &first_chunk_time, - start_time, - state: &state, - current_id, - }; + // 使用decoder处理chunk + let messages = match decoder.lock().await.decode(&chunk, convert_web_ref) { + Ok(msgs) => msgs, + Err(e) => { + eprintln!("[警告] Stream error: {}", e); + return Ok::<_, Infallible>(Bytes::new()); + } + }; - let messages = match decoder.lock().await.decode(&chunk, convert_web_ref) { - Ok(msgs) => msgs, - Err(e) => { - eprintln!("[警告] Stream error: {}", e); - return Ok::<_, Infallible>(Bytes::new()); - } - }; + let mut response_data = String::new(); - // let mut response_data = String::new(); - - // if let Some(first_msg) = decoder.lock().await.take_first_result() { - // let first_response = process_messages(first_msg, &ctx).await; - // if !first_response.is_empty() { - // response_data.push_str(&first_response); - // } - // } - - let response_data = process_messages(messages, &ctx).await; - // if !current_response.is_empty() { - // response_data.push_str(¤t_response); - // } - - Ok(Bytes::from(response_data)) + if let Some(first_msg) = decoder.lock().await.take_first_result() { + let first_response = process_messages(first_msg, &ctx).await; + response_data.push_str(&first_response); } + + let current_response = process_messages(messages, &ctx).await; + if !current_response.is_empty() { + response_data.push_str(¤t_response); + } + + Ok(Bytes::from(response_data)) } - }) - }; + } + }); Ok(Response::builder() .header("Cache-Control", "no-cache") diff --git a/src/chat/stream/decoder.rs b/src/chat/stream/decoder.rs index 9e80e20..de03472 100644 --- a/src/chat/stream/decoder.rs +++ b/src/chat/stream/decoder.rs @@ -77,15 +77,15 @@ impl StreamDecoder { } } - // pub fn take_first_result(&mut self) -> Option> { - // if !self.buffer.is_empty() { - // return None; - // } - // if self.first_result.is_some() { - // self.first_result_taken = true; - // } - // self.first_result.take() - // } + pub fn take_first_result(&mut self) -> Option> { + if !self.buffer.is_empty() { + return None; + } + if self.first_result.is_some() { + self.first_result_taken = true; + } + self.first_result.take() + } #[cfg(test)] fn is_incomplete(&self) -> bool {