Files
pg/secure/aescbc/aescbc.go
2024-09-27 11:20:17 +08:00

159 lines
3.5 KiB
Go

package aescbc
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"
"sync"
"github.com/sigcn/pg/lru"
"github.com/sigcn/pg/secure"
)
func PKCS7Padding(data []byte, blockSize int) []byte {
padding := blockSize - len(data)%blockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(data, padText...)
}
func PKCS7UnPadding(data []byte) ([]byte, error) {
length := len(data)
padding := int(data[length-1])
if padding < 1 || padding > aes.BlockSize {
return nil, fmt.Errorf("invalid padding size")
}
for i := length - padding; i < length; i++ {
if data[i] != byte(padding) {
return nil, fmt.Errorf("invalid padding")
}
}
return data[:len(data)-int(padding)], nil
}
func Encrypt(key, data []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
dataBytes := PKCS7Padding([]byte(data), aes.BlockSize)
cipherText := make([]byte, aes.BlockSize+len(dataBytes))
iv := cipherText[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(cipherText[aes.BlockSize:], dataBytes)
return cipherText, nil
}
func Decrypt(key, data []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
if len(data) < aes.BlockSize || len(data)%aes.BlockSize != 0 {
return nil, fmt.Errorf("invalid encrypted data")
}
iv := data[:aes.BlockSize]
data = data[aes.BlockSize:]
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(data, data)
return PKCS7UnPadding(data)
}
var _ secure.SymmAlgo = (*AESCBC)(nil)
type AESCBC struct {
mut sync.RWMutex
cipher *lru.Cache[string, cipher.Block]
provideSecretKey secure.ProvideSecretKey
}
func (s *AESCBC) Encrypt(b []byte, pubKey string) ([]byte, error) {
if s == nil {
return nil, errors.New("aesCBC is nil")
}
block, err := s.ensureChiperBlock(pubKey)
if err != nil {
return nil, err
}
dataBytes := PKCS7Padding([]byte(b), aes.BlockSize)
cipherText := make([]byte, aes.BlockSize+len(dataBytes))
iv := cipherText[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(cipherText[aes.BlockSize:], dataBytes)
return cipherText, nil
}
func (s *AESCBC) Decrypt(b []byte, pubKey string) ([]byte, error) {
if s == nil {
return nil, errors.New("aesCBC is nil")
}
block, err := s.ensureChiperBlock(pubKey)
if err != nil {
return nil, err
}
if len(b) < aes.BlockSize || len(b)%aes.BlockSize != 0 {
return nil, fmt.Errorf("invalid encrypted data")
}
iv := b[:aes.BlockSize]
b = b[aes.BlockSize:]
plainBytes := make([]byte, len(b))
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(plainBytes, b)
return PKCS7UnPadding(plainBytes)
}
func (s *AESCBC) SecretKey() secure.ProvideSecretKey {
return s.provideSecretKey
}
func (s *AESCBC) ensureChiperBlock(pubKey string) (cipher.Block, error) {
s.mut.RLock()
block, ok := s.cipher.Get(pubKey)
s.mut.RUnlock()
if !ok {
secretKey, err := s.provideSecretKey(pubKey)
if err != nil {
return nil, err
}
b, err := aes.NewCipher(secretKey)
if err != nil {
return nil, err
}
block = b
s.mut.Lock()
s.cipher.Put(pubKey, block)
s.mut.Unlock()
}
return block, nil
}
func New(provideSecretKey secure.ProvideSecretKey) secure.SymmAlgo {
return &AESCBC{
cipher: lru.New[string, cipher.Block](128),
provideSecretKey: provideSecretKey,
}
}