diff --git a/rs-capi/src/main.rs b/rs-capi/src/main.rs index 979d179..2d42510 100644 --- a/rs-capi/src/main.rs +++ b/rs-capi/src/main.rs @@ -18,12 +18,12 @@ use futures::{ }; // use http::HeaderName as HttpHeaderName; use regex::Regex; +use serde::Deserializer; use serde::{Deserialize, Serialize}; use std::str::FromStr; use std::{convert::Infallible, time::Duration}; use tower_http::cors::{Any, CorsLayer}; use uuid::Uuid; - mod hex_utils; use hex_utils::{chunk_to_utf8_string, string_to_hex}; @@ -31,9 +31,30 @@ use hex_utils::{chunk_to_utf8_string, string_to_hex}; #[derive(Debug, Deserialize)] struct Message { role: String, + #[serde(deserialize_with = "deserialize_single_or_vec")] content: Vec, } +// 添加一个辅助枚举 +#[derive(Deserialize)] +#[serde(untagged)] +enum SingleOrVec { + Single(T), + Vec(Vec), +} + +// 简单的辅助函数 +fn deserialize_single_or_vec<'de, D, T>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, + T: Deserialize<'de>, +{ + let value = SingleOrVec::deserialize(deserializer)?; + Ok(match value { + SingleOrVec::Single(x) => vec![x], + SingleOrVec::Vec(x) => x, + }) +} #[derive(Debug, Deserialize)] #[serde(tag = "type")] enum ContentPart {