feat: implement TLS support

This commit is contained in:
sujit
2024-10-05 11:47:35 +05:45
parent a1ef941268
commit 9571676dd9
10 changed files with 1423 additions and 179 deletions

View File

@@ -2,13 +2,7 @@ package codec
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"io"
@@ -18,164 +12,216 @@ import (
type Message struct {
Headers map[string]string `json:"h"`
Topic string `json:"t"`
Queue string `json:"q"`
Command consts.CMD `json:"c"`
Payload json.RawMessage `json:"p"`
// Metadata map[string]any `json:"m"`
}
func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message {
return &Message{
Headers: headers,
Queue: queue,
Command: cmd,
Payload: payload,
// Metadata: nil,
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, nil, err
}
nonce := make([]byte, aesGCM.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, nil, err
}
ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil)
return ciphertext, nonce, nil
}
func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aesGCM.Open(nil, nonce, ciphertext, nil)
}
func CalculateHMAC(key []byte, data []byte) string {
h := hmac.New(sha256.New, key)
h.Write(data)
return hex.EncodeToString(h.Sum(nil))
}
func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool {
expectedHMAC := CalculateHMAC(key, data)
return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC))
}
func (m *Message) Serialize(aesKey []byte, hmacKey []byte) ([]byte, string, error) {
func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) {
var buf bytes.Buffer
headersBytes, err := json.Marshal(m.Headers)
if err != nil {
return nil, "", fmt.Errorf("error marshalling headers: %v", err)
// Serialize Headers, Queue, Command, Payload, and Metadata
if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil {
return nil, "", fmt.Errorf("error serializing headers: %v", err)
}
headersLen := uint32(len(headersBytes))
if err := binary.Write(&buf, binary.LittleEndian, headersLen); err != nil {
return nil, "", err
}
buf.Write(headersBytes)
topicBytes := []byte(m.Topic)
topicLen := uint8(len(topicBytes))
if err := binary.Write(&buf, binary.LittleEndian, topicLen); err != nil {
return nil, "", err
}
buf.Write(topicBytes)
if !m.Command.IsValid() {
return nil, "", fmt.Errorf("invalid command: %s", m.Command)
if err := writeLengthPrefixed(&buf, []byte(m.Queue)); err != nil {
return nil, "", fmt.Errorf("error serializing topic: %v", err)
}
if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil {
return nil, "", fmt.Errorf("error serializing command: %v", err)
}
if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil {
return nil, "", err
}
payloadBytes, err := json.Marshal(m.Payload)
if err != nil {
return nil, "", fmt.Errorf("error marshalling payload: %v", err)
}
encryptedPayload, nonce, err := EncryptPayload(aesKey, payloadBytes)
if err != nil {
return nil, "", err
}
payloadLen := uint32(len(encryptedPayload))
if err := binary.Write(&buf, binary.LittleEndian, payloadLen); err != nil {
return nil, "", err
}
buf.Write(encryptedPayload)
buf.Write(nonce)
/*if err := writeLengthPrefixedJSON(&buf, m.Metadata); err != nil {
return nil, "", fmt.Errorf("error serializing metadata: %v", err)
}*/
// Calculate HMAC
messageBytes := buf.Bytes()
hmacSignature := CalculateHMAC(hmacKey, messageBytes)
return messageBytes, hmacSignature, nil
}
func Deserialize(data []byte, aesKey []byte, hmacKey []byte, receivedHMAC string) (*Message, error) {
func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) {
if !VerifyHMAC(hmacKey, data, receivedHMAC) {
return nil, fmt.Errorf("HMAC verification failed")
}
buf := bytes.NewReader(data)
var topicLen uint8
var payloadLen uint32
var headersLen uint32
if err := binary.Read(buf, binary.LittleEndian, &headersLen); err != nil {
return nil, err
// Deserialize Headers, Queue, Command, Payload, and Metadata
headers := make(map[string]string)
if err := readLengthPrefixedJSON(buf, &headers); err != nil {
return nil, fmt.Errorf("error deserializing headers: %v", err)
}
headersBytes := make([]byte, headersLen)
if _, err := buf.Read(headersBytes); err != nil {
return nil, err
topic, err := readLengthPrefixedString(buf)
if err != nil {
return nil, fmt.Errorf("error deserializing topic: %v", err)
}
var headers map[string]string
if err := json.Unmarshal(headersBytes, &headers); err != nil {
return nil, err
}
if err := binary.Read(buf, binary.LittleEndian, &topicLen); err != nil {
return nil, err
}
topicBytes := make([]byte, topicLen)
if _, err := buf.Read(topicBytes); err != nil {
return nil, err
}
topic := string(topicBytes)
var command consts.CMD
if err := binary.Read(buf, binary.LittleEndian, &command); err != nil {
return nil, err
return nil, fmt.Errorf("error deserializing command: %v", err)
}
if !command.IsValid() {
return nil, fmt.Errorf("invalid command: %s", command)
payload, err := readPayload(buf, aesKey, decrypt)
if err != nil {
return nil, fmt.Errorf("error deserializing payload: %v", err)
}
if err := binary.Read(buf, binary.LittleEndian, &payloadLen); err != nil {
return nil, err
/*metadata := make(map[string]any)
if err := readLengthPrefixedJSON(buf, &metadata); err != nil {
return nil, fmt.Errorf("error deserializing metadata: %v", err)
}*/
return &Message{
Headers: headers,
Queue: topic,
Command: command,
Payload: payload,
// Metadata: metadata,
}, nil
}
func SendMessage(conn io.Writer, msg *Message, aesKey, hmacKey []byte, encrypt bool) error {
sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey, encrypt)
if err != nil {
return fmt.Errorf("error serializing message: %v", err)
}
encryptedPayload := make([]byte, payloadLen)
if _, err := buf.Read(encryptedPayload); err != nil {
return nil, err
if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil {
return fmt.Errorf("error sending message: %v", err)
}
nonce := make([]byte, 12)
if _, err := buf.Read(nonce); err != nil {
return nil, err
}
payloadBytes, err := DecryptPayload(aesKey, encryptedPayload, nonce)
return nil
}
func ReadMessage(conn io.Reader, aesKey, hmacKey []byte, decrypt bool) (*Message, error) {
data, receivedHMAC, err := readMessageWithHMAC(conn)
if err != nil {
return nil, err
}
return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt)
}
func writeLengthPrefixed(buf *bytes.Buffer, data []byte) error {
length := uint32(len(data))
if err := binary.Write(buf, binary.LittleEndian, length); err != nil {
return err
}
buf.Write(data)
return nil
}
func writeLengthPrefixedJSON(buf *bytes.Buffer, v any) error {
data, err := json.Marshal(v)
if err != nil {
return err
}
return writeLengthPrefixed(buf, data)
}
func readLengthPrefixed(r *bytes.Reader) ([]byte, error) {
var length uint32
if err := binary.Read(r, binary.LittleEndian, &length); err != nil {
return nil, err
}
data := make([]byte, length)
if _, err := r.Read(data); err != nil {
return nil, err
}
return data, nil
}
func readLengthPrefixedJSON(r *bytes.Reader, v any) error {
data, err := readLengthPrefixed(r)
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
func readLengthPrefixedString(r *bytes.Reader) (string, error) {
data, err := readLengthPrefixed(r)
if err != nil {
return "", err
}
return string(data), nil
}
func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, encrypt bool) error {
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("error marshalling payload: %v", err)
}
var encryptedPayload, nonce []byte
if encrypt {
encryptedPayload, nonce, err = EncryptPayload(aesKey, payloadBytes)
if err != nil {
return err
}
} else {
encryptedPayload = payloadBytes
}
if err := writeLengthPrefixed(buf, encryptedPayload); err != nil {
return err
}
if encrypt {
buf.Write(nonce)
}
return nil
}
func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage, error) {
encryptedPayload, err := readLengthPrefixed(r)
if err != nil {
return nil, err
}
var payloadBytes []byte
if decrypt {
nonce := make([]byte, 12)
if _, err := r.Read(nonce); err != nil {
return nil, err
}
payloadBytes, err = DecryptPayload(aesKey, encryptedPayload, nonce)
if err != nil {
return nil, fmt.Errorf("error decrypting payload: %v", err)
}
} else {
payloadBytes = encryptedPayload
}
var payload json.RawMessage
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return nil, fmt.Errorf("error unmarshalling payload: %v", err)
}
return &Message{
Headers: headers,
Topic: topic,
Command: command,
Payload: payload,
}, nil
}
func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) error {
sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey)
if err != nil {
return fmt.Errorf("error serializing message: %v", err)
}
if err := binary.Write(conn, binary.LittleEndian, uint32(len(sentData))); err != nil {
return payload, nil
}
func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error {
if err := binary.Write(conn, binary.LittleEndian, uint32(len(messageBytes))); err != nil {
return err
}
if _, err := conn.Write(sentData); err != nil {
if _, err := conn.Write(messageBytes); err != nil {
return err
}
if _, err := conn.Write([]byte(hmacSignature)); err != nil {
@@ -184,20 +230,21 @@ func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) er
return nil
}
func ReadMessage(conn io.Reader, aesKey []byte, hmacKey []byte) (*Message, error) {
func readMessageWithHMAC(conn io.Reader) ([]byte, string, error) {
var length uint32
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
return nil, err
return nil, "", err
}
data := make([]byte, length)
if _, err := io.ReadFull(conn, data); err != nil {
return nil, err
return nil, "", err
}
hmacBytes := make([]byte, 64)
if _, err := io.ReadFull(conn, hmacBytes); err != nil {
return nil, err
return nil, "", err
}
receivedHMAC := string(hmacBytes)
return Deserialize(data, aesKey, hmacKey, receivedHMAC)
return data, receivedHMAC, nil
}