v0.1.3-rc.4.2

This commit is contained in:
wisdgod
2025-01-28 15:00:27 +08:00
parent cb244d7282
commit 22121f3beb
6 changed files with 51 additions and 23 deletions

2
Cargo.lock generated
View File

@@ -361,7 +361,7 @@ dependencies = [
[[package]]
name = "cursor-api"
version = "0.1.3-rc.4.1"
version = "0.1.3-rc.4.2"
dependencies = [
"axum",
"base64",

View File

@@ -1,6 +1,6 @@
[package]
name = "cursor-api"
version = "0.1.3-rc.4.1"
version = "0.1.3-rc.4.2"
edition = "2021"
authors = ["wisdgod <nav@wisdgod.com>"]
description = "OpenAI format compatibility layer for the Cursor API"

View File

@@ -24,7 +24,7 @@ use crate::{
model::{error::ChatError, userinfo::MembershipType, ApiStatus, ErrorResponse},
utils::{
format_time_ms, from_base64, get_token_profile, tokeninfo_to_token,
validate_token_and_checksum,
validate_token_and_checksum, TrimNewlines as _,
},
},
};
@@ -410,15 +410,13 @@ pub async fn handle_chat(
for message in messages {
match message {
StreamMessage::Content(text) => {
// 记录首字时间(如果还未记录)
if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() {
if first_time.is_none() {
let is_first = ctx.is_start.load(Ordering::SeqCst);
if is_first {
if let Ok(mut first_time) = ctx.first_chunk_time.try_lock() {
*first_time = Some(ctx.start_time.elapsed().as_secs_f64());
}
}
let is_first = ctx.is_start.load(Ordering::SeqCst);
let response = ChatResponse {
id: ctx.response_id.to_string(),
object: OBJECT_CHAT_COMPLETION_CHUNK.to_string(),
@@ -433,12 +431,16 @@ pub async fn handle_chat(
message: None,
delta: Some(Delta {
role: if is_first {
ctx.is_start.store(false, Ordering::SeqCst);
Some(Role::Assistant)
} else {
None
},
content: Some(text),
content: if is_first {
ctx.is_start.store(false, Ordering::SeqCst);
Some(text.trim_leading_newlines())
} else {
Some(text)
},
}),
finish_reason: None,
}],
@@ -528,7 +530,8 @@ pub async fn handle_chat(
{
log.status = LogStatus::Failed;
log.error = Some(error_response.native_code());
log.timing.total = format_time_ms(start_time.elapsed().as_secs_f64());
log.timing.total =
format_time_ms(start_time.elapsed().as_secs_f64());
state.error_requests += 1;
}
}
@@ -562,7 +565,9 @@ pub async fn handle_chat(
}
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ChatError::RequestFailed("Empty stream response".to_string()).to_json()),
Json(
ChatError::RequestFailed("Empty stream response".to_string()).to_json(),
),
));
}
}

View File

@@ -39,7 +39,7 @@ impl ToMarkdown for BTreeMap<String, String> {
}
}
#[derive(PartialEq, Clone)]
#[derive(PartialEq, Clone, Debug)]
pub enum StreamMessage {
// 调试
Debug(String),
@@ -65,6 +65,7 @@ impl StreamMessage {
pub struct StreamDecoder {
buffer: Vec<u8>,
first_result: Option<Vec<StreamMessage>>,
first_result_ready: bool,
first_result_taken: bool,
}
@@ -73,6 +74,7 @@ impl StreamDecoder {
Self {
buffer: Vec::new(),
first_result: None,
first_result_ready: false,
first_result_taken: false,
}
}
@@ -93,7 +95,7 @@ impl StreamDecoder {
}
pub fn is_first_result_ready(&self) -> bool {
self.first_result.is_some() && self.buffer.is_empty() && !self.first_result_taken
self.first_result_ready
}
pub fn decode(&mut self, data: &[u8], convert_web_ref: bool) -> Result<Vec<StreamMessage>, StreamError> {
@@ -150,10 +152,13 @@ impl StreamDecoder {
if !self.first_result_taken && !messages.is_empty() {
if self.first_result.is_none() {
self.first_result = Some(messages.clone());
} else {
} else if !self.first_result_ready {
self.first_result.as_mut().unwrap().extend(messages.clone());
}
}
if !self.first_result_ready {
self.first_result_ready = self.first_result.is_some() && self.buffer.is_empty() && !self.first_result_taken;
}
Ok(messages)
}

View File

@@ -48,6 +48,24 @@ pub fn parse_usize_from_env(key: &str, default: usize) -> usize {
.unwrap_or(default)
}
pub trait TrimNewlines {
fn trim_leading_newlines(self) -> Self;
}
impl TrimNewlines for String {
#[inline(always)]
fn trim_leading_newlines(mut self) -> Self {
if self.as_bytes().get(..2) == Some(b"\n\n".as_slice()) {
unsafe {
let vec = self.as_mut_vec();
vec.copy_within(2.., 0);
vec.truncate(vec.len() - 2);
}
}
self
}
}
pub async fn get_token_profile(auth_token: &str) -> Option<TokenProfile> {
let user_id = extract_user_id(auth_token)?;

View File

@@ -39,14 +39,14 @@ use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer};
#[tokio::main]
async fn main() {
// 设置自定义 panic hook
// std::panic::set_hook(Box::new(|info| {
// // std::env::set_var("RUST_BACKTRACE", "1");
// if let Some(msg) = info.payload().downcast_ref::<String>() {
// eprintln!("{}", msg);
// } else if let Some(msg) = info.payload().downcast_ref::<&str>() {
// eprintln!("{}", msg);
// }
// }));
std::panic::set_hook(Box::new(|info| {
// std::env::set_var("RUST_BACKTRACE", "1");
if let Some(msg) = info.payload().downcast_ref::<String>() {
eprintln!("{}", msg);
} else if let Some(msg) = info.payload().downcast_ref::<&str>() {
eprintln!("{}", msg);
}
}));
// 加载环境变量
dotenvy::dotenv().ok();