diff --git a/Cargo.lock b/Cargo.lock index 1685b3e..bf76934 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -361,7 +361,7 @@ dependencies = [ [[package]] name = "cursor-api" -version = "0.1.3-rc.4.1" +version = "0.1.3-rc.4.2" dependencies = [ "axum", "base64", diff --git a/Cargo.toml b/Cargo.toml index e40b3ee..0108b28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cursor-api" -version = "0.1.3-rc.4.1" +version = "0.1.3-rc.4.2" edition = "2021" authors = ["wisdgod "] description = "OpenAI format compatibility layer for the Cursor API" diff --git a/src/chat/service.rs b/src/chat/service.rs index 9a8ef35..e7a431a 100644 --- a/src/chat/service.rs +++ b/src/chat/service.rs @@ -24,7 +24,7 @@ use crate::{ model::{error::ChatError, userinfo::MembershipType, ApiStatus, ErrorResponse}, utils::{ format_time_ms, from_base64, get_token_profile, tokeninfo_to_token, - validate_token_and_checksum, + validate_token_and_checksum, TrimNewlines as _, }, }, }; @@ -410,15 +410,13 @@ pub async fn handle_chat( for message in messages { match message { StreamMessage::Content(text) => { - // 记录首字时间(如果还未记录) - if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() { - if first_time.is_none() { + let is_first = ctx.is_start.load(Ordering::SeqCst); + if is_first { + if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() { *first_time = Some(ctx.start_time.elapsed().as_secs_f64()); } } - let is_first = ctx.is_start.load(Ordering::SeqCst); - let response = ChatResponse { id: ctx.response_id.to_string(), object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), @@ -433,12 +431,16 @@ pub async fn handle_chat( message: None, delta: Some(Delta { role: if is_first { - ctx.is_start.store(false, Ordering::SeqCst); Some(Role::Assistant) } else { None }, - content: Some(text), + content: if is_first { + ctx.is_start.store(false, Ordering::SeqCst); + Some(text.trim_leading_newlines()) + } else { + Some(text) + }, }), finish_reason: None, }], @@ -528,7 +530,8 @@ pub async fn handle_chat( { log.status = LogStatus::Failed; log.error = Some(error_response.native_code()); - log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64()); + log.timing.total = + format_time_ms(start_time.elapsed().as_secs_f64()); state.error_requests += 1; } } @@ -562,7 +565,9 @@ pub async fn handle_chat( } return Err(( StatusCode::INTERNAL_SERVER_ERROR, - Json(ChatError::RequestFailed("Empty stream response".to_string()).to_json()), + Json( + ChatError::RequestFailed("Empty stream response".to_string()).to_json(), + ), )); } } diff --git a/src/chat/stream/decoder.rs b/src/chat/stream/decoder.rs index de03472..f2926d4 100644 --- a/src/chat/stream/decoder.rs +++ b/src/chat/stream/decoder.rs @@ -39,7 +39,7 @@ impl ToMarkdown for BTreeMap { } } -#[derive(PartialEq, Clone)] +#[derive(PartialEq, Clone, Debug)] pub enum StreamMessage { // 调试 Debug(String), @@ -65,6 +65,7 @@ impl StreamMessage { pub struct StreamDecoder { buffer: Vec, first_result: Option>, + first_result_ready: bool, first_result_taken: bool, } @@ -73,6 +74,7 @@ impl StreamDecoder { Self { buffer: Vec::new(), first_result: None, + first_result_ready: false, first_result_taken: false, } } @@ -93,7 +95,7 @@ impl StreamDecoder { } pub fn is_first_result_ready(&self) -> bool { - self.first_result.is_some() && self.buffer.is_empty() && !self.first_result_taken + self.first_result_ready } pub fn decode(&mut self, data: &[u8], convert_web_ref: bool) -> Result, StreamError> { @@ -150,10 +152,13 @@ impl StreamDecoder { if !self.first_result_taken && !messages.is_empty() { if self.first_result.is_none() { self.first_result = Some(messages.clone()); - } else { + } else if !self.first_result_ready { self.first_result.as_mut().unwrap().extend(messages.clone()); } } + if !self.first_result_ready { + self.first_result_ready = self.first_result.is_some() && self.buffer.is_empty() && !self.first_result_taken; + } Ok(messages) } diff --git a/src/common/utils.rs b/src/common/utils.rs index 7d85634..94fb836 100644 --- a/src/common/utils.rs +++ b/src/common/utils.rs @@ -48,6 +48,24 @@ pub fn parse_usize_from_env(key: &str, default: usize) -> usize { .unwrap_or(default) } +pub trait TrimNewlines { + fn trim_leading_newlines(self) -> Self; +} + +impl TrimNewlines for String { + #[inline(always)] + fn trim_leading_newlines(mut self) -> Self { + if self.as_bytes().get(..2) == Some(b"\n\n".as_slice()) { + unsafe { + let vec = self.as_mut_vec(); + vec.copy_within(2.., 0); + vec.truncate(vec.len() - 2); + } + } + self + } +} + pub async fn get_token_profile(auth_token: &str) -> Option { let user_id = extract_user_id(auth_token)?; diff --git a/src/main.rs b/src/main.rs index 13a85e0..2d0658a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -39,14 +39,14 @@ use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer}; #[tokio::main] async fn main() { // 设置自定义 panic hook - // std::panic::set_hook(Box::new(|info| { - // // std::env::set_var("RUST_BACKTRACE", "1"); - // if let Some(msg) = info.payload().downcast_ref::() { - // eprintln!("{}", msg); - // } else if let Some(msg) = info.payload().downcast_ref::<&str>() { - // eprintln!("{}", msg); - // } - // })); + std::panic::set_hook(Box::new(|info| { + // std::env::set_var("RUST_BACKTRACE", "1"); + if let Some(msg) = info.payload().downcast_ref::() { + eprintln!("{}", msg); + } else if let Some(msg) = info.payload().downcast_ref::<&str>() { + eprintln!("{}", msg); + } + })); // 加载环境变量 dotenvy::dotenv().ok();