Files
mq/codec/serializer.go
2025-07-31 09:31:28 +05:45

419 lines
10 KiB
Go

package codec
import (
"bytes"
"compress/gzip"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/oarkflow/json"
"golang.org/x/crypto/chacha20poly1305"
)
// Error definitions for serialization
var (
ErrSerializationFailed = errors.New("serialization failed")
ErrDeserializationFailed = errors.New("deserialization failed")
ErrCompressionFailed = errors.New("compression failed")
ErrDecompressionFailed = errors.New("decompression failed")
ErrEncryptionFailed = errors.New("encryption failed")
ErrDecryptionFailed = errors.New("decryption failed")
ErrInvalidKey = errors.New("invalid encryption key")
)
// ContentType represents the content type of serialized data
type ContentType string
const (
ContentTypeJSON ContentType = "application/json"
ContentTypeMsgPack ContentType = "application/msgpack"
ContentTypeCBOR ContentType = "application/cbor"
)
// Marshaller interface for pluggable serialization
type Marshaller interface {
Marshal(v any) ([]byte, error)
ContentType() ContentType
}
// Unmarshaller interface for pluggable deserialization
type Unmarshaller interface {
Unmarshal(data []byte, v any) error
ContentType() ContentType
}
// MarshallerFunc adapter
type MarshallerFunc func(v any) ([]byte, error)
func (f MarshallerFunc) Marshal(v any) ([]byte, error) {
return f(v)
}
func (f MarshallerFunc) ContentType() ContentType {
return ContentTypeJSON
}
// UnmarshallerFunc adapter
type UnmarshallerFunc func(data []byte, v any) error
func (f UnmarshallerFunc) Unmarshal(data []byte, v any) error {
return f(data, v)
}
func (f UnmarshallerFunc) ContentType() ContentType {
return ContentTypeJSON
}
// SerializationConfig holds serialization configuration
type SerializationConfig struct {
EnableCompression bool
CompressionLevel int
MaxCompressionRatio float64
EnableEncryption bool
EncryptionKey []byte
PreferredCipher string // "chacha20poly1305" or "aes-gcm"
}
// DefaultSerializationConfig returns default configuration
func DefaultSerializationConfig() *SerializationConfig {
return &SerializationConfig{
EnableCompression: false,
CompressionLevel: gzip.DefaultCompression,
MaxCompressionRatio: 0.8, // Only compress if we save at least 20%
EnableEncryption: false,
PreferredCipher: "chacha20poly1305",
}
}
// SerializationManager manages serialization with configuration
type SerializationManager struct {
marshaller Marshaller
unmarshaller Unmarshaller
config *SerializationConfig
mu sync.RWMutex
cachedCipher cipher.AEAD
cipherMu sync.Mutex
}
// NewSerializationManager creates a new serialization manager
func NewSerializationManager(config *SerializationConfig) *SerializationManager {
if config == nil {
config = DefaultSerializationConfig()
}
return &SerializationManager{
marshaller: MarshallerFunc(json.Marshal),
unmarshaller: UnmarshallerFunc(func(data []byte, v any) error {
return json.Unmarshal(data, v)
}),
config: config,
}
}
// SetMarshaller sets custom marshaller
func (sm *SerializationManager) SetMarshaller(marshaller Marshaller) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.marshaller = marshaller
}
// SetUnmarshaller sets custom unmarshaller
func (sm *SerializationManager) SetUnmarshaller(unmarshaller Unmarshaller) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.unmarshaller = unmarshaller
}
// SetEncryptionKey sets the encryption key
func (sm *SerializationManager) SetEncryptionKey(key []byte) error {
if sm.config.EnableEncryption && len(key) == 0 {
return ErrInvalidKey
}
sm.mu.Lock()
defer sm.mu.Unlock()
sm.config.EncryptionKey = key
// Clear the cached cipher so it will be recreated with the new key
sm.cipherMu.Lock()
defer sm.cipherMu.Unlock()
sm.cachedCipher = nil
return nil
}
// Marshal serializes data with optional compression and encryption
func (sm *SerializationManager) Marshal(v any) ([]byte, error) {
sm.mu.RLock()
marshaller := sm.marshaller
config := sm.config
sm.mu.RUnlock()
// Serialize the data
data, err := marshaller.Marshal(v)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrSerializationFailed, err)
}
// Apply compression if enabled and beneficial
if config.EnableCompression && len(data) > 256 { // Only compress larger payloads
compressed, err := sm.compress(data)
if err != nil {
// Continue with uncompressed data on error
// but don't return error to allow for graceful degradation
} else if float64(len(compressed))/float64(len(data)) <= config.MaxCompressionRatio {
// Only use compression if it provides significant benefit
data = compressed
}
}
// Apply encryption if enabled
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
encrypted, err := sm.encrypt(data)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
}
data = encrypted
}
return data, nil
}
// Unmarshal deserializes data with optional decompression and decryption
func (sm *SerializationManager) Unmarshal(data []byte, v any) error {
if len(data) == 0 {
return fmt.Errorf("%w: empty data", ErrDeserializationFailed)
}
sm.mu.RLock()
unmarshaller := sm.unmarshaller
config := sm.config
sm.mu.RUnlock()
// Apply decryption if enabled
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
decrypted, err := sm.decrypt(data)
if err != nil {
return fmt.Errorf("%w: %v", ErrDecryptionFailed, err)
}
data = decrypted
}
// Try decompression if enabled
if config.EnableCompression && sm.isCompressed(data) {
decompressed, err := sm.decompress(data)
if err != nil {
return fmt.Errorf("%w: %v", ErrDecompressionFailed, err)
}
data = decompressed
}
// Deserialize the data
if err := unmarshaller.Unmarshal(data, v); err != nil {
return fmt.Errorf("%w: %v", ErrDeserializationFailed, err)
}
return nil
}
// compress compresses data using gzip
func (sm *SerializationManager) compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
// Add compression marker
buf.WriteByte(0x1f) // gzip magic number
writer, err := gzip.NewWriterLevel(&buf, sm.config.CompressionLevel)
if err != nil {
return nil, err
}
if _, err := writer.Write(data); err != nil {
writer.Close()
return nil, err
}
if err := writer.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decompress decompresses gzip data
func (sm *SerializationManager) decompress(data []byte) ([]byte, error) {
if len(data) < 1 {
return nil, fmt.Errorf("invalid compressed data")
}
// Remove compression marker
if data[0] != 0x1f {
return data, nil // Not compressed
}
reader, err := gzip.NewReader(bytes.NewReader(data[1:]))
if err != nil {
return nil, err
}
defer reader.Close()
var buf bytes.Buffer
if _, err := buf.ReadFrom(reader); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// isCompressed checks if data is compressed
func (sm *SerializationManager) isCompressed(data []byte) bool {
return len(data) > 0 && data[0] == 0x1f
}
// getCipher returns a cipher.AEAD instance for encryption/decryption
func (sm *SerializationManager) getCipher() (cipher.AEAD, error) {
sm.cipherMu.Lock()
defer sm.cipherMu.Unlock()
if sm.cachedCipher != nil {
return sm.cachedCipher, nil
}
var aead cipher.AEAD
var err error
sm.mu.RLock()
key := sm.config.EncryptionKey
preferredCipher := sm.config.PreferredCipher
sm.mu.RUnlock()
switch preferredCipher {
case "chacha20poly1305":
aead, err = chacha20poly1305.New(key)
case "aes-gcm":
block, e := aes.NewCipher(key)
if e != nil {
err = e
break
}
aead, err = cipher.NewGCM(block)
default:
// Default to ChaCha20-Poly1305
aead, err = chacha20poly1305.New(key)
}
if err != nil {
return nil, err
}
sm.cachedCipher = aead
return aead, nil
}
// encrypt encrypts data using authenticated encryption
func (sm *SerializationManager) encrypt(data []byte) ([]byte, error) {
aead, err := sm.getCipher()
if err != nil {
return nil, err
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
// Add timestamp to associated data for replay protection
timestampBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timestampBytes, uint64(time.Now().Unix()))
// Encrypt and authenticate
ciphertext := aead.Seal(nil, nonce, data, timestampBytes)
// Prepend nonce and timestamp to ciphertext
result := make([]byte, len(nonce)+len(timestampBytes)+len(ciphertext))
copy(result, nonce)
copy(result[len(nonce):], timestampBytes)
copy(result[len(nonce)+len(timestampBytes):], ciphertext)
return result, nil
}
// decrypt decrypts data using authenticated decryption
func (sm *SerializationManager) decrypt(data []byte) ([]byte, error) {
aead, err := sm.getCipher()
if err != nil {
return nil, err
}
// Extract nonce
nonceSize := aead.NonceSize()
if len(data) < nonceSize+8 { // 8 bytes for timestamp
return nil, fmt.Errorf("ciphertext too short")
}
nonce := data[:nonceSize]
timestampBytes := data[nonceSize : nonceSize+8]
ciphertext := data[nonceSize+8:]
// Decrypt and verify
plaintext, err := aead.Open(nil, nonce, ciphertext, timestampBytes)
if err != nil {
return nil, err
}
return plaintext, nil
}
// Global serialization manager instance
var globalSerializationManager = NewSerializationManager(nil)
var globalSerializationManagerMu sync.RWMutex
// Global functions for backward compatibility
func SetMarshaller(marshaller MarshallerFunc) {
globalSerializationManagerMu.Lock()
defer globalSerializationManagerMu.Unlock()
globalSerializationManager.SetMarshaller(marshaller)
}
func SetUnmarshaller(unmarshaller UnmarshallerFunc) {
globalSerializationManagerMu.Lock()
defer globalSerializationManagerMu.Unlock()
globalSerializationManager.SetUnmarshaller(unmarshaller)
}
func EnableCompression(enable bool) {
globalSerializationManagerMu.Lock()
defer globalSerializationManagerMu.Unlock()
globalSerializationManager.config.EnableCompression = enable
}
func EnableEncryption(enable bool, key []byte) error {
globalSerializationManagerMu.Lock()
defer globalSerializationManagerMu.Unlock()
globalSerializationManager.config.EnableEncryption = enable
if enable {
return globalSerializationManager.SetEncryptionKey(key)
}
return nil
}
func Marshal(v any) ([]byte, error) {
globalSerializationManagerMu.RLock()
defer globalSerializationManagerMu.RUnlock()
return globalSerializationManager.Marshal(v)
}
func Unmarshal(data []byte, v any) error {
globalSerializationManagerMu.RLock()
defer globalSerializationManagerMu.RUnlock()
return globalSerializationManager.Unmarshal(data, v)
}