Files
cursor-api/src/common/utils/base64.rs
2025-01-23 12:34:56 +08:00

149 lines
4.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Base64 字符集 (a-z, A-Z, 0-9, -, _)
const BASE64_CHARS: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_";
// 预计算的 Base64 查找表,用于快速解码
const BASE64_LOOKUP: [i8; 256] = {
let mut lookup = [-1i8; 256];
let mut i = 0;
while i < BASE64_CHARS.len() {
lookup[BASE64_CHARS[i] as usize] = i as i8;
i += 1;
}
lookup
};
/// 将字节切片编码为 Base64 字符串。
///
/// # Arguments
///
/// * `bytes`: 要编码的字节切片
///
/// # Returns
///
/// 编码后的 Base64 字符串
pub fn to_base64(bytes: &[u8]) -> String {
// 预分配足够容量,避免多次分配内存
let capacity = (bytes.len() + 2) / 3 * 4;
let mut result = Vec::with_capacity(capacity);
// 每三个字节为一组进行处理
for chunk in bytes.chunks(3) {
// 将三个字节合并为一个 u32
let b1 = chunk[0] as u32;
let b2 = chunk.get(1).map_or(0, |&b| b as u32);
let b3 = chunk.get(2).map_or(0, |&b| b as u32);
let n = (b1 << 16) | (b2 << 8) | b3;
// 将 u32 拆分成四个 6 位的值,并根据查找表转换为 Base64 字符
result.push(BASE64_CHARS[(n >> 18) as usize]);
result.push(BASE64_CHARS[((n >> 12) & 0x3F) as usize]);
// 如果 chunk 长度大于 1则需要处理第二个字符
if chunk.len() > 1 {
result.push(BASE64_CHARS[((n >> 6) & 0x3F) as usize]);
// 如果 chunk 长度大于 2则需要处理第三个字符
if chunk.len() > 2 {
result.push(BASE64_CHARS[(n & 0x3F) as usize]);
}
}
}
// 使用 from_utf8_unchecked 提高性能,因为 BASE64_CHARS 都是有效的 ASCII 字符
unsafe { String::from_utf8_unchecked(result) }
}
/// 将 Base64 字符串解码为字节数组。
///
/// # Arguments
///
/// * `input`: 要解码的 Base64 字符串
///
/// # Returns
///
/// 如果解码成功,返回 Some(解码后的字节数组);如果输入无效,返回 None
pub fn from_base64(input: &str) -> Option<Vec<u8>> {
let input = input.as_bytes();
// 检查输入长度Base64 编码的长度必须是 4 的倍数或余 2/3
if input.is_empty() || input.len() % 4 == 1 {
return None;
}
// 检查是否包含无效字符无效字符直接返回None
if input.iter().any(|&b| BASE64_LOOKUP[b as usize] == -1) {
return None;
}
// 预分配足够容量,避免多次分配内存
let capacity = input.len() / 4 * 3;
let mut result = Vec::with_capacity(capacity);
// 每四个字符为一组进行处理
let mut chunks = input.chunks_exact(4);
for chunk in &mut chunks {
// 使用查找表将 Base64 字符转换为 6 位的值
let n1 = BASE64_LOOKUP[chunk[0] as usize] as u32;
let n2 = BASE64_LOOKUP[chunk[1] as usize] as u32;
let n3 = BASE64_LOOKUP[chunk[2] as usize] as u32;
let n4 = BASE64_LOOKUP[chunk[3] as usize] as u32;
// 将四个 6 位的值合并为一个 u32并拆分成三个字节
let n = (n1 << 18) | (n2 << 12) | (n3 << 6) | n4;
result.push((n >> 16) as u8);
result.push(((n >> 8) & 0xFF) as u8);
result.push((n & 0xFF) as u8);
}
// 处理剩余的字符
let remainder = chunks.remainder();
if !remainder.is_empty() {
let n1 = BASE64_LOOKUP[remainder[0] as usize] as u32;
let n2 = BASE64_LOOKUP[remainder[1] as usize] as u32;
let mut n = (n1 << 18) | (n2 << 12);
result.push((n >> 16) as u8);
// 如果剩余字符长度大于 2则需要处理第二个字节
if remainder.len() > 2 {
let n3 = BASE64_LOOKUP[remainder[2] as usize] as u32;
n |= n3 << 6;
result.push(((n >> 8) & 0xFF) as u8);
}
}
Some(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_base64_roundtrip() {
let test_cases = vec![
vec![0u8, 1, 2, 3],
vec![255u8, 254, 253],
vec![0u8],
vec![0u8, 1],
vec![0u8, 1, 2],
vec![255u8; 1000],
];
for case in test_cases {
let encoded = to_base64(&case);
let decoded = from_base64(&encoded).unwrap();
assert_eq!(case, decoded);
}
}
#[test]
fn test_invalid_input() {
assert_eq!(from_base64(""), None); // 空字符串
assert_eq!(from_base64("a"), None); // 长度为 1
assert_eq!(from_base64("!@#$"), None); // 无效字符
assert_eq!(from_base64("YWJj!"), None); // 包含无效字符
assert!(from_base64("YWJj").is_some()); // 有效输入
}
}