v0.1.3-rc.4.1

This commit is contained in:
wisdgod
2025-01-28 08:46:13 +08:00
parent 0d8035fd02
commit cb244d7282
4 changed files with 95 additions and 103 deletions

2
Cargo.lock generated
View File

@@ -361,7 +361,7 @@ dependencies = [
[[package]]
name = "cursor-api"
version = "0.1.3-rc.4"
version = "0.1.3-rc.4.1"
dependencies = [
"axum",
"base64",

View File

@@ -1,6 +1,6 @@
[package]
name = "cursor-api"
version = "0.1.3-rc.4"
version = "0.1.3-rc.4.1"
edition = "2021"
authors = ["wisdgod <nav@wisdgod.com>"]
description = "OpenAI format compatibility layer for the Cursor API"

View File

@@ -508,48 +508,15 @@ pub async fn handle_chat(
response_data
}
let stream = {
let mut stream = response.bytes_stream();
// 处理第一个chunk并获取first_result
while !decoder.lock().await.is_first_result_ready() {
match stream.next().await {
Some(first_chunk) => {
let chunk = first_chunk.map_err(|e| {
let error_message = format!("Failed to read response chunk: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ChatError::RequestFailed(error_message).to_json()),
)
})?;
if let Err(StreamError::ChatError(error)) =
decoder.lock().await.decode(&chunk, convert_web_ref)
{
let error_response = error.to_error_response();
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.error = Some(error_response.native_code());
log.timing.total =
format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
}
}
return Err((
error_response.status_code(),
Json(error_response.to_common()),
));
}
}
None => {
// 首先处理stream直到获得第一个结果
let mut stream = response.bytes_stream();
while !decoder.lock().await.is_first_result_ready() {
match stream.next().await {
Some(Ok(chunk)) => {
if let Err(StreamError::ChatError(error)) =
decoder.lock().await.decode(&chunk, convert_web_ref)
{
let error_response = error.to_error_response();
// 更新请求日志为失败
{
let mut state = state.lock().await;
@@ -560,77 +527,102 @@ pub async fn handle_chat(
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.error = Some("Empty stream response".to_string());
log.error = Some(error_response.native_code());
log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
}
}
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(
ChatError::RequestFailed("Empty stream response".to_string())
.to_json(),
),
error_response.status_code(),
Json(error_response.to_common()),
));
}
}
Some(Err(e)) => {
let error_message = format!("Failed to read response chunk: {}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ChatError::RequestFailed(error_message).to_json()),
));
}
None => {
// 更新请求日志为失败
{
let mut state = state.lock().await;
if let Some(log) = state
.request_logs
.iter_mut()
.rev()
.find(|log| log.id == current_id)
{
log.status = LogStatus::Failed;
log.error = Some("Empty stream response".to_string());
state.error_requests += 1;
}
}
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ChatError::RequestFailed("Empty stream response".to_string()).to_json()),
));
}
}
}
stream.then({
// 处理后续的stream
let stream = stream.then({
let decoder = decoder.clone();
let response_id = response_id.clone();
let model = request.model.clone();
let is_start = is_start.clone();
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
move |chunk| {
let decoder = decoder.clone();
let response_id = response_id.clone();
let model = request.model.clone();
let model = model.clone();
let is_start = is_start.clone();
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
move |chunk| {
let decoder = decoder.clone();
let response_id = response_id.clone();
let model = model.clone();
let is_start = is_start.clone();
let first_chunk_time = first_chunk_time.clone();
let state = state.clone();
async move {
let chunk = chunk.unwrap_or_default();
async move {
let chunk = chunk.unwrap_or_default();
let ctx = MessageProcessContext {
response_id: &response_id,
model: &model,
is_start: &is_start,
first_chunk_time: &first_chunk_time,
start_time,
state: &state,
current_id,
};
let ctx = MessageProcessContext {
response_id: &response_id,
model: &model,
is_start: &is_start,
first_chunk_time: &first_chunk_time,
start_time,
state: &state,
current_id,
};
// 使用decoder处理chunk
let messages = match decoder.lock().await.decode(&chunk, convert_web_ref) {
Ok(msgs) => msgs,
Err(e) => {
eprintln!("[警告] Stream error: {}", e);
return Ok::<_, Infallible>(Bytes::new());
}
};
let messages = match decoder.lock().await.decode(&chunk, convert_web_ref) {
Ok(msgs) => msgs,
Err(e) => {
eprintln!("[警告] Stream error: {}", e);
return Ok::<_, Infallible>(Bytes::new());
}
};
let mut response_data = String::new();
// let mut response_data = String::new();
// if let Some(first_msg) = decoder.lock().await.take_first_result() {
// let first_response = process_messages(first_msg, &ctx).await;
// if !first_response.is_empty() {
// response_data.push_str(&first_response);
// }
// }
let response_data = process_messages(messages, &ctx).await;
// if !current_response.is_empty() {
// response_data.push_str(&current_response);
// }
Ok(Bytes::from(response_data))
if let Some(first_msg) = decoder.lock().await.take_first_result() {
let first_response = process_messages(first_msg, &ctx).await;
response_data.push_str(&first_response);
}
let current_response = process_messages(messages, &ctx).await;
if !current_response.is_empty() {
response_data.push_str(&current_response);
}
Ok(Bytes::from(response_data))
}
})
};
}
});
Ok(Response::builder()
.header("Cache-Control", "no-cache")

View File

@@ -77,15 +77,15 @@ impl StreamDecoder {
}
}
// pub fn take_first_result(&mut self) -> Option<Vec<StreamMessage>> {
// if !self.buffer.is_empty() {
// return None;
// }
// if self.first_result.is_some() {
// self.first_result_taken = true;
// }
// self.first_result.take()
// }
pub fn take_first_result(&mut self) -> Option<Vec<StreamMessage>> {
if !self.buffer.is_empty() {
return None;
}
if self.first_result.is_some() {
self.first_result_taken = true;
}
self.first_result.take()
}
#[cfg(test)]
fn is_incomplete(&self) -> bool {