mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-26 20:11:16 +08:00
419 lines
10 KiB
Go
419 lines
10 KiB
Go
package codec
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/oarkflow/json"
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
)
|
|
|
|
// Error definitions for serialization
|
|
var (
|
|
ErrSerializationFailed = errors.New("serialization failed")
|
|
ErrDeserializationFailed = errors.New("deserialization failed")
|
|
ErrCompressionFailed = errors.New("compression failed")
|
|
ErrDecompressionFailed = errors.New("decompression failed")
|
|
ErrEncryptionFailed = errors.New("encryption failed")
|
|
ErrDecryptionFailed = errors.New("decryption failed")
|
|
ErrInvalidKey = errors.New("invalid encryption key")
|
|
)
|
|
|
|
// ContentType represents the content type of serialized data
|
|
type ContentType string
|
|
|
|
const (
|
|
ContentTypeJSON ContentType = "application/json"
|
|
ContentTypeMsgPack ContentType = "application/msgpack"
|
|
ContentTypeCBOR ContentType = "application/cbor"
|
|
)
|
|
|
|
// Marshaller interface for pluggable serialization
|
|
type Marshaller interface {
|
|
Marshal(v any) ([]byte, error)
|
|
ContentType() ContentType
|
|
}
|
|
|
|
// Unmarshaller interface for pluggable deserialization
|
|
type Unmarshaller interface {
|
|
Unmarshal(data []byte, v any) error
|
|
ContentType() ContentType
|
|
}
|
|
|
|
// MarshallerFunc adapter
|
|
type MarshallerFunc func(v any) ([]byte, error)
|
|
|
|
func (f MarshallerFunc) Marshal(v any) ([]byte, error) {
|
|
return f(v)
|
|
}
|
|
|
|
func (f MarshallerFunc) ContentType() ContentType {
|
|
return ContentTypeJSON
|
|
}
|
|
|
|
// UnmarshallerFunc adapter
|
|
type UnmarshallerFunc func(data []byte, v any) error
|
|
|
|
func (f UnmarshallerFunc) Unmarshal(data []byte, v any) error {
|
|
return f(data, v)
|
|
}
|
|
|
|
func (f UnmarshallerFunc) ContentType() ContentType {
|
|
return ContentTypeJSON
|
|
}
|
|
|
|
// SerializationConfig holds serialization configuration
|
|
type SerializationConfig struct {
|
|
EnableCompression bool
|
|
CompressionLevel int
|
|
MaxCompressionRatio float64
|
|
EnableEncryption bool
|
|
EncryptionKey []byte
|
|
PreferredCipher string // "chacha20poly1305" or "aes-gcm"
|
|
}
|
|
|
|
// DefaultSerializationConfig returns default configuration
|
|
func DefaultSerializationConfig() *SerializationConfig {
|
|
return &SerializationConfig{
|
|
EnableCompression: false,
|
|
CompressionLevel: gzip.DefaultCompression,
|
|
MaxCompressionRatio: 0.8, // Only compress if we save at least 20%
|
|
EnableEncryption: false,
|
|
PreferredCipher: "chacha20poly1305",
|
|
}
|
|
}
|
|
|
|
// SerializationManager manages serialization with configuration
|
|
type SerializationManager struct {
|
|
marshaller Marshaller
|
|
unmarshaller Unmarshaller
|
|
config *SerializationConfig
|
|
mu sync.RWMutex
|
|
cachedCipher cipher.AEAD
|
|
cipherMu sync.Mutex
|
|
}
|
|
|
|
// NewSerializationManager creates a new serialization manager
|
|
func NewSerializationManager(config *SerializationConfig) *SerializationManager {
|
|
if config == nil {
|
|
config = DefaultSerializationConfig()
|
|
}
|
|
|
|
return &SerializationManager{
|
|
marshaller: MarshallerFunc(json.Marshal),
|
|
unmarshaller: UnmarshallerFunc(func(data []byte, v any) error {
|
|
return json.Unmarshal(data, v)
|
|
}),
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
// SetMarshaller sets custom marshaller
|
|
func (sm *SerializationManager) SetMarshaller(marshaller Marshaller) {
|
|
sm.mu.Lock()
|
|
defer sm.mu.Unlock()
|
|
sm.marshaller = marshaller
|
|
}
|
|
|
|
// SetUnmarshaller sets custom unmarshaller
|
|
func (sm *SerializationManager) SetUnmarshaller(unmarshaller Unmarshaller) {
|
|
sm.mu.Lock()
|
|
defer sm.mu.Unlock()
|
|
sm.unmarshaller = unmarshaller
|
|
}
|
|
|
|
// SetEncryptionKey sets the encryption key
|
|
func (sm *SerializationManager) SetEncryptionKey(key []byte) error {
|
|
if sm.config.EnableEncryption && len(key) == 0 {
|
|
return ErrInvalidKey
|
|
}
|
|
|
|
sm.mu.Lock()
|
|
defer sm.mu.Unlock()
|
|
|
|
sm.config.EncryptionKey = key
|
|
// Clear the cached cipher so it will be recreated with the new key
|
|
sm.cipherMu.Lock()
|
|
defer sm.cipherMu.Unlock()
|
|
sm.cachedCipher = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
// Marshal serializes data with optional compression and encryption
|
|
func (sm *SerializationManager) Marshal(v any) ([]byte, error) {
|
|
sm.mu.RLock()
|
|
marshaller := sm.marshaller
|
|
config := sm.config
|
|
sm.mu.RUnlock()
|
|
|
|
// Serialize the data
|
|
data, err := marshaller.Marshal(v)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %v", ErrSerializationFailed, err)
|
|
}
|
|
|
|
// Apply compression if enabled and beneficial
|
|
if config.EnableCompression && len(data) > 256 { // Only compress larger payloads
|
|
compressed, err := sm.compress(data)
|
|
if err != nil {
|
|
// Continue with uncompressed data on error
|
|
// but don't return error to allow for graceful degradation
|
|
} else if float64(len(compressed))/float64(len(data)) <= config.MaxCompressionRatio {
|
|
// Only use compression if it provides significant benefit
|
|
data = compressed
|
|
}
|
|
}
|
|
|
|
// Apply encryption if enabled
|
|
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
|
|
encrypted, err := sm.encrypt(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
|
|
}
|
|
data = encrypted
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
// Unmarshal deserializes data with optional decompression and decryption
|
|
func (sm *SerializationManager) Unmarshal(data []byte, v any) error {
|
|
if len(data) == 0 {
|
|
return fmt.Errorf("%w: empty data", ErrDeserializationFailed)
|
|
}
|
|
|
|
sm.mu.RLock()
|
|
unmarshaller := sm.unmarshaller
|
|
config := sm.config
|
|
sm.mu.RUnlock()
|
|
|
|
// Apply decryption if enabled
|
|
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
|
|
decrypted, err := sm.decrypt(data)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %v", ErrDecryptionFailed, err)
|
|
}
|
|
data = decrypted
|
|
}
|
|
|
|
// Try decompression if enabled
|
|
if config.EnableCompression && sm.isCompressed(data) {
|
|
decompressed, err := sm.decompress(data)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %v", ErrDecompressionFailed, err)
|
|
}
|
|
data = decompressed
|
|
}
|
|
|
|
// Deserialize the data
|
|
if err := unmarshaller.Unmarshal(data, v); err != nil {
|
|
return fmt.Errorf("%w: %v", ErrDeserializationFailed, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// compress compresses data using gzip
|
|
func (sm *SerializationManager) compress(data []byte) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
|
|
// Add compression marker
|
|
buf.WriteByte(0x1f) // gzip magic number
|
|
|
|
writer, err := gzip.NewWriterLevel(&buf, sm.config.CompressionLevel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if _, err := writer.Write(data); err != nil {
|
|
writer.Close()
|
|
return nil, err
|
|
}
|
|
|
|
if err := writer.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// decompress decompresses gzip data
|
|
func (sm *SerializationManager) decompress(data []byte) ([]byte, error) {
|
|
if len(data) < 1 {
|
|
return nil, fmt.Errorf("invalid compressed data")
|
|
}
|
|
|
|
// Remove compression marker
|
|
if data[0] != 0x1f {
|
|
return data, nil // Not compressed
|
|
}
|
|
|
|
reader, err := gzip.NewReader(bytes.NewReader(data[1:]))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer reader.Close()
|
|
|
|
var buf bytes.Buffer
|
|
if _, err := buf.ReadFrom(reader); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// isCompressed checks if data is compressed
|
|
func (sm *SerializationManager) isCompressed(data []byte) bool {
|
|
return len(data) > 0 && data[0] == 0x1f
|
|
}
|
|
|
|
// getCipher returns a cipher.AEAD instance for encryption/decryption
|
|
func (sm *SerializationManager) getCipher() (cipher.AEAD, error) {
|
|
sm.cipherMu.Lock()
|
|
defer sm.cipherMu.Unlock()
|
|
|
|
if sm.cachedCipher != nil {
|
|
return sm.cachedCipher, nil
|
|
}
|
|
|
|
var aead cipher.AEAD
|
|
var err error
|
|
|
|
sm.mu.RLock()
|
|
key := sm.config.EncryptionKey
|
|
preferredCipher := sm.config.PreferredCipher
|
|
sm.mu.RUnlock()
|
|
|
|
switch preferredCipher {
|
|
case "chacha20poly1305":
|
|
aead, err = chacha20poly1305.New(key)
|
|
case "aes-gcm":
|
|
block, e := aes.NewCipher(key)
|
|
if e != nil {
|
|
err = e
|
|
break
|
|
}
|
|
aead, err = cipher.NewGCM(block)
|
|
default:
|
|
// Default to ChaCha20-Poly1305
|
|
aead, err = chacha20poly1305.New(key)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sm.cachedCipher = aead
|
|
return aead, nil
|
|
}
|
|
|
|
// encrypt encrypts data using authenticated encryption
|
|
func (sm *SerializationManager) encrypt(data []byte) ([]byte, error) {
|
|
aead, err := sm.getCipher()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Generate a random nonce
|
|
nonce := make([]byte, aead.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Add timestamp to associated data for replay protection
|
|
timestampBytes := make([]byte, 8)
|
|
binary.BigEndian.PutUint64(timestampBytes, uint64(time.Now().Unix()))
|
|
|
|
// Encrypt and authenticate
|
|
ciphertext := aead.Seal(nil, nonce, data, timestampBytes)
|
|
|
|
// Prepend nonce and timestamp to ciphertext
|
|
result := make([]byte, len(nonce)+len(timestampBytes)+len(ciphertext))
|
|
copy(result, nonce)
|
|
copy(result[len(nonce):], timestampBytes)
|
|
copy(result[len(nonce)+len(timestampBytes):], ciphertext)
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// decrypt decrypts data using authenticated decryption
|
|
func (sm *SerializationManager) decrypt(data []byte) ([]byte, error) {
|
|
aead, err := sm.getCipher()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Extract nonce
|
|
nonceSize := aead.NonceSize()
|
|
if len(data) < nonceSize+8 { // 8 bytes for timestamp
|
|
return nil, fmt.Errorf("ciphertext too short")
|
|
}
|
|
|
|
nonce := data[:nonceSize]
|
|
timestampBytes := data[nonceSize : nonceSize+8]
|
|
ciphertext := data[nonceSize+8:]
|
|
|
|
// Decrypt and verify
|
|
plaintext, err := aead.Open(nil, nonce, ciphertext, timestampBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
// Global serialization manager instance
|
|
var globalSerializationManager = NewSerializationManager(nil)
|
|
var globalSerializationManagerMu sync.RWMutex
|
|
|
|
// Global functions for backward compatibility
|
|
func SetMarshaller(marshaller MarshallerFunc) {
|
|
globalSerializationManagerMu.Lock()
|
|
defer globalSerializationManagerMu.Unlock()
|
|
globalSerializationManager.SetMarshaller(marshaller)
|
|
}
|
|
|
|
func SetUnmarshaller(unmarshaller UnmarshallerFunc) {
|
|
globalSerializationManagerMu.Lock()
|
|
defer globalSerializationManagerMu.Unlock()
|
|
globalSerializationManager.SetUnmarshaller(unmarshaller)
|
|
}
|
|
|
|
func EnableCompression(enable bool) {
|
|
globalSerializationManagerMu.Lock()
|
|
defer globalSerializationManagerMu.Unlock()
|
|
globalSerializationManager.config.EnableCompression = enable
|
|
}
|
|
|
|
func EnableEncryption(enable bool, key []byte) error {
|
|
globalSerializationManagerMu.Lock()
|
|
defer globalSerializationManagerMu.Unlock()
|
|
globalSerializationManager.config.EnableEncryption = enable
|
|
if enable {
|
|
return globalSerializationManager.SetEncryptionKey(key)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func Marshal(v any) ([]byte, error) {
|
|
globalSerializationManagerMu.RLock()
|
|
defer globalSerializationManagerMu.RUnlock()
|
|
return globalSerializationManager.Marshal(v)
|
|
}
|
|
|
|
func Unmarshal(data []byte, v any) error {
|
|
globalSerializationManagerMu.RLock()
|
|
defer globalSerializationManagerMu.RUnlock()
|
|
return globalSerializationManager.Unmarshal(data, v)
|
|
}
|