mirror of
https://github.com/wisdgod/cursor-api.git
synced 2025-12-24 13:38:01 +08:00
468 lines
15 KiB
Rust
468 lines
15 KiB
Rust
use crate::chat::{
|
||
aiserver::v1::{StreamChatResponse, WebReference},
|
||
error::{ChatError, StreamError},
|
||
};
|
||
use crate::common::utils::InstantExt as _;
|
||
use flate2::read::GzDecoder;
|
||
use prost::Message;
|
||
use std::io::Read;
|
||
use std::time::Instant;
|
||
|
||
// 解压gzip数据
|
||
fn decompress_gzip(data: &[u8]) -> Option<Vec<u8>> {
|
||
let mut decoder = GzDecoder::new(data);
|
||
let mut decompressed = Vec::new();
|
||
|
||
match decoder.read_to_end(&mut decompressed) {
|
||
Ok(_) => Some(decompressed),
|
||
Err(_) => {
|
||
// println!("gzip解压失败: {}", e);
|
||
None
|
||
}
|
||
}
|
||
}
|
||
|
||
pub trait ToMarkdown {
|
||
fn to_markdown(&self) -> String;
|
||
}
|
||
|
||
impl ToMarkdown for Vec<WebReference> {
|
||
#[inline]
|
||
fn to_markdown(&self) -> String {
|
||
if self.is_empty() {
|
||
return String::new();
|
||
}
|
||
|
||
let mut result = String::from("WebReferences:\n");
|
||
for (i, web_ref) in self.iter().enumerate() {
|
||
result.push_str(&format!(
|
||
"{}. [{}]({})<{}>\n",
|
||
i + 1,
|
||
web_ref.title,
|
||
web_ref.url,
|
||
web_ref.chunk
|
||
));
|
||
}
|
||
result.push('\n');
|
||
result
|
||
}
|
||
}
|
||
|
||
#[derive(PartialEq, Clone)]
|
||
pub enum StreamMessage {
|
||
// 调试
|
||
Debug(String),
|
||
// 网络引用
|
||
WebReference(Vec<WebReference>),
|
||
// 内容开始标志
|
||
ContentStart,
|
||
// 消息内容
|
||
Content(String),
|
||
// 流结束标志
|
||
StreamEnd,
|
||
}
|
||
|
||
impl StreamMessage {
|
||
#[inline]
|
||
fn convert_web_ref_to_content(self) -> Self {
|
||
match self {
|
||
StreamMessage::WebReference(refs) => StreamMessage::Content(refs.to_markdown()),
|
||
other => other,
|
||
}
|
||
}
|
||
}
|
||
|
||
pub struct StreamDecoder {
|
||
// 主要数据缓冲区 (24字节)
|
||
buffer: Vec<u8>,
|
||
// 结果相关 (24字节 + 24字节)
|
||
first_result: Option<Vec<StreamMessage>>,
|
||
content_delays: Vec<(String, f64)>,
|
||
// 计数器和时间 (8字节 + 8字节)
|
||
empty_stream_count: usize,
|
||
last_content_time: Instant,
|
||
// 状态标志 (1字节 + 1字节 + 1字节)
|
||
first_result_ready: bool,
|
||
first_result_taken: bool,
|
||
has_seen_content: bool,
|
||
}
|
||
|
||
impl StreamDecoder {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
buffer: Vec::new(),
|
||
first_result: None,
|
||
content_delays: Vec::new(),
|
||
empty_stream_count: 0,
|
||
last_content_time: Instant::now(),
|
||
first_result_ready: false,
|
||
first_result_taken: false,
|
||
has_seen_content: false,
|
||
}
|
||
}
|
||
|
||
pub fn get_empty_stream_count(&self) -> usize {
|
||
self.empty_stream_count
|
||
}
|
||
|
||
pub fn reset_empty_stream_count(&mut self) {
|
||
if self.empty_stream_count > 0 {
|
||
crate::debug_println!(
|
||
"重置连续空流计数,之前的计数为: {}",
|
||
self.empty_stream_count
|
||
);
|
||
self.empty_stream_count = 0;
|
||
}
|
||
}
|
||
|
||
pub fn increment_empty_stream_count(&mut self) {
|
||
self.empty_stream_count += 1;
|
||
}
|
||
|
||
pub fn take_first_result(&mut self) -> Option<Vec<StreamMessage>> {
|
||
if !self.buffer.is_empty() {
|
||
return None;
|
||
}
|
||
if self.first_result.is_some() {
|
||
self.first_result_taken = true;
|
||
}
|
||
self.first_result.take()
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn is_incomplete(&self) -> bool {
|
||
!self.buffer.is_empty()
|
||
}
|
||
|
||
pub fn is_first_result_ready(&self) -> bool {
|
||
self.first_result_ready
|
||
}
|
||
|
||
pub fn take_content_delays(&mut self) -> Vec<(String, f64)> {
|
||
std::mem::take(&mut self.content_delays)
|
||
}
|
||
|
||
pub fn decode(
|
||
&mut self,
|
||
data: &[u8],
|
||
convert_web_ref: bool,
|
||
) -> Result<Vec<StreamMessage>, StreamError> {
|
||
if !data.is_empty() {
|
||
self.reset_empty_stream_count();
|
||
}
|
||
|
||
self.buffer.extend_from_slice(data);
|
||
|
||
if self.buffer.len() < 5 {
|
||
if self.buffer.is_empty() {
|
||
self.increment_empty_stream_count();
|
||
|
||
return Err(StreamError::EmptyStream);
|
||
}
|
||
crate::debug_println!("数据长度小于5字节,当前数据: {}", hex::encode(&self.buffer));
|
||
return Err(StreamError::DataLengthLessThan5);
|
||
}
|
||
|
||
self.reset_empty_stream_count();
|
||
|
||
let mut messages = Vec::new();
|
||
let mut offset = 0;
|
||
|
||
while offset + 5 <= self.buffer.len() {
|
||
let msg_type = self.buffer[offset];
|
||
let msg_len = u32::from_be_bytes([
|
||
self.buffer[offset + 1],
|
||
self.buffer[offset + 2],
|
||
self.buffer[offset + 3],
|
||
self.buffer[offset + 4],
|
||
]) as usize;
|
||
|
||
if msg_len == 0 {
|
||
offset += 5;
|
||
messages.push(StreamMessage::ContentStart);
|
||
continue;
|
||
}
|
||
|
||
if offset + 5 + msg_len > self.buffer.len() {
|
||
break;
|
||
}
|
||
|
||
let msg_data = &self.buffer[offset + 5..offset + 5 + msg_len];
|
||
|
||
if let Some(msg) = self.process_message(msg_type, msg_data)? {
|
||
if let StreamMessage::Content(content) = &msg {
|
||
self.has_seen_content = true;
|
||
let delay = self.last_content_time.duration_as_secs_f64();
|
||
self.content_delays.push((content.clone(), delay));
|
||
}
|
||
if convert_web_ref {
|
||
messages.push(msg.convert_web_ref_to_content());
|
||
} else {
|
||
messages.push(msg);
|
||
}
|
||
}
|
||
|
||
offset += 5 + msg_len;
|
||
}
|
||
|
||
self.buffer.drain(..offset);
|
||
|
||
if !self.first_result_taken && !messages.is_empty() {
|
||
if self.first_result.is_none() {
|
||
self.first_result = Some(std::mem::take(&mut messages));
|
||
} else if !self.first_result_ready {
|
||
if let Some(first_result) = &mut self.first_result {
|
||
first_result.append(&mut messages);
|
||
}
|
||
}
|
||
}
|
||
if !self.first_result_ready {
|
||
self.first_result_ready = self.first_result.is_some()
|
||
&& self.buffer.is_empty()
|
||
&& !self.first_result_taken
|
||
&& self.has_seen_content;
|
||
}
|
||
Ok(messages)
|
||
}
|
||
|
||
fn process_message(
|
||
&self,
|
||
msg_type: u8,
|
||
msg_data: &[u8],
|
||
) -> Result<Option<StreamMessage>, StreamError> {
|
||
match msg_type {
|
||
0 => self.handle_text_message(msg_data),
|
||
1 => self.handle_gzip_message(msg_data),
|
||
2 => self.handle_json_message(msg_data),
|
||
3 => self.handle_gzip_json_message(msg_data),
|
||
t => {
|
||
eprintln!("收到未知消息类型: {},请尝试联系开发者以获取支持", t);
|
||
crate::debug_println!("消息类型: {},消息内容: {}", t, hex::encode(msg_data));
|
||
Ok(None)
|
||
}
|
||
}
|
||
}
|
||
|
||
fn handle_text_message(&self, msg_data: &[u8]) -> Result<Option<StreamMessage>, StreamError> {
|
||
if let Ok(response) = StreamChatResponse::decode(msg_data) {
|
||
// println!("[text] StreamChatResponse [hex: {}]: {:?}", hex::encode(msg_data), response);
|
||
if !response.text.is_empty() {
|
||
Ok(Some(StreamMessage::Content(response.text)))
|
||
} else if let Some(filled_prompt) = response.filled_prompt {
|
||
Ok(Some(StreamMessage::Debug(filled_prompt)))
|
||
} else if let Some(web_citation) = response.web_citation {
|
||
Ok(Some(StreamMessage::WebReference(web_citation.references)))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
fn handle_gzip_message(&self, msg_data: &[u8]) -> Result<Option<StreamMessage>, StreamError> {
|
||
if let Some(text) = decompress_gzip(msg_data) {
|
||
if let Ok(response) = StreamChatResponse::decode(&text[..]) {
|
||
// println!("[gzip] StreamChatResponse [hex: {}]: {:?}", hex::encode(msg_data), response);
|
||
if !response.text.is_empty() {
|
||
Ok(Some(StreamMessage::Content(response.text)))
|
||
} else if let Some(filled_prompt) = response.filled_prompt {
|
||
Ok(Some(StreamMessage::Debug(filled_prompt)))
|
||
} else if let Some(web_citation) = response.web_citation {
|
||
Ok(Some(StreamMessage::WebReference(web_citation.references)))
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
} else {
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
fn handle_json_message(&self, msg_data: &[u8]) -> Result<Option<StreamMessage>, StreamError> {
|
||
if msg_data.len() == 2 {
|
||
return Ok(Some(StreamMessage::StreamEnd));
|
||
}
|
||
if let Ok(text) = String::from_utf8(msg_data.to_vec()) {
|
||
// println!("[text] JSON消息 [hex: {}]: {}", hex::encode(msg_data), text);
|
||
if let Ok(error) = serde_json::from_str::<ChatError>(&text) {
|
||
return Err(StreamError::ChatError(error));
|
||
}
|
||
}
|
||
Ok(None)
|
||
}
|
||
|
||
fn handle_gzip_json_message(
|
||
&self,
|
||
msg_data: &[u8],
|
||
) -> Result<Option<StreamMessage>, StreamError> {
|
||
if let Some(text) = decompress_gzip(msg_data) {
|
||
if text.len() == 2 {
|
||
return Ok(Some(StreamMessage::StreamEnd));
|
||
}
|
||
if let Ok(text) = String::from_utf8(text) {
|
||
// println!("[gzip] JSON消息 [hex: {}]: {}", hex::encode(msg_data), text);
|
||
if let Ok(error) = serde_json::from_str::<ChatError>(&text) {
|
||
return Err(StreamError::ChatError(error));
|
||
}
|
||
}
|
||
}
|
||
Ok(None)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_single_chunk() {
|
||
// 使用include_str!加载测试数据文件
|
||
let stream_data = include_str!("../../../tests/data/stream_data.txt");
|
||
|
||
// 将整个字符串按每两个字符分割成字节
|
||
let bytes: Vec<u8> = stream_data
|
||
.as_bytes()
|
||
.chunks(2)
|
||
.map(|chunk| {
|
||
let hex_str = std::str::from_utf8(chunk).unwrap();
|
||
u8::from_str_radix(hex_str, 16).unwrap()
|
||
})
|
||
.collect();
|
||
|
||
// 创建解码器
|
||
let mut decoder = StreamDecoder::new();
|
||
|
||
match decoder.decode(&bytes, false) {
|
||
Ok(messages) => {
|
||
for message in messages {
|
||
match message {
|
||
StreamMessage::StreamEnd => {
|
||
println!("流结束");
|
||
break;
|
||
}
|
||
StreamMessage::Content(msg) => {
|
||
println!("消息内容: {}", msg);
|
||
}
|
||
StreamMessage::WebReference(refs) => {
|
||
println!("网页引用:");
|
||
for (i, web_ref) in refs.iter().enumerate() {
|
||
println!(
|
||
"{}. {} - {} - {}",
|
||
i, web_ref.url, web_ref.title, web_ref.chunk
|
||
);
|
||
}
|
||
}
|
||
StreamMessage::Debug(prompt) => {
|
||
println!("调试信息: {}", prompt);
|
||
}
|
||
StreamMessage::ContentStart => {
|
||
println!("流开始");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
Err(e) => {
|
||
println!("解析错误: {}", e);
|
||
}
|
||
}
|
||
if decoder.is_incomplete() {
|
||
println!("数据不完整");
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_multiple_chunks() {
|
||
// 使用include_str!加载测试数据文件
|
||
let stream_data = include_str!("../../../tests/data/stream_data.txt");
|
||
|
||
// 将整个字符串按每两个字符分割成字节
|
||
let bytes: Vec<u8> = stream_data
|
||
.as_bytes()
|
||
.chunks(2)
|
||
.map(|chunk| {
|
||
let hex_str = std::str::from_utf8(chunk).unwrap();
|
||
u8::from_str_radix(hex_str, 16).unwrap()
|
||
})
|
||
.collect();
|
||
|
||
// 创建解码器
|
||
let mut decoder = StreamDecoder::new();
|
||
|
||
// 辅助函数:找到下一个消息边界
|
||
fn find_next_message_boundary(bytes: &[u8]) -> usize {
|
||
if bytes.len() < 5 {
|
||
return bytes.len();
|
||
}
|
||
let msg_len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
|
||
5 + msg_len
|
||
}
|
||
|
||
// 辅助函数:将字节转换为hex字符串
|
||
fn bytes_to_hex(bytes: &[u8]) -> String {
|
||
bytes
|
||
.iter()
|
||
.map(|b| format!("{:02X}", b))
|
||
.collect::<Vec<String>>()
|
||
.join("")
|
||
}
|
||
|
||
// 多次解析数据
|
||
let mut offset = 0;
|
||
let mut should_break = false;
|
||
|
||
while offset < bytes.len() {
|
||
let remaining_bytes = &bytes[offset..];
|
||
let msg_boundary = find_next_message_boundary(remaining_bytes);
|
||
let current_msg_bytes = &remaining_bytes[..msg_boundary];
|
||
let hex_str = bytes_to_hex(current_msg_bytes);
|
||
|
||
match decoder.decode(current_msg_bytes, false) {
|
||
Ok(messages) => {
|
||
for message in messages {
|
||
match message {
|
||
StreamMessage::StreamEnd => {
|
||
println!("流结束 [hex: {}]", hex_str);
|
||
should_break = true;
|
||
break;
|
||
}
|
||
StreamMessage::Content(msg) => {
|
||
println!("消息内容 [hex: {}]: {}", hex_str, msg);
|
||
}
|
||
StreamMessage::WebReference(refs) => {
|
||
println!("网页引用 [hex: {}]:", hex_str);
|
||
for (i, web_ref) in refs.iter().enumerate() {
|
||
println!(
|
||
"{}. {} - {} - {}",
|
||
i, web_ref.url, web_ref.title, web_ref.chunk
|
||
);
|
||
}
|
||
}
|
||
StreamMessage::Debug(prompt) => {
|
||
println!("调试信息 [hex: {}]: {}", hex_str, prompt);
|
||
}
|
||
StreamMessage::ContentStart => {
|
||
println!("流开始 [hex: {}]", hex_str);
|
||
}
|
||
}
|
||
}
|
||
if should_break {
|
||
break;
|
||
}
|
||
if decoder.is_incomplete() {
|
||
println!("数据不完整 [hex: {}]", hex_str);
|
||
break;
|
||
}
|
||
offset += msg_boundary;
|
||
}
|
||
Err(e) => {
|
||
println!("解析错误 [hex: {}]: {}", hex_str, e);
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|