v0.1.3-rc.4.2

This commit is contained in:
wisdgod
2025-01-28 15:00:27 +08:00
parent cb244d7282
commit 22121f3beb
6 changed files with 51 additions and 23 deletions

2
Cargo.lock generated
View File

@@ -361,7 +361,7 @@ dependencies = [
[[package]] [[package]]
name = "cursor-api" name = "cursor-api"
version = "0.1.3-rc.4.1" version = "0.1.3-rc.4.2"
dependencies = [ dependencies = [
"axum", "axum",
"base64", "base64",

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "cursor-api" name = "cursor-api"
version = "0.1.3-rc.4.1" version = "0.1.3-rc.4.2"
edition = "2021" edition = "2021"
authors = ["wisdgod <nav@wisdgod.com>"] authors = ["wisdgod <nav@wisdgod.com>"]
description = "OpenAI format compatibility layer for the Cursor API" description = "OpenAI format compatibility layer for the Cursor API"

View File

@@ -24,7 +24,7 @@ use crate::{
model::{error::ChatError, userinfo::MembershipType, ApiStatus, ErrorResponse}, model::{error::ChatError, userinfo::MembershipType, ApiStatus, ErrorResponse},
utils::{ utils::{
format_time_ms, from_base64, get_token_profile, tokeninfo_to_token, format_time_ms, from_base64, get_token_profile, tokeninfo_to_token,
validate_token_and_checksum, validate_token_and_checksum, TrimNewlines as _,
}, },
}, },
}; };
@@ -410,15 +410,13 @@ pub async fn handle_chat(
for message in messages { for message in messages {
match message { match message {
StreamMessage::Content(text) => { StreamMessage::Content(text) => {
// 记录首字时间(如果还未记录) let is_first = ctx.is_start.load(Ordering::SeqCst);
if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() { if is_first {
if first_time.is_none() { if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() {
*first_time = Some(ctx.start_time.elapsed().as_secs_f64()); *first_time = Some(ctx.start_time.elapsed().as_secs_f64());
} }
} }
let is_first = ctx.is_start.load(Ordering::SeqCst);
let response = ChatResponse { let response = ChatResponse {
id: ctx.response_id.to_string(), id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(), object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
@@ -433,12 +431,16 @@ pub async fn handle_chat(
message: None, message: None,
delta: Some(Delta { delta: Some(Delta {
role: if is_first { role: if is_first {
ctx.is_start.store(false, Ordering::SeqCst);
Some(Role::Assistant) Some(Role::Assistant)
} else { } else {
None None
}, },
content: Some(text), content: if is_first {
ctx.is_start.store(false, Ordering::SeqCst);
Some(text.trim_leading_newlines())
} else {
Some(text)
},
}), }),
finish_reason: None, finish_reason: None,
}], }],
@@ -528,7 +530,8 @@ pub async fn handle_chat(
{ {
log.status = LogStatus::Failed; log.status = LogStatus::Failed;
log.error = Some(error_response.native_code()); log.error = Some(error_response.native_code());
log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64()); log.timing.total =
format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1; state.error_requests += 1;
} }
} }
@@ -562,7 +565,9 @@ pub async fn handle_chat(
} }
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(ChatError::RequestFailed("Empty stream response".to_string()).to_json()), Json(
ChatError::RequestFailed("Empty stream response".to_string()).to_json(),
),
)); ));
} }
} }

View File

@@ -39,7 +39,7 @@ impl ToMarkdown for BTreeMap<String, String> {
} }
} }
#[derive(PartialEq, Clone)] #[derive(PartialEq, Clone, Debug)]
pub enum StreamMessage { pub enum StreamMessage {
// 调试 // 调试
Debug(String), Debug(String),
@@ -65,6 +65,7 @@ impl StreamMessage {
pub struct StreamDecoder { pub struct StreamDecoder {
buffer: Vec<u8>, buffer: Vec<u8>,
first_result: Option<Vec<StreamMessage>>, first_result: Option<Vec<StreamMessage>>,
first_result_ready: bool,
first_result_taken: bool, first_result_taken: bool,
} }
@@ -73,6 +74,7 @@ impl StreamDecoder {
Self { Self {
buffer: Vec::new(), buffer: Vec::new(),
first_result: None, first_result: None,
first_result_ready: false,
first_result_taken: false, first_result_taken: false,
} }
} }
@@ -93,7 +95,7 @@ impl StreamDecoder {
} }
pub fn is_first_result_ready(&self) -> bool { pub fn is_first_result_ready(&self) -> bool {
self.first_result.is_some() && self.buffer.is_empty() && !self.first_result_taken self.first_result_ready
} }
pub fn decode(&mut self, data: &[u8], convert_web_ref: bool) -> Result<Vec<StreamMessage>, StreamError> { pub fn decode(&mut self, data: &[u8], convert_web_ref: bool) -> Result<Vec<StreamMessage>, StreamError> {
@@ -150,10 +152,13 @@ impl StreamDecoder {
if !self.first_result_taken && !messages.is_empty() { if !self.first_result_taken && !messages.is_empty() {
if self.first_result.is_none() { if self.first_result.is_none() {
self.first_result = Some(messages.clone()); self.first_result = Some(messages.clone());
} else { } else if !self.first_result_ready {
self.first_result.as_mut().unwrap().extend(messages.clone()); self.first_result.as_mut().unwrap().extend(messages.clone());
} }
} }
if !self.first_result_ready {
self.first_result_ready = self.first_result.is_some() && self.buffer.is_empty() && !self.first_result_taken;
}
Ok(messages) Ok(messages)
} }

View File

@@ -48,6 +48,24 @@ pub fn parse_usize_from_env(key: &str, default: usize) -> usize {
.unwrap_or(default) .unwrap_or(default)
} }
pub trait TrimNewlines {
fn trim_leading_newlines(self) -> Self;
}
impl TrimNewlines for String {
#[inline(always)]
fn trim_leading_newlines(mut self) -> Self {
if self.as_bytes().get(..2) == Some(b"\n\n".as_slice()) {
unsafe {
let vec = self.as_mut_vec();
vec.copy_within(2.., 0);
vec.truncate(vec.len() - 2);
}
}
self
}
}
pub async fn get_token_profile(auth_token: &str) -> Option<TokenProfile> { pub async fn get_token_profile(auth_token: &str) -> Option<TokenProfile> {
let user_id = extract_user_id(auth_token)?; let user_id = extract_user_id(auth_token)?;

View File

@@ -39,14 +39,14 @@ use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// 设置自定义 panic hook // 设置自定义 panic hook
// std::panic::set_hook(Box::new(|info| { std::panic::set_hook(Box::new(|info| {
// // std::env::set_var("RUST_BACKTRACE", "1"); // std::env::set_var("RUST_BACKTRACE", "1");
// if let Some(msg) = info.payload().downcast_ref::<String>() { if let Some(msg) = info.payload().downcast_ref::<String>() {
// eprintln!("{}", msg); eprintln!("{}", msg);
// } else if let Some(msg) = info.payload().downcast_ref::<&str>() { } else if let Some(msg) = info.payload().downcast_ref::<&str>() {
// eprintln!("{}", msg); eprintln!("{}", msg);
// } }
// })); }));
// 加载环境变量 // 加载环境变量
dotenvy::dotenv().ok(); dotenvy::dotenv().ok();