mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-14 13:53:44 +08:00
feat: use "GetTags" and "SetTags"
This commit is contained in:
@@ -1,39 +1,418 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
type MarshallerFunc func(v any) ([]byte, error)
|
||||
// Error definitions for serialization
|
||||
var (
|
||||
ErrSerializationFailed = errors.New("serialization failed")
|
||||
ErrDeserializationFailed = errors.New("deserialization failed")
|
||||
ErrCompressionFailed = errors.New("compression failed")
|
||||
ErrDecompressionFailed = errors.New("decompression failed")
|
||||
ErrEncryptionFailed = errors.New("encryption failed")
|
||||
ErrDecryptionFailed = errors.New("decryption failed")
|
||||
ErrInvalidKey = errors.New("invalid encryption key")
|
||||
)
|
||||
|
||||
type UnmarshallerFunc func(data []byte, v any) error
|
||||
// ContentType represents the content type of serialized data
|
||||
type ContentType string
|
||||
|
||||
const (
|
||||
ContentTypeJSON ContentType = "application/json"
|
||||
ContentTypeMsgPack ContentType = "application/msgpack"
|
||||
ContentTypeCBOR ContentType = "application/cbor"
|
||||
)
|
||||
|
||||
// Marshaller interface for pluggable serialization
|
||||
type Marshaller interface {
|
||||
Marshal(v any) ([]byte, error)
|
||||
ContentType() ContentType
|
||||
}
|
||||
|
||||
// Unmarshaller interface for pluggable deserialization
|
||||
type Unmarshaller interface {
|
||||
Unmarshal(data []byte, v any) error
|
||||
ContentType() ContentType
|
||||
}
|
||||
|
||||
// MarshallerFunc adapter
|
||||
type MarshallerFunc func(v any) ([]byte, error)
|
||||
|
||||
func (f MarshallerFunc) Marshal(v any) ([]byte, error) {
|
||||
return f(v)
|
||||
}
|
||||
|
||||
func (f MarshallerFunc) ContentType() ContentType {
|
||||
return ContentTypeJSON
|
||||
}
|
||||
|
||||
// UnmarshallerFunc adapter
|
||||
type UnmarshallerFunc func(data []byte, v any) error
|
||||
|
||||
func (f UnmarshallerFunc) Unmarshal(data []byte, v any) error {
|
||||
return f(data, v)
|
||||
}
|
||||
|
||||
var defaultMarshaller MarshallerFunc = json.Marshal
|
||||
|
||||
var defaultUnmarshaller UnmarshallerFunc = func(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
func (f UnmarshallerFunc) ContentType() ContentType {
|
||||
return ContentTypeJSON
|
||||
}
|
||||
|
||||
// SerializationConfig holds serialization configuration
|
||||
type SerializationConfig struct {
|
||||
EnableCompression bool
|
||||
CompressionLevel int
|
||||
MaxCompressionRatio float64
|
||||
EnableEncryption bool
|
||||
EncryptionKey []byte
|
||||
PreferredCipher string // "chacha20poly1305" or "aes-gcm"
|
||||
}
|
||||
|
||||
// DefaultSerializationConfig returns default configuration
|
||||
func DefaultSerializationConfig() *SerializationConfig {
|
||||
return &SerializationConfig{
|
||||
EnableCompression: false,
|
||||
CompressionLevel: gzip.DefaultCompression,
|
||||
MaxCompressionRatio: 0.8, // Only compress if we save at least 20%
|
||||
EnableEncryption: false,
|
||||
PreferredCipher: "chacha20poly1305",
|
||||
}
|
||||
}
|
||||
|
||||
// SerializationManager manages serialization with configuration
|
||||
type SerializationManager struct {
|
||||
marshaller Marshaller
|
||||
unmarshaller Unmarshaller
|
||||
config *SerializationConfig
|
||||
mu sync.RWMutex
|
||||
cachedCipher cipher.AEAD
|
||||
cipherMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewSerializationManager creates a new serialization manager
|
||||
func NewSerializationManager(config *SerializationConfig) *SerializationManager {
|
||||
if config == nil {
|
||||
config = DefaultSerializationConfig()
|
||||
}
|
||||
|
||||
return &SerializationManager{
|
||||
marshaller: MarshallerFunc(json.Marshal),
|
||||
unmarshaller: UnmarshallerFunc(func(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// SetMarshaller sets custom marshaller
|
||||
func (sm *SerializationManager) SetMarshaller(marshaller Marshaller) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.marshaller = marshaller
|
||||
}
|
||||
|
||||
// SetUnmarshaller sets custom unmarshaller
|
||||
func (sm *SerializationManager) SetUnmarshaller(unmarshaller Unmarshaller) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.unmarshaller = unmarshaller
|
||||
}
|
||||
|
||||
// SetEncryptionKey sets the encryption key
|
||||
func (sm *SerializationManager) SetEncryptionKey(key []byte) error {
|
||||
if sm.config.EnableEncryption && len(key) == 0 {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.config.EncryptionKey = key
|
||||
// Clear the cached cipher so it will be recreated with the new key
|
||||
sm.cipherMu.Lock()
|
||||
defer sm.cipherMu.Unlock()
|
||||
sm.cachedCipher = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal serializes data with optional compression and encryption
|
||||
func (sm *SerializationManager) Marshal(v any) ([]byte, error) {
|
||||
sm.mu.RLock()
|
||||
marshaller := sm.marshaller
|
||||
config := sm.config
|
||||
sm.mu.RUnlock()
|
||||
|
||||
// Serialize the data
|
||||
data, err := marshaller.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrSerializationFailed, err)
|
||||
}
|
||||
|
||||
// Apply compression if enabled and beneficial
|
||||
if config.EnableCompression && len(data) > 256 { // Only compress larger payloads
|
||||
compressed, err := sm.compress(data)
|
||||
if err != nil {
|
||||
// Continue with uncompressed data on error
|
||||
// but don't return error to allow for graceful degradation
|
||||
} else if float64(len(compressed))/float64(len(data)) <= config.MaxCompressionRatio {
|
||||
// Only use compression if it provides significant benefit
|
||||
data = compressed
|
||||
}
|
||||
}
|
||||
|
||||
// Apply encryption if enabled
|
||||
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
|
||||
encrypted, err := sm.encrypt(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
|
||||
}
|
||||
data = encrypted
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Unmarshal deserializes data with optional decompression and decryption
|
||||
func (sm *SerializationManager) Unmarshal(data []byte, v any) error {
|
||||
if len(data) == 0 {
|
||||
return fmt.Errorf("%w: empty data", ErrDeserializationFailed)
|
||||
}
|
||||
|
||||
sm.mu.RLock()
|
||||
unmarshaller := sm.unmarshaller
|
||||
config := sm.config
|
||||
sm.mu.RUnlock()
|
||||
|
||||
// Apply decryption if enabled
|
||||
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
|
||||
decrypted, err := sm.decrypt(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrDecryptionFailed, err)
|
||||
}
|
||||
data = decrypted
|
||||
}
|
||||
|
||||
// Try decompression if enabled
|
||||
if config.EnableCompression && sm.isCompressed(data) {
|
||||
decompressed, err := sm.decompress(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrDecompressionFailed, err)
|
||||
}
|
||||
data = decompressed
|
||||
}
|
||||
|
||||
// Deserialize the data
|
||||
if err := unmarshaller.Unmarshal(data, v); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrDeserializationFailed, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// compress compresses data using gzip
|
||||
func (sm *SerializationManager) compress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Add compression marker
|
||||
buf.WriteByte(0x1f) // gzip magic number
|
||||
|
||||
writer, err := gzip.NewWriterLevel(&buf, sm.config.CompressionLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := writer.Write(data); err != nil {
|
||||
writer.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// decompress decompresses gzip data
|
||||
func (sm *SerializationManager) decompress(data []byte) ([]byte, error) {
|
||||
if len(data) < 1 {
|
||||
return nil, fmt.Errorf("invalid compressed data")
|
||||
}
|
||||
|
||||
// Remove compression marker
|
||||
if data[0] != 0x1f {
|
||||
return data, nil // Not compressed
|
||||
}
|
||||
|
||||
reader, err := gzip.NewReader(bytes.NewReader(data[1:]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := buf.ReadFrom(reader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// isCompressed checks if data is compressed
|
||||
func (sm *SerializationManager) isCompressed(data []byte) bool {
|
||||
return len(data) > 0 && data[0] == 0x1f
|
||||
}
|
||||
|
||||
// getCipher returns a cipher.AEAD instance for encryption/decryption
|
||||
func (sm *SerializationManager) getCipher() (cipher.AEAD, error) {
|
||||
sm.cipherMu.Lock()
|
||||
defer sm.cipherMu.Unlock()
|
||||
|
||||
if sm.cachedCipher != nil {
|
||||
return sm.cachedCipher, nil
|
||||
}
|
||||
|
||||
var aead cipher.AEAD
|
||||
var err error
|
||||
|
||||
sm.mu.RLock()
|
||||
key := sm.config.EncryptionKey
|
||||
preferredCipher := sm.config.PreferredCipher
|
||||
sm.mu.RUnlock()
|
||||
|
||||
switch preferredCipher {
|
||||
case "chacha20poly1305":
|
||||
aead, err = chacha20poly1305.New(key)
|
||||
case "aes-gcm":
|
||||
block, e := aes.NewCipher(key)
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
aead, err = cipher.NewGCM(block)
|
||||
default:
|
||||
// Default to ChaCha20-Poly1305
|
||||
aead, err = chacha20poly1305.New(key)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sm.cachedCipher = aead
|
||||
return aead, nil
|
||||
}
|
||||
|
||||
// encrypt encrypts data using authenticated encryption
|
||||
func (sm *SerializationManager) encrypt(data []byte) ([]byte, error) {
|
||||
aead, err := sm.getCipher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add timestamp to associated data for replay protection
|
||||
timestampBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(timestampBytes, uint64(time.Now().Unix()))
|
||||
|
||||
// Encrypt and authenticate
|
||||
ciphertext := aead.Seal(nil, nonce, data, timestampBytes)
|
||||
|
||||
// Prepend nonce and timestamp to ciphertext
|
||||
result := make([]byte, len(nonce)+len(timestampBytes)+len(ciphertext))
|
||||
copy(result, nonce)
|
||||
copy(result[len(nonce):], timestampBytes)
|
||||
copy(result[len(nonce)+len(timestampBytes):], ciphertext)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// decrypt decrypts data using authenticated decryption
|
||||
func (sm *SerializationManager) decrypt(data []byte) ([]byte, error) {
|
||||
aead, err := sm.getCipher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract nonce
|
||||
nonceSize := aead.NonceSize()
|
||||
if len(data) < nonceSize+8 { // 8 bytes for timestamp
|
||||
return nil, fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce := data[:nonceSize]
|
||||
timestampBytes := data[nonceSize : nonceSize+8]
|
||||
ciphertext := data[nonceSize+8:]
|
||||
|
||||
// Decrypt and verify
|
||||
plaintext, err := aead.Open(nil, nonce, ciphertext, timestampBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// Global serialization manager instance
|
||||
var globalSerializationManager = NewSerializationManager(nil)
|
||||
var globalSerializationManagerMu sync.RWMutex
|
||||
|
||||
// Global functions for backward compatibility
|
||||
func SetMarshaller(marshaller MarshallerFunc) {
|
||||
defaultMarshaller = marshaller
|
||||
globalSerializationManagerMu.Lock()
|
||||
defer globalSerializationManagerMu.Unlock()
|
||||
globalSerializationManager.SetMarshaller(marshaller)
|
||||
}
|
||||
|
||||
func SetUnmarshaller(unmarshaller UnmarshallerFunc) {
|
||||
defaultUnmarshaller = unmarshaller
|
||||
globalSerializationManagerMu.Lock()
|
||||
defer globalSerializationManagerMu.Unlock()
|
||||
globalSerializationManager.SetUnmarshaller(unmarshaller)
|
||||
}
|
||||
|
||||
func EnableCompression(enable bool) {
|
||||
globalSerializationManagerMu.Lock()
|
||||
defer globalSerializationManagerMu.Unlock()
|
||||
globalSerializationManager.config.EnableCompression = enable
|
||||
}
|
||||
|
||||
func EnableEncryption(enable bool, key []byte) error {
|
||||
globalSerializationManagerMu.Lock()
|
||||
defer globalSerializationManagerMu.Unlock()
|
||||
globalSerializationManager.config.EnableEncryption = enable
|
||||
if enable {
|
||||
return globalSerializationManager.SetEncryptionKey(key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return defaultMarshaller(v)
|
||||
globalSerializationManagerMu.RLock()
|
||||
defer globalSerializationManagerMu.RUnlock()
|
||||
return globalSerializationManager.Marshal(v)
|
||||
}
|
||||
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return defaultUnmarshaller(data, v)
|
||||
globalSerializationManagerMu.RLock()
|
||||
defer globalSerializationManagerMu.RUnlock()
|
||||
return globalSerializationManager.Unmarshal(data, v)
|
||||
}
|
||||
|
Reference in New Issue
Block a user