update: 项目结构

This commit is contained in:
zeke
2024-11-26 17:14:48 +08:00
parent 6f91c5f0f8
commit 4db169e244
7 changed files with 537 additions and 518 deletions

View File

@@ -0,0 +1,295 @@
use axum::body::Body;
use axum::extract::Request;
use axum::response::sse::Event;
use axum::Json;
use axum::{
http::{HeaderMap, StatusCode},
response::{sse::Sse, IntoResponse, Response},
};
use bytes::Bytes;
use futures::channel::mpsc;
use futures::stream::StreamExt;
use futures::{SinkExt, Stream};
use std::convert::Infallible;
use std::error::Error;
// use http::HeaderName as HttpHeaderName;
use crate::hex_utils::{chunk_to_utf8_string, string_to_hex};
use crate::models;
use regex::Regex;
use std::str::FromStr;
use std::time::Duration;
use uuid::Uuid;
// 处理聊天完成请求
pub async fn chat_completions(
headers: HeaderMap,
request: Request<Body>,
// Json(chat_request): Json<ChatRequest>,
) -> Result<Response, StatusCode> {
// 提取并打印原始请求体
const MAX_BODY_SIZE: usize = 20 * 1024 * 1024;
let bytes = match axum::body::to_bytes(request.into_body(), MAX_BODY_SIZE).await {
Ok(bytes) => bytes,
Err(err) => {
tracing::error!("读取请求体失败: {}", err);
return Err(StatusCode::BAD_REQUEST);
}
};
// 打印原始请求体
if let Ok(body_str) = String::from_utf8(bytes.to_vec()) {
tracing::info!("原始请求体: {}", body_str);
}
// 尝试解析 JSON
let chat_request: models::chat::ChatRequest = match serde_json::from_slice(&bytes) {
Ok(req) => req,
Err(err) => {
tracing::error!("JSON解析失败: {}", err);
return Err(StatusCode::BAD_REQUEST);
}
};
// 验证认证
let auth_header = headers
.get("authorization")
.and_then(|h| h.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let mut auth_token = auth_header.replace("Bearer ", "");
// 验证o1模型不支持流式输出
if chat_request.model.starts_with("o1-") && chat_request.stream {
return Err(StatusCode::BAD_REQUEST);
}
tracing::info!("chat_request: {:?}", chat_request);
// 处理多个密钥
if auth_token.contains(',') {
auth_token = auth_token.split(',').next().unwrap().trim().to_string();
}
if auth_token.contains("%3A%3A") {
auth_token = auth_token
.split("%3A%3A")
.nth(1)
.unwrap_or(&auth_token)
.to_string();
}
// 格式化消息
// let formatted_messages = chat_request
// .messages
// .iter()
// .map(|msg| format!("{}:{}", msg.role, msg.content))
// .collect::<Vec<_>>()
// .join("\n");
let formatted_messages = chat_request
.messages
.iter()
.map(|msg| {
let content = msg
.content
.iter()
.map(|part| part.to_string())
.collect::<Vec<_>>()
.join(", ");
format!("{}:{}", msg.role, content)
})
.collect::<Vec<_>>()
.join("\n");
// 生成请求数据
let hex_data = string_to_hex(&formatted_messages, &chat_request.model);
// 准备请求头
let request_id = Uuid::new_v4();
let headers = reqwest::header::HeaderMap::from_iter([
(reqwest::header::CONTENT_TYPE, "application/connect+proto"),
(reqwest::header::AUTHORIZATION, &format!("Bearer {}", auth_token)),
// 对于标准 HTTP 头部,使用预定义的常量
(reqwest::header::HeaderName::from_str("Connect-Accept-Encoding").unwrap(), "gzip,br"),
(reqwest::header::HeaderName::from_str("Connect-Protocol-Version").unwrap(), "1"),
(reqwest::header::HeaderName::from_str("User-Agent").unwrap(), "connect-es/1.4.0"),
(reqwest::header::HeaderName::from_str("X-Amzn-Trace-Id").unwrap(), &format!("Root={}", Uuid::new_v4())),
(reqwest::header::HeaderName::from_str("X-Cursor-Checksum").unwrap(), "zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef"),
(reqwest::header::HeaderName::from_str("X-Cursor-Client-Version").unwrap(), "0.42.3"),
(reqwest::header::HeaderName::from_str("X-Cursor-Timezone").unwrap(), "Asia/Shanghai"),
(reqwest::header::HeaderName::from_str("X-Ghost-Mode").unwrap(), "false"),
(reqwest::header::HeaderName::from_str("X-Request-Id").unwrap(), &request_id.to_string()),
(reqwest::header::HeaderName::from_str("Host").unwrap(), "api2.cursor.sh"),
].iter().map(|(k, v)| (
k.clone(),
reqwest::header::HeaderValue::from_str(v).unwrap()
)));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(300))
.build()
.map_err(|e| {
tracing::error!("创建HTTP客户端失败: {:?}", e);
tracing::error!(error = %e, "错误详情");
if let Some(source) = e.source() {
tracing::error!(source = %source, "错误源");
}
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response = client
.post("https://api2.cursor.sh/aiserver.v1.AiService/StreamChat")
.headers(headers)
.body(hex_data)
.send()
.await
.map_err(|e| {
tracing::error!("请求失败: {:?}", e);
tracing::error!(error = %e, "错误详情");
// 如果是超时错误
if e.is_timeout() {
tracing::error!("请求超时");
}
// 如果是连接错误
if e.is_connect() {
tracing::error!("连接失败");
}
// 如果有请求信息
if let Some(url) = e.url() {
tracing::error!(url = %url, "请求URL");
}
// 如果有状态码
if let Some(status) = e.status() {
tracing::error!(status = %status, "HTTP状态码");
}
StatusCode::INTERNAL_SERVER_ERROR
})?;
if chat_request.stream {
let mut chunks = Vec::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(chunk) => chunks.push(chunk),
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
let stream = process_stream(chunks).await;
return Ok(Sse::new(stream).into_response());
}
// 非流式响应
let mut text = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(chunk) => {
let res = chunk_to_utf8_string(&chunk);
if !res.is_empty() {
text.push_str(&res);
}
}
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
// 清理响应文本
let re = Regex::new(r"^.*<\|END_USER\|>").unwrap();
text = re.replace(&text, "").to_string();
let re = Regex::new(r"^\n[a-zA-Z]?").unwrap();
text = re.replace(&text, "").trim().to_string();
let re = Regex::new(r"[\x00-\x1F\x7F]").unwrap();
text = re.replace_all(&text, "").to_string();
let response = models::chat::ChatResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: chat_request.model,
choices: vec![models::chat::Choice {
index: 0,
message: models::chat::ResponseMessage {
role: "assistant".to_string(),
content: text,
},
finish_reason: "stop".to_string(),
}],
usage: models::chat::Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
};
Ok(Json(response).into_response())
}
async fn process_stream(
chunks: Vec<Bytes>,
) -> impl Stream<Item = Result<Event, Infallible>> + Send {
let (mut tx, rx) = mpsc::channel(100);
let response_id = format!("chatcmpl-{}", Uuid::new_v4());
tokio::spawn(async move {
for chunk in chunks {
let text = chunk_to_utf8_string(&chunk);
if !text.is_empty() {
let text = text.trim();
let text = if let Some(idx) = text.find("<|END_USER|>") {
text[idx + "<|END_USER|>".len()..].trim()
} else {
text
};
let text = if !text.is_empty() && text.chars().next().unwrap().is_alphabetic() {
text[1..].trim()
} else {
text
};
let re = Regex::new(r"[\x00-\x1F\x7F]").unwrap();
let text = re.replace_all(text, "");
if !text.is_empty() {
let var_name = models::chat::StreamResponse {
id: response_id.clone(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp(),
choices: vec![models::chat::StreamChoice {
index: 0,
delta: models::chat::Delta {
content: text.to_string(),
},
}],
};
let response = var_name;
let json_data = serde_json::to_string(&response).unwrap();
if !json_data.is_empty() {
let _ = tx.send(Ok(Event::default().data(json_data))).await;
}
}
}
}
let _ = tx.send(Ok(Event::default().data("[DONE]"))).await;
});
rx
}

View File

@@ -0,0 +1,2 @@
pub mod chat;
pub mod models;

View File

@@ -0,0 +1,82 @@
use axum::Json;
// 处理模型列表请求
pub async fn models() -> Json<serde_json::Value> {
Json(serde_json::json!({
"object": "list",
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"object": "model",
"created": 1713744000,
"owned_by": "anthropic"
},
{
"id": "claude-3-opus",
"object": "model",
"created": 1709251200,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-haiku",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-sonnet",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "cursor-small",
"object": "model",
"created": 1712534400,
"owned_by": "cursor"
},
{
"id": "gpt-3.5-turbo",
"object": "model",
"created": 1677649200,
"owned_by": "openai"
},
{
"id": "gpt-4",
"object": "model",
"created": 1687392000,
"owned_by": "openai"
},
{
"id": "gpt-4-turbo-2024-04-09",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-preview",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
}
]
}))
}

View File

@@ -1,14 +1,14 @@
pub fn string_to_hex(text: &str, model_name: &str) -> Vec<u8> {
let text_bytes = text.as_bytes();
let text_length = text_bytes.len();
// 固定常量
const FIXED_HEADER: usize = 2;
const SEPARATOR: usize = 1;
let model_name_bytes = model_name.as_bytes();
let fixed_suffix_length = 0xA3 + model_name_bytes.len();
// 计算第一个长度字段
let (text_length_field1, text_length_field_size1) = if text_length < 128 {
(format!("{:02x}", text_length), 1)
@@ -29,8 +29,12 @@ pub fn string_to_hex(text: &str, model_name: &str) -> Vec<u8> {
};
// 计算总消息长度
let message_total_length = FIXED_HEADER + text_length_field_size + SEPARATOR +
text_length_field_size1 + text_length + fixed_suffix_length;
let message_total_length = FIXED_HEADER
+ text_length_field_size
+ SEPARATOR
+ text_length_field_size1
+ text_length
+ fixed_suffix_length;
// 构造十六进制字符串
let model_name_length_hex = format!("{:02X}", model_name_bytes.len());
@@ -54,7 +58,8 @@ pub fn string_to_hex(text: &str, model_name: &str) -> Vec<u8> {
hex::encode_upper(text_bytes),
model_name_length_hex,
model_name_hex
).to_uppercase();
)
.to_uppercase();
// 将十六进制字符串转换为字节数组
hex::decode(hex_string).unwrap_or_default()
@@ -64,7 +69,7 @@ pub fn chunk_to_utf8_string(chunk: &[u8]) -> String {
if chunk.len() < 2 {
return String::new();
}
if chunk[0] == 0x01 || chunk[0] == 0x02 || (chunk[0] == 0x60 && chunk[1] == 0x0C) {
return String::new();
}
@@ -72,15 +77,15 @@ pub fn chunk_to_utf8_string(chunk: &[u8]) -> String {
// 尝试找到0x0A并从其后开始处理
let chunk = match chunk.iter().position(|&x| x == 0x0A) {
Some(pos) => &chunk[pos + 1..],
None => chunk
None => chunk,
};
let mut filtered_chunk = Vec::new();
let mut i = 0;
while i < chunk.len() {
// 检查是否有连续的0x00
if i + 4 <= chunk.len() && chunk[i..i+4].iter().all(|&x| x == 0x00) {
if i + 4 <= chunk.len() && chunk[i..i + 4].iter().all(|&x| x == 0x00) {
i += 4;
while i < chunk.len() && chunk[i] <= 0x0F {
i += 1;
@@ -108,4 +113,4 @@ pub fn chunk_to_utf8_string(chunk: &[u8]) -> String {
// 转换为UTF-8字符串
String::from_utf8_lossy(&filtered_chunk).trim().to_string()
}
}

View File

@@ -1,213 +1,14 @@
use axum::body::Body;
use axum::extract::Request;
mod handlers;
mod models;
use axum::{
http::{HeaderMap, StatusCode},
response::{
sse::{Event, Sse},
IntoResponse, Response,
},
routing::{get, post},
Json, Router,
Router,
};
use bytes::Bytes;
use futures::{
channel::mpsc,
stream::{Stream, StreamExt},
SinkExt,
};
use std::error::Error;
use tower_http::trace::TraceLayer;
// use http::HeaderName as HttpHeaderName;
use regex::Regex;
use serde::Deserializer;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::{convert::Infallible, time::Duration};
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
mod hex_utils;
use hex_utils::{chunk_to_utf8_string, string_to_hex};
// 定义请求模型
#[derive(Debug, Deserialize)]
struct Message {
role: String,
#[serde(deserialize_with = "deserialize_content")]
content: Vec<ContentPart>,
}
// 添加一个辅助枚举
#[derive(Deserialize)]
#[serde(untagged)]
enum SingleOrVec<T> {
Single(T),
Vec(Vec<T>),
}
// 新增一个字符串或ContentPart的枚举
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ContentItem {
String(String),
Part(ContentPart),
}
// 新的反序列化函数
fn deserialize_content<'de, D>(deserializer: D) -> Result<Vec<ContentPart>, D::Error>
where
D: Deserializer<'de>,
{
// 首先尝试作为字符串反序列化
let content = SingleOrVec::<ContentItem>::deserialize(deserializer)?;
Ok(match content {
SingleOrVec::Single(item) => match item {
ContentItem::String(s) => vec![ContentPart::Text { text: s }],
ContentItem::Part(p) => vec![p],
},
SingleOrVec::Vec(items) => items
.into_iter()
.map(|item| match item {
ContentItem::String(s) => ContentPart::Text { text: s },
ContentItem::Part(p) => p,
})
.collect(),
})
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Deserialize)]
struct ImageUrl {
url: String,
}
impl std::fmt::Display for ContentPart {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContentPart::Text { text } => write!(f, "{}", text),
ContentPart::ImageUrl { image_url } => write!(f, "[Image: {}]", image_url.url),
}
}
}
#[derive(Debug, Deserialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
#[serde(default)]
stream: bool,
}
// 定义响应模型
#[derive(Debug, Serialize)]
struct ChatResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<Choice>,
usage: Usage,
}
#[derive(Debug, Serialize)]
struct Choice {
index: i32,
message: ResponseMessage,
finish_reason: String,
}
#[derive(Debug, Serialize)]
struct ResponseMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct Usage {
prompt_tokens: i32,
completion_tokens: i32,
total_tokens: i32,
}
#[derive(Debug, Serialize)]
struct StreamResponse {
id: String,
object: String,
created: i64,
choices: Vec<StreamChoice>,
}
#[derive(Debug, Serialize)]
struct StreamChoice {
index: i32,
delta: Delta,
}
#[derive(Debug, Serialize)]
struct Delta {
content: String,
}
async fn process_stream(
chunks: Vec<Bytes>,
) -> impl Stream<Item = Result<Event, Infallible>> + Send {
let (mut tx, rx) = mpsc::channel(100);
let response_id = format!("chatcmpl-{}", Uuid::new_v4());
tokio::spawn(async move {
for chunk in chunks {
let text = chunk_to_utf8_string(&chunk);
if !text.is_empty() {
let text = text.trim();
let text = if let Some(idx) = text.find("<|END_USER|>") {
text[idx + "<|END_USER|>".len()..].trim()
} else {
text
};
let text = if !text.is_empty() && text.chars().next().unwrap().is_alphabetic() {
text[1..].trim()
} else {
text
};
let re = Regex::new(r"[\x00-\x1F\x7F]").unwrap();
let text = re.replace_all(text, "");
if !text.is_empty() {
let response = StreamResponse {
id: response_id.clone(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp(),
choices: vec![StreamChoice {
index: 0,
delta: Delta {
content: text.to_string(),
},
}],
};
let json_data = serde_json::to_string(&response).unwrap();
if !json_data.is_empty() {
let _ = tx.send(Ok(Event::default().data(json_data))).await;
}
}
}
}
let _ = tx.send(Ok(Event::default().data("[DONE]"))).await;
});
rx
}
#[tokio::main]
async fn main() {
@@ -222,9 +23,12 @@ async fn main() {
// 创建路由
let app = Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/models", get(models))
.route("/v1/models", get(models))
.route(
"/v1/chat/completions",
post(handlers::chat::chat_completions),
)
.route("/models", get(handlers::models::models))
.route("/v1/models", get(handlers::models::models))
.layer(cors)
.layer(
TraceLayer::new_for_http()
@@ -251,309 +55,9 @@ async fn main() {
// 启动服务器
let port = std::env::var("PORT").unwrap_or_else(|_| "3000".to_string());
let addr = format!("0.0.0.0:{}", port);
println!("Server running on {}", addr);
tracing::info!("Server running on {}", addr);
// 修改服务器启动代码
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
// 处理聊天完成请求
async fn chat_completions(
headers: HeaderMap,
request: Request<Body>,
// Json(chat_request): Json<ChatRequest>,
) -> Result<Response, StatusCode> {
// 提取并打印原始请求体
const MAX_BODY_SIZE: usize = 20 * 1024 * 1024;
let bytes = match axum::body::to_bytes(request.into_body(), MAX_BODY_SIZE).await {
Ok(bytes) => bytes,
Err(err) => {
tracing::error!("读取请求体失败: {}", err);
return Err(StatusCode::BAD_REQUEST);
}
};
// 打印原始请求体
if let Ok(body_str) = String::from_utf8(bytes.to_vec()) {
tracing::info!("原始请求体: {}", body_str);
}
// 尝试解析 JSON
let chat_request: ChatRequest = match serde_json::from_slice(&bytes) {
Ok(req) => req,
Err(err) => {
tracing::error!("JSON解析失败: {}", err);
return Err(StatusCode::BAD_REQUEST);
}
};
// 验证认证
let auth_header = headers
.get("authorization")
.and_then(|h| h.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let mut auth_token = auth_header.replace("Bearer ", "");
// 验证o1模型不支持流式输出
if chat_request.model.starts_with("o1-") && chat_request.stream {
return Err(StatusCode::BAD_REQUEST);
}
tracing::info!("chat_request: {:?}", chat_request);
// 处理多个密钥
if auth_token.contains(',') {
auth_token = auth_token.split(',').next().unwrap().trim().to_string();
}
if auth_token.contains("%3A%3A") {
auth_token = auth_token
.split("%3A%3A")
.nth(1)
.unwrap_or(&auth_token)
.to_string();
}
// 格式化消息
// let formatted_messages = chat_request
// .messages
// .iter()
// .map(|msg| format!("{}:{}", msg.role, msg.content))
// .collect::<Vec<_>>()
// .join("\n");
let formatted_messages = chat_request
.messages
.iter()
.map(|msg| {
let content = msg
.content
.iter()
.map(|part| part.to_string())
.collect::<Vec<_>>()
.join(", ");
format!("{}:{}", msg.role, content)
})
.collect::<Vec<_>>()
.join("\n");
// 生成请求数据
let hex_data = string_to_hex(&formatted_messages, &chat_request.model);
// 准备请求头
let request_id = Uuid::new_v4();
let headers = reqwest::header::HeaderMap::from_iter([
(reqwest::header::CONTENT_TYPE, "application/connect+proto"),
(reqwest::header::AUTHORIZATION, &format!("Bearer {}", auth_token)),
// 对于标准 HTTP 头部,使用预定义的常量
(reqwest::header::HeaderName::from_str("Connect-Accept-Encoding").unwrap(), "gzip,br"),
(reqwest::header::HeaderName::from_str("Connect-Protocol-Version").unwrap(), "1"),
(reqwest::header::HeaderName::from_str("User-Agent").unwrap(), "connect-es/1.4.0"),
(reqwest::header::HeaderName::from_str("X-Amzn-Trace-Id").unwrap(), &format!("Root={}", Uuid::new_v4())),
(reqwest::header::HeaderName::from_str("X-Cursor-Checksum").unwrap(), "zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef"),
(reqwest::header::HeaderName::from_str("X-Cursor-Client-Version").unwrap(), "0.42.3"),
(reqwest::header::HeaderName::from_str("X-Cursor-Timezone").unwrap(), "Asia/Shanghai"),
(reqwest::header::HeaderName::from_str("X-Ghost-Mode").unwrap(), "false"),
(reqwest::header::HeaderName::from_str("X-Request-Id").unwrap(), &request_id.to_string()),
(reqwest::header::HeaderName::from_str("Host").unwrap(), "api2.cursor.sh"),
].iter().map(|(k, v)| (
k.clone(),
reqwest::header::HeaderValue::from_str(v).unwrap()
)));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(300))
.build()
.map_err(|e| {
tracing::error!("创建HTTP客户端失败: {:?}", e);
tracing::error!(error = %e, "错误详情");
if let Some(source) = e.source() {
tracing::error!(source = %source, "错误源");
}
StatusCode::INTERNAL_SERVER_ERROR
})?;
let response = client
.post("https://api2.cursor.sh/aiserver.v1.AiService/StreamChat")
.headers(headers)
.body(hex_data)
.send()
.await
.map_err(|e| {
tracing::error!("请求失败: {:?}", e);
tracing::error!(error = %e, "错误详情");
// 如果是超时错误
if e.is_timeout() {
tracing::error!("请求超时");
}
// 如果是连接错误
if e.is_connect() {
tracing::error!("连接失败");
}
// 如果有请求信息
if let Some(url) = e.url() {
tracing::error!(url = %url, "请求URL");
}
// 如果有状态码
if let Some(status) = e.status() {
tracing::error!(status = %status, "HTTP状态码");
}
StatusCode::INTERNAL_SERVER_ERROR
})?;
if chat_request.stream {
let mut chunks = Vec::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(chunk) => chunks.push(chunk),
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
let stream = process_stream(chunks).await;
return Ok(Sse::new(stream).into_response());
}
// 非流式响应
let mut text = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(chunk) => {
let res = chunk_to_utf8_string(&chunk);
if !res.is_empty() {
text.push_str(&res);
}
}
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
// 清理响应文本
let re = Regex::new(r"^.*<\|END_USER\|>").unwrap();
text = re.replace(&text, "").to_string();
let re = Regex::new(r"^\n[a-zA-Z]?").unwrap();
text = re.replace(&text, "").trim().to_string();
let re = Regex::new(r"[\x00-\x1F\x7F]").unwrap();
text = re.replace_all(&text, "").to_string();
let response = ChatResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: chat_request.model,
choices: vec![Choice {
index: 0,
message: ResponseMessage {
role: "assistant".to_string(),
content: text,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
};
Ok(Json(response).into_response())
}
// 处理模型列表请求
async fn models() -> Json<serde_json::Value> {
Json(serde_json::json!({
"object": "list",
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"object": "model",
"created": 1713744000,
"owned_by": "anthropic"
},
{
"id": "claude-3-opus",
"object": "model",
"created": 1709251200,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-haiku",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-sonnet",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "cursor-small",
"object": "model",
"created": 1712534400,
"owned_by": "cursor"
},
{
"id": "gpt-3.5-turbo",
"object": "model",
"created": 1677649200,
"owned_by": "openai"
},
{
"id": "gpt-4",
"object": "model",
"created": 1687392000,
"owned_by": "openai"
},
{
"id": "gpt-4-turbo-2024-04-09",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-preview",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
}
]
}))
}

130
rs-capi/src/models/chat.rs Normal file
View File

@@ -0,0 +1,130 @@
use serde::Deserializer;
use serde::{Deserialize, Serialize};
// 定义请求模型
#[derive(Debug, Deserialize)]
pub struct Message {
pub role: String,
#[serde(deserialize_with = "deserialize_content")]
pub content: Vec<ContentPart>,
}
// 添加一个辅助枚举
#[derive(Deserialize)]
#[serde(untagged)]
pub enum SingleOrVec<T> {
Single(T),
Vec(Vec<T>),
}
// 新增一个字符串或ContentPart的枚举
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ContentItem {
String(String),
Part(ContentPart),
}
// 新的反序列化函数
fn deserialize_content<'de, D>(deserializer: D) -> Result<Vec<ContentPart>, D::Error>
where
D: Deserializer<'de>,
{
// 首先尝试作为字符串反序列化
let content = SingleOrVec::<ContentItem>::deserialize(deserializer)?;
Ok(match content {
SingleOrVec::Single(item) => match item {
ContentItem::String(s) => vec![ContentPart::Text { text: s }],
ContentItem::Part(p) => vec![p],
},
SingleOrVec::Vec(items) => items
.into_iter()
.map(|item| match item {
ContentItem::String(s) => ContentPart::Text { text: s },
ContentItem::Part(p) => p,
})
.collect(),
})
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Deserialize)]
pub struct ImageUrl {
pub url: String,
}
impl std::fmt::Display for ContentPart {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContentPart::Text { text } => write!(f, "{}", text),
ContentPart::ImageUrl { image_url } => write!(f, "[Image: {}]", image_url.url),
}
}
}
#[derive(Debug, Deserialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(default)]
pub stream: bool,
}
// 定义响应模型
#[derive(Debug, Serialize)]
pub struct ChatResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Debug, Serialize)]
pub struct Choice {
pub index: i32,
pub message: ResponseMessage,
pub finish_reason: String,
}
#[derive(Debug, Serialize)]
pub struct ResponseMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct Usage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}
#[derive(Debug, Serialize)]
pub struct StreamResponse {
pub id: String,
pub object: String,
pub created: i64,
pub choices: Vec<StreamChoice>,
}
#[derive(Debug, Serialize)]
pub struct StreamChoice {
pub index: i32,
pub delta: Delta,
}
#[derive(Debug, Serialize)]
pub struct Delta {
pub content: String,
}

View File

@@ -0,0 +1 @@
pub mod chat;