mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-06 16:36:53 +08:00
feat: implement TLS support
This commit is contained in:
305
codec/codec.go
305
codec/codec.go
@@ -2,13 +2,7 @@ package codec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -18,164 +12,216 @@ import (
|
|||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Headers map[string]string `json:"h"`
|
Headers map[string]string `json:"h"`
|
||||||
Topic string `json:"t"`
|
Queue string `json:"q"`
|
||||||
Command consts.CMD `json:"c"`
|
Command consts.CMD `json:"c"`
|
||||||
Payload json.RawMessage `json:"p"`
|
Payload json.RawMessage `json:"p"`
|
||||||
|
// Metadata map[string]any `json:"m"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) {
|
func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message {
|
||||||
block, err := aes.NewCipher(key)
|
return &Message{
|
||||||
if err != nil {
|
Headers: headers,
|
||||||
return nil, nil, err
|
Queue: queue,
|
||||||
|
Command: cmd,
|
||||||
|
Payload: payload,
|
||||||
|
// Metadata: nil,
|
||||||
}
|
}
|
||||||
aesGCM, err := cipher.NewGCM(block)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
nonce := make([]byte, aesGCM.NonceSize())
|
|
||||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil)
|
|
||||||
return ciphertext, nonce, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) {
|
func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) {
|
||||||
block, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
aesGCM, err := cipher.NewGCM(block)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return aesGCM.Open(nil, nonce, ciphertext, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CalculateHMAC(key []byte, data []byte) string {
|
|
||||||
h := hmac.New(sha256.New, key)
|
|
||||||
h.Write(data)
|
|
||||||
return hex.EncodeToString(h.Sum(nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool {
|
|
||||||
expectedHMAC := CalculateHMAC(key, data)
|
|
||||||
return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) Serialize(aesKey []byte, hmacKey []byte) ([]byte, string, error) {
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
headersBytes, err := json.Marshal(m.Headers)
|
|
||||||
if err != nil {
|
// Serialize Headers, Queue, Command, Payload, and Metadata
|
||||||
return nil, "", fmt.Errorf("error marshalling headers: %v", err)
|
if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("error serializing headers: %v", err)
|
||||||
}
|
}
|
||||||
headersLen := uint32(len(headersBytes))
|
if err := writeLengthPrefixed(&buf, []byte(m.Queue)); err != nil {
|
||||||
if err := binary.Write(&buf, binary.LittleEndian, headersLen); err != nil {
|
return nil, "", fmt.Errorf("error serializing topic: %v", err)
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
buf.Write(headersBytes)
|
|
||||||
topicBytes := []byte(m.Topic)
|
|
||||||
topicLen := uint8(len(topicBytes))
|
|
||||||
if err := binary.Write(&buf, binary.LittleEndian, topicLen); err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
buf.Write(topicBytes)
|
|
||||||
if !m.Command.IsValid() {
|
|
||||||
return nil, "", fmt.Errorf("invalid command: %s", m.Command)
|
|
||||||
}
|
}
|
||||||
if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil {
|
if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("error serializing command: %v", err)
|
||||||
|
}
|
||||||
|
if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
payloadBytes, err := json.Marshal(m.Payload)
|
/*if err := writeLengthPrefixedJSON(&buf, m.Metadata); err != nil {
|
||||||
if err != nil {
|
return nil, "", fmt.Errorf("error serializing metadata: %v", err)
|
||||||
return nil, "", fmt.Errorf("error marshalling payload: %v", err)
|
}*/
|
||||||
}
|
|
||||||
encryptedPayload, nonce, err := EncryptPayload(aesKey, payloadBytes)
|
// Calculate HMAC
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
payloadLen := uint32(len(encryptedPayload))
|
|
||||||
if err := binary.Write(&buf, binary.LittleEndian, payloadLen); err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
buf.Write(encryptedPayload)
|
|
||||||
buf.Write(nonce)
|
|
||||||
messageBytes := buf.Bytes()
|
messageBytes := buf.Bytes()
|
||||||
hmacSignature := CalculateHMAC(hmacKey, messageBytes)
|
hmacSignature := CalculateHMAC(hmacKey, messageBytes)
|
||||||
|
|
||||||
return messageBytes, hmacSignature, nil
|
return messageBytes, hmacSignature, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Deserialize(data []byte, aesKey []byte, hmacKey []byte, receivedHMAC string) (*Message, error) {
|
func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) {
|
||||||
if !VerifyHMAC(hmacKey, data, receivedHMAC) {
|
if !VerifyHMAC(hmacKey, data, receivedHMAC) {
|
||||||
return nil, fmt.Errorf("HMAC verification failed")
|
return nil, fmt.Errorf("HMAC verification failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := bytes.NewReader(data)
|
buf := bytes.NewReader(data)
|
||||||
var topicLen uint8
|
|
||||||
var payloadLen uint32
|
// Deserialize Headers, Queue, Command, Payload, and Metadata
|
||||||
var headersLen uint32
|
headers := make(map[string]string)
|
||||||
if err := binary.Read(buf, binary.LittleEndian, &headersLen); err != nil {
|
if err := readLengthPrefixedJSON(buf, &headers); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("error deserializing headers: %v", err)
|
||||||
}
|
}
|
||||||
headersBytes := make([]byte, headersLen)
|
|
||||||
if _, err := buf.Read(headersBytes); err != nil {
|
topic, err := readLengthPrefixedString(buf)
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error deserializing topic: %v", err)
|
||||||
}
|
}
|
||||||
var headers map[string]string
|
|
||||||
if err := json.Unmarshal(headersBytes, &headers); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := binary.Read(buf, binary.LittleEndian, &topicLen); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
topicBytes := make([]byte, topicLen)
|
|
||||||
if _, err := buf.Read(topicBytes); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
topic := string(topicBytes)
|
|
||||||
var command consts.CMD
|
var command consts.CMD
|
||||||
if err := binary.Read(buf, binary.LittleEndian, &command); err != nil {
|
if err := binary.Read(buf, binary.LittleEndian, &command); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("error deserializing command: %v", err)
|
||||||
}
|
}
|
||||||
if !command.IsValid() {
|
|
||||||
return nil, fmt.Errorf("invalid command: %s", command)
|
payload, err := readPayload(buf, aesKey, decrypt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error deserializing payload: %v", err)
|
||||||
}
|
}
|
||||||
if err := binary.Read(buf, binary.LittleEndian, &payloadLen); err != nil {
|
|
||||||
return nil, err
|
/*metadata := make(map[string]any)
|
||||||
|
if err := readLengthPrefixedJSON(buf, &metadata); err != nil {
|
||||||
|
return nil, fmt.Errorf("error deserializing metadata: %v", err)
|
||||||
|
}*/
|
||||||
|
|
||||||
|
return &Message{
|
||||||
|
Headers: headers,
|
||||||
|
Queue: topic,
|
||||||
|
Command: command,
|
||||||
|
Payload: payload,
|
||||||
|
// Metadata: metadata,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendMessage(conn io.Writer, msg *Message, aesKey, hmacKey []byte, encrypt bool) error {
|
||||||
|
sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey, encrypt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error serializing message: %v", err)
|
||||||
}
|
}
|
||||||
encryptedPayload := make([]byte, payloadLen)
|
|
||||||
if _, err := buf.Read(encryptedPayload); err != nil {
|
if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil {
|
||||||
return nil, err
|
return fmt.Errorf("error sending message: %v", err)
|
||||||
}
|
}
|
||||||
nonce := make([]byte, 12)
|
|
||||||
if _, err := buf.Read(nonce); err != nil {
|
return nil
|
||||||
return nil, err
|
}
|
||||||
}
|
|
||||||
payloadBytes, err := DecryptPayload(aesKey, encryptedPayload, nonce)
|
func ReadMessage(conn io.Reader, aesKey, hmacKey []byte, decrypt bool) (*Message, error) {
|
||||||
|
data, receivedHMAC, err := readMessageWithHMAC(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeLengthPrefixed(buf *bytes.Buffer, data []byte) error {
|
||||||
|
length := uint32(len(data))
|
||||||
|
if err := binary.Write(buf, binary.LittleEndian, length); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
buf.Write(data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeLengthPrefixedJSON(buf *bytes.Buffer, v any) error {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeLengthPrefixed(buf, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLengthPrefixed(r *bytes.Reader) ([]byte, error) {
|
||||||
|
var length uint32
|
||||||
|
if err := binary.Read(r, binary.LittleEndian, &length); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data := make([]byte, length)
|
||||||
|
if _, err := r.Read(data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLengthPrefixedJSON(r *bytes.Reader, v any) error {
|
||||||
|
data, err := readLengthPrefixed(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(data, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLengthPrefixedString(r *bytes.Reader) (string, error) {
|
||||||
|
data, err := readLengthPrefixed(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, encrypt bool) error {
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshalling payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var encryptedPayload, nonce []byte
|
||||||
|
if encrypt {
|
||||||
|
encryptedPayload, nonce, err = EncryptPayload(aesKey, payloadBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
encryptedPayload = payloadBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeLengthPrefixed(buf, encryptedPayload); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if encrypt {
|
||||||
|
buf.Write(nonce)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage, error) {
|
||||||
|
encryptedPayload, err := readLengthPrefixed(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var payloadBytes []byte
|
||||||
|
if decrypt {
|
||||||
|
nonce := make([]byte, 12)
|
||||||
|
if _, err := r.Read(nonce); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payloadBytes, err = DecryptPayload(aesKey, encryptedPayload, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error decrypting payload: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
payloadBytes = encryptedPayload
|
||||||
|
}
|
||||||
|
|
||||||
var payload json.RawMessage
|
var payload json.RawMessage
|
||||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||||
return nil, fmt.Errorf("error unmarshalling payload: %v", err)
|
return nil, fmt.Errorf("error unmarshalling payload: %v", err)
|
||||||
}
|
}
|
||||||
return &Message{
|
|
||||||
Headers: headers,
|
|
||||||
Topic: topic,
|
|
||||||
Command: command,
|
|
||||||
Payload: payload,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) error {
|
return payload, nil
|
||||||
sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey)
|
}
|
||||||
if err != nil {
|
func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error {
|
||||||
return fmt.Errorf("error serializing message: %v", err)
|
if err := binary.Write(conn, binary.LittleEndian, uint32(len(messageBytes))); err != nil {
|
||||||
}
|
|
||||||
if err := binary.Write(conn, binary.LittleEndian, uint32(len(sentData))); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if _, err := conn.Write(messageBytes); err != nil {
|
||||||
if _, err := conn.Write(sentData); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := conn.Write([]byte(hmacSignature)); err != nil {
|
if _, err := conn.Write([]byte(hmacSignature)); err != nil {
|
||||||
@@ -184,20 +230,21 @@ func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadMessage(conn io.Reader, aesKey []byte, hmacKey []byte) (*Message, error) {
|
func readMessageWithHMAC(conn io.Reader) ([]byte, string, error) {
|
||||||
var length uint32
|
var length uint32
|
||||||
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
|
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
data := make([]byte, length)
|
data := make([]byte, length)
|
||||||
if _, err := io.ReadFull(conn, data); err != nil {
|
if _, err := io.ReadFull(conn, data); err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
hmacBytes := make([]byte, 64)
|
hmacBytes := make([]byte, 64)
|
||||||
if _, err := io.ReadFull(conn, hmacBytes); err != nil {
|
if _, err := io.ReadFull(conn, hmacBytes); err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
receivedHMAC := string(hmacBytes)
|
receivedHMAC := string(hmacBytes)
|
||||||
return Deserialize(data, aesKey, hmacKey, receivedHMAC)
|
|
||||||
|
return data, receivedHMAC, nil
|
||||||
}
|
}
|
||||||
|
51
codec/encrypt.go
Normal file
51
codec/encrypt.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package codec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
nonce := make([]byte, aesGCM.NonceSize())
|
||||||
|
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil)
|
||||||
|
return ciphertext, nonce, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CalculateHMAC(key []byte, data []byte) string {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
h.Write(data)
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool {
|
||||||
|
expectedHMAC := CalculateHMAC(key, data)
|
||||||
|
return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC))
|
||||||
|
}
|
@@ -8,8 +8,15 @@ const (
|
|||||||
PING CMD = iota + 1
|
PING CMD = iota + 1
|
||||||
SUBSCRIBE
|
SUBSCRIBE
|
||||||
SUBSCRIBE_ACK
|
SUBSCRIBE_ACK
|
||||||
|
|
||||||
|
MESSAGE_SEND
|
||||||
|
MESSAGE_RESPONSE
|
||||||
|
MESSAGE_ACK
|
||||||
|
MESSAGE_ERROR
|
||||||
|
|
||||||
PUBLISH
|
PUBLISH
|
||||||
PUBLISH_ACK
|
PUBLISH_ACK
|
||||||
|
|
||||||
REQUEST
|
REQUEST
|
||||||
RESPONSE
|
RESPONSE
|
||||||
STOP
|
STOP
|
||||||
@@ -23,6 +30,14 @@ func (c CMD) String() string {
|
|||||||
return "SUBSCRIBE"
|
return "SUBSCRIBE"
|
||||||
case SUBSCRIBE_ACK:
|
case SUBSCRIBE_ACK:
|
||||||
return "SUBSCRIBE_ACK"
|
return "SUBSCRIBE_ACK"
|
||||||
|
case MESSAGE_SEND:
|
||||||
|
return "MESSAGE_SEND"
|
||||||
|
case MESSAGE_RESPONSE:
|
||||||
|
return "MESSAGE_RESPONSE"
|
||||||
|
case MESSAGE_ERROR:
|
||||||
|
return "MESSAGE_ERROR"
|
||||||
|
case MESSAGE_ACK:
|
||||||
|
return "MESSAGE_ACK"
|
||||||
case PUBLISH:
|
case PUBLISH:
|
||||||
return "PUBLISH"
|
return "PUBLISH"
|
||||||
case PUBLISH_ACK:
|
case PUBLISH_ACK:
|
||||||
@@ -40,6 +55,8 @@ var (
|
|||||||
ConsumerKey = "Consumer-Key"
|
ConsumerKey = "Consumer-Key"
|
||||||
PublisherKey = "Publisher-Key"
|
PublisherKey = "Publisher-Key"
|
||||||
ContentType = "Content-Type"
|
ContentType = "Content-Type"
|
||||||
|
AwaitResponseKey = "Await-Response"
|
||||||
|
QueueKey = "Queue"
|
||||||
TypeJson = "application/json"
|
TypeJson = "application/json"
|
||||||
HeaderKey = "headers"
|
HeaderKey = "headers"
|
||||||
TriggerNode = "triggerNode"
|
TriggerNode = "triggerNode"
|
||||||
|
@@ -1,68 +1,266 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq"
|
||||||
"github.com/oarkflow/mq/codec"
|
"github.com/oarkflow/mq/codec"
|
||||||
"github.com/oarkflow/mq/consts"
|
"github.com/oarkflow/mq/consts"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
type Broker struct {
|
||||||
aesKey := []byte("thisis32bytekeyforaesencryption1")
|
aesKey json.RawMessage
|
||||||
hmacKey := []byte("thisisasecrethmackey1")
|
hmacKey json.RawMessage
|
||||||
go func() {
|
subscribers map[string][]net.Conn
|
||||||
listener, err := net.Listen("tcp", ":8081")
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBroker(aesKey, hmacKey json.RawMessage) *Broker {
|
||||||
|
return &Broker{
|
||||||
|
aesKey: aesKey,
|
||||||
|
hmacKey: hmacKey,
|
||||||
|
subscribers: make(map[string][]net.Conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) addSubscriber(topic string, conn net.Conn) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
b.subscribers[topic] = append(b.subscribers[topic], conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) removeSubscriber(conn net.Conn) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
for topic, conns := range b.subscribers {
|
||||||
|
for i, c := range conns {
|
||||||
|
if c == conn {
|
||||||
|
b.subscribers[topic] = append(conns[:i], conns[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) broadcastToSubscribers(topic string, msg *codec.Message) {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
|
||||||
|
subscribers, ok := b.subscribers[topic]
|
||||||
|
if !ok || len(subscribers) == 0 {
|
||||||
|
log.Printf("No subscribers for topic: %s", topic)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, conn := range subscribers {
|
||||||
|
err := codec.SendMessage(conn, msg, b.aesKey, b.hmacKey, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Printf("Error sending message to subscriber: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||||
|
switch msg.Command {
|
||||||
|
case consts.PUBLISH:
|
||||||
|
b.broadcastToSubscribers(msg.Queue, msg)
|
||||||
|
ack := &codec.Message{
|
||||||
|
Headers: msg.Headers,
|
||||||
|
Queue: msg.Queue,
|
||||||
|
Command: consts.PUBLISH_ACK,
|
||||||
|
}
|
||||||
|
if err := codec.SendMessage(conn, ack, b.aesKey, b.hmacKey, true); err != nil {
|
||||||
|
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||||||
|
}
|
||||||
|
case consts.SUBSCRIBE:
|
||||||
|
b.addSubscriber(msg.Queue, conn)
|
||||||
|
ack := &codec.Message{
|
||||||
|
Headers: msg.Headers,
|
||||||
|
Queue: msg.Queue,
|
||||||
|
Command: consts.SUBSCRIBE_ACK,
|
||||||
|
}
|
||||||
|
if err := codec.SendMessage(conn, ack, b.aesKey, b.hmacKey, true); err != nil {
|
||||||
|
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) {
|
||||||
|
log.Println("Connection closed")
|
||||||
|
b.removeSubscriber(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) OnError(ctx context.Context, err error) {
|
||||||
|
log.Printf("Connection Error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(c, b.aesKey, b.hmacKey, true)
|
||||||
|
if err == nil {
|
||||||
|
ctx = mq.SetHeaders(ctx, msg.Headers)
|
||||||
|
b.OnMessage(ctx, msg, c)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||||
|
b.OnClose(ctx, c)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.OnError(ctx, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) Serve(ctx context.Context, addr string) error {
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
log.Println("Server is listening on :8080")
|
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Connection error:", err)
|
b.OnError(ctx, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go func(c net.Conn) {
|
go func(c net.Conn) {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
for {
|
for {
|
||||||
msg, err := codec.ReadMessage(c, aesKey, hmacKey)
|
err := b.readMessage(ctx, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "EOF" {
|
|
||||||
log.Println("Client disconnected")
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
log.Println("Failed to receive message:", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
log.Printf("Received Message:\n Headers: %v\n Topic: %s\n Command: %v\n Payload: %s\n",
|
|
||||||
msg.Headers, msg.Topic, msg.Command, msg.Payload)
|
|
||||||
}
|
}
|
||||||
}(conn)
|
}(conn)
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
conn, err := net.Dial("tcp", ":8081")
|
type Publisher struct {
|
||||||
|
aesKey json.RawMessage
|
||||||
|
hmacKey json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPublisher(aesKey, hmacKey json.RawMessage) *Publisher {
|
||||||
|
return &Publisher{aesKey: aesKey, hmacKey: hmacKey}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) Publish(ctx context.Context, addr, topic string, payload json.RawMessage) error {
|
||||||
|
conn, err := net.Dial("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
headers := map[string]string{"Api-Key": "121323"}
|
|
||||||
data := map[string]interface{}{"temperature": 23.5, "humidity": 60}
|
headers, _ := mq.GetHeaders(ctx)
|
||||||
payload, _ := json.Marshal(data)
|
|
||||||
msg := &codec.Message{
|
msg := &codec.Message{
|
||||||
Headers: headers,
|
Headers: headers,
|
||||||
Topic: "sensor_data",
|
Queue: topic,
|
||||||
Command: consts.SUBSCRIBE,
|
Command: consts.PUBLISH,
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
if err := codec.SendMessage(conn, msg, aesKey, hmacKey); err != nil {
|
if err := codec.SendMessage(conn, msg, p.aesKey, p.hmacKey, true); err != nil {
|
||||||
log.Fatalf("Error sending message: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
fmt.Println("Message sent successfully")
|
|
||||||
time.Sleep(5 * time.Second)
|
return p.waitForAck(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(conn, p.aesKey, p.hmacKey, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.PUBLISH_ACK {
|
||||||
|
log.Println("Received PUBLISH_ACK: Message published successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Consumer struct {
|
||||||
|
aesKey json.RawMessage
|
||||||
|
hmacKey json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConsumer(aesKey, hmacKey json.RawMessage) *Consumer {
|
||||||
|
return &Consumer{aesKey: aesKey, hmacKey: hmacKey}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) Subscribe(ctx context.Context, addr, topic string) error {
|
||||||
|
conn, err := net.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
headers, _ := mq.GetHeaders(ctx)
|
||||||
|
msg := &codec.Message{
|
||||||
|
Headers: headers,
|
||||||
|
Queue: topic,
|
||||||
|
Command: consts.SUBSCRIBE,
|
||||||
|
}
|
||||||
|
if err := codec.SendMessage(conn, msg, c.aesKey, c.hmacKey, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.waitForAck(conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
msg, err := codec.ReadMessage(conn, c.aesKey, c.hmacKey, true)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error reading message: %v\n", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Printf("Received task on topic %s: %s\n", msg.Queue, msg.Payload)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) waitForAck(conn net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(conn, c.aesKey, c.hmacKey, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.SUBSCRIBE_ACK {
|
||||||
|
log.Println("Received SUBSCRIBE_ACK: Subscribed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
addr := ":8081"
|
||||||
|
aesKey := []byte("thisis32bytekeyforaesencryption1")
|
||||||
|
hmacKey := []byte("thisisasecrethmackey1")
|
||||||
|
|
||||||
|
broker := NewBroker(aesKey, hmacKey)
|
||||||
|
publisher := NewPublisher(aesKey, hmacKey)
|
||||||
|
consumer := NewConsumer(aesKey, hmacKey)
|
||||||
|
|
||||||
|
go broker.Serve(context.Background(), addr)
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
go consumer.Subscribe(context.Background(), addr, "sensor_data")
|
||||||
|
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
data := map[string]interface{}{"temperature": 23.5, "humidity": 60}
|
||||||
|
payload, _ := json.Marshal(data)
|
||||||
|
go publisher.Publish(context.Background(), addr, "sensor_data", payload)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
}
|
}
|
||||||
|
34
examples/msg.go
Normal file
34
examples/msg.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
v2 "github.com/oarkflow/mq/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
broker := v2.NewBroker()
|
||||||
|
go broker.Start(context.Background())
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
|
consumer := v2.NewConsumer("consumer-1")
|
||||||
|
consumer.RegisterHandler("queue-1", func(ctx context.Context, task v2.Task) v2.Result {
|
||||||
|
fmt.Println("Handling on queue-1", string(task.Payload))
|
||||||
|
return v2.Result{Payload: task.Payload}
|
||||||
|
})
|
||||||
|
go func() {
|
||||||
|
err := consumer.Consume(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
publisher := v2.NewPublisher("publisher-1")
|
||||||
|
data := map[string]interface{}{"temperature": 23.5, "humidity": 60}
|
||||||
|
payload, _ := json.Marshal(data)
|
||||||
|
rs := publisher.Request(context.Background(), "queue-1", v2.Task{Payload: payload})
|
||||||
|
fmt.Println("Response:", string(rs.Payload), rs.Error)
|
||||||
|
}
|
332
v2/broker.go
Normal file
332
v2/broker.go
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/xsync"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/codec"
|
||||||
|
"github.com/oarkflow/mq/consts"
|
||||||
|
)
|
||||||
|
|
||||||
|
type consumer struct {
|
||||||
|
id string
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *consumer) send(ctx context.Context, cmd any) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type publisher struct {
|
||||||
|
id string
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *publisher) send(ctx context.Context, cmd any) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handler func(context.Context, Task) Result
|
||||||
|
|
||||||
|
type Broker struct {
|
||||||
|
queues xsync.IMap[string, *Queue]
|
||||||
|
consumers xsync.IMap[string, *consumer]
|
||||||
|
publishers xsync.IMap[string, *publisher]
|
||||||
|
opts Options
|
||||||
|
}
|
||||||
|
|
||||||
|
type Queue struct {
|
||||||
|
name string
|
||||||
|
consumers xsync.IMap[string, *consumer]
|
||||||
|
messages xsync.IMap[string, *Task]
|
||||||
|
deferred xsync.IMap[string, *Task]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newQueue(name string) *Queue {
|
||||||
|
return &Queue{
|
||||||
|
name: name,
|
||||||
|
consumers: xsync.NewMap[string, *consumer](),
|
||||||
|
messages: xsync.NewMap[string, *Task](),
|
||||||
|
deferred: xsync.NewMap[string, *Task](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (queue *Queue) send(ctx context.Context, cmd any) {
|
||||||
|
queue.consumers.ForEach(func(_ string, client *consumer) bool {
|
||||||
|
err := client.send(ctx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Task struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Payload json.RawMessage `json:"payload"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ProcessedAt time.Time `json:"processed_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Error error `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBroker(opts ...Option) *Broker {
|
||||||
|
options := defaultOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
b := &Broker{
|
||||||
|
queues: xsync.NewMap[string, *Queue](),
|
||||||
|
publishers: xsync.NewMap[string, *publisher](),
|
||||||
|
consumers: xsync.NewMap[string, *consumer](),
|
||||||
|
opts: options,
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) TLSConfig() TLSConfig {
|
||||||
|
return b.opts.tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) SyncMode() bool {
|
||||||
|
return b.opts.syncMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
|
||||||
|
consumerID, ok := GetConsumerID(ctx)
|
||||||
|
if ok && consumerID != "" {
|
||||||
|
if con, exists := b.consumers.Get(consumerID); exists {
|
||||||
|
con.conn.Close()
|
||||||
|
b.consumers.Del(consumerID)
|
||||||
|
}
|
||||||
|
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
||||||
|
queue.consumers.Del(consumerID)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
publisherID, ok := GetPublisherID(ctx)
|
||||||
|
if ok && publisherID != "" {
|
||||||
|
if con, exists := b.publishers.Get(publisherID); exists {
|
||||||
|
con.conn.Close()
|
||||||
|
b.publishers.Del(publisherID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
||||||
|
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||||
|
switch msg.Command {
|
||||||
|
case consts.PUBLISH:
|
||||||
|
b.publish(ctx, conn, msg)
|
||||||
|
case consts.SUBSCRIBE:
|
||||||
|
b.subscribe(ctx, conn, msg)
|
||||||
|
case consts.MESSAGE_RESPONSE:
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if ok {
|
||||||
|
if awaitResponse, ok := headers[consts.AwaitResponseKey]; ok && awaitResponse == "true" {
|
||||||
|
publisherID, exists := headers[consts.PublisherKey]
|
||||||
|
if exists {
|
||||||
|
con, ok := b.publishers.Get(publisherID)
|
||||||
|
if ok {
|
||||||
|
msg.Command = consts.RESPONSE
|
||||||
|
err := codec.SendMessage(con.conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println("consumer confirmed", headers, ok, string(msg.Payload))
|
||||||
|
case consts.MESSAGE_ACK:
|
||||||
|
fmt.Println("consumer confirmed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||||
|
if err == nil {
|
||||||
|
ctx = SetHeaders(ctx, msg.Headers)
|
||||||
|
b.OnMessage(ctx, msg, c)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||||
|
b.OnClose(ctx, c)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.OnError(ctx, c, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the broker server with optional TLS support
|
||||||
|
func (b *Broker) Start(ctx context.Context) error {
|
||||||
|
var listener net.Listener
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if b.opts.tlsConfig.UseTLS {
|
||||||
|
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load TLS certificates: %v", err)
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start TLS listener: %v", err)
|
||||||
|
}
|
||||||
|
log.Println("TLS server started on", b.opts.brokerAddr)
|
||||||
|
} else {
|
||||||
|
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start TCP listener: %v", err)
|
||||||
|
}
|
||||||
|
log.Println("TCP server started on", b.opts.brokerAddr)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
b.OnError(ctx, conn, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go func(c net.Conn) {
|
||||||
|
defer c.Close()
|
||||||
|
for {
|
||||||
|
err := b.readMessage(ctx, c)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) NewQueue(qName string) *Queue {
|
||||||
|
q, ok := b.queues.Get(qName)
|
||||||
|
if ok {
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
q = newQueue(qName)
|
||||||
|
b.queues.Set(qName, q)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) AddMessageToQueue(task *Task, queueName string) (*Queue, *Task, error) {
|
||||||
|
queue := b.NewQueue(queueName)
|
||||||
|
if task.ID == "" {
|
||||||
|
task.ID = NewID()
|
||||||
|
}
|
||||||
|
task.CreatedAt = time.Now()
|
||||||
|
queue.messages.Set(task.ID, task)
|
||||||
|
return queue, task, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||||||
|
consumerID, ok := GetConsumerID(ctx)
|
||||||
|
q, ok := b.queues.Get(queueName)
|
||||||
|
if !ok {
|
||||||
|
q = b.NewQueue(queueName)
|
||||||
|
}
|
||||||
|
con := &consumer{id: consumerID, conn: conn}
|
||||||
|
b.consumers.Set(consumerID, con)
|
||||||
|
q.consumers.Set(consumerID, con)
|
||||||
|
log.Printf("Consumer %s joined server on queue %s", consumerID, queueName)
|
||||||
|
return consumerID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
|
||||||
|
publisherID, ok := GetPublisherID(ctx)
|
||||||
|
_, ok = b.queues.Get(queueName)
|
||||||
|
if !ok {
|
||||||
|
b.NewQueue(queueName)
|
||||||
|
}
|
||||||
|
con := &publisher{id: publisherID, conn: conn}
|
||||||
|
b.publishers.Set(publisherID, con)
|
||||||
|
return con
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
||||||
|
if queue, ok := b.queues.Get(msg.Queue); ok {
|
||||||
|
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
||||||
|
msg.Command = consts.MESSAGE_SEND
|
||||||
|
if err := codec.SendMessage(con.conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil {
|
||||||
|
log.Printf("Error sending Message: %v\n", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) waitForAck(conn net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(conn, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.MESSAGE_ACK {
|
||||||
|
log.Println("Received CONSUMER_ACK: Subscribed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) publish(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||||
|
pub := b.addPublisher(ctx, msg.Queue, conn)
|
||||||
|
ack := &codec.Message{
|
||||||
|
Headers: msg.Headers,
|
||||||
|
Queue: msg.Queue,
|
||||||
|
Command: consts.PUBLISH_ACK,
|
||||||
|
}
|
||||||
|
if err := codec.SendMessage(conn, ack, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil {
|
||||||
|
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||||||
|
}
|
||||||
|
b.broadcastToConsumers(ctx, msg)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
b.publishers.Del(pub.id)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) subscribe(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||||
|
consumerID := b.addConsumer(ctx, msg.Queue, conn)
|
||||||
|
ack := &codec.Message{
|
||||||
|
Headers: msg.Headers,
|
||||||
|
Queue: msg.Queue,
|
||||||
|
Command: consts.SUBSCRIBE_ACK,
|
||||||
|
}
|
||||||
|
if err := codec.SendMessage(conn, ack, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil {
|
||||||
|
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
b.removeConsumer(msg.Queue, consumerID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removes connection from the queue and broker
|
||||||
|
func (b *Broker) removeConsumer(queueName, consumerID string) {
|
||||||
|
if queue, ok := b.queues.Get(queueName); ok {
|
||||||
|
con, ok := queue.consumers.Get(consumerID)
|
||||||
|
if ok {
|
||||||
|
con.conn.Close()
|
||||||
|
queue.consumers.Del(consumerID)
|
||||||
|
}
|
||||||
|
b.queues.Del(queueName)
|
||||||
|
}
|
||||||
|
}
|
191
v2/consumer.go
Normal file
191
v2/consumer.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/codec"
|
||||||
|
"github.com/oarkflow/mq/consts"
|
||||||
|
"github.com/oarkflow/mq/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Consumer structure to hold consumer-specific configurations and state.
|
||||||
|
type Consumer struct {
|
||||||
|
id string
|
||||||
|
handlers map[string]Handler
|
||||||
|
conn net.Conn
|
||||||
|
queues []string
|
||||||
|
opts Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConsumer initializes a new consumer with the provided options.
|
||||||
|
func NewConsumer(id string, opts ...Option) *Consumer {
|
||||||
|
options := defaultOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
b := &Consumer{
|
||||||
|
handlers: make(map[string]Handler),
|
||||||
|
id: id,
|
||||||
|
opts: options,
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the consumer's connection.
|
||||||
|
func (c *Consumer) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to a specific queue.
|
||||||
|
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||||
|
headers := WithHeaders(ctx, map[string]string{
|
||||||
|
consts.ConsumerKey: c.id,
|
||||||
|
consts.ContentType: consts.TypeJson,
|
||||||
|
})
|
||||||
|
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers)
|
||||||
|
if err := codec.SendMessage(c.conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.waitForAck(c.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error {
|
||||||
|
fmt.Println("Consumer closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
||||||
|
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||||
|
headers := WithHeaders(ctx, map[string]string{
|
||||||
|
consts.ConsumerKey: c.id,
|
||||||
|
consts.ContentType: consts.TypeJson,
|
||||||
|
})
|
||||||
|
reply := codec.NewMessage(consts.MESSAGE_ACK, nil, msg.Queue, headers)
|
||||||
|
if err := codec.SendMessage(conn, reply, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption); err != nil {
|
||||||
|
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||||||
|
}
|
||||||
|
var task Task
|
||||||
|
err := json.Unmarshal(msg.Payload, &task)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error unmarshalling message:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
||||||
|
result := c.ProcessTask(ctx, task)
|
||||||
|
result.MessageID = task.ID
|
||||||
|
result.Queue = msg.Queue
|
||||||
|
if result.Error != nil {
|
||||||
|
result.Status = "FAILED"
|
||||||
|
} else {
|
||||||
|
result.Status = "SUCCESS"
|
||||||
|
}
|
||||||
|
bt, _ := json.Marshal(result)
|
||||||
|
reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers)
|
||||||
|
if err := codec.SendMessage(conn, reply, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption); err != nil {
|
||||||
|
fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessTask handles a received task message and invokes the appropriate handler.
|
||||||
|
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
||||||
|
queue, _ := GetQueue(ctx)
|
||||||
|
handler, exists := c.handlers[queue]
|
||||||
|
if !exists {
|
||||||
|
return Result{Error: errors.New("No handler for queue " + queue)}
|
||||||
|
}
|
||||||
|
return handler(ctx, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
|
||||||
|
func (c *Consumer) AttemptConnect() error {
|
||||||
|
var err error
|
||||||
|
delay := c.opts.initialDelay
|
||||||
|
for i := 0; i < c.opts.maxRetries; i++ {
|
||||||
|
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
||||||
|
if err == nil {
|
||||||
|
c.conn = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
||||||
|
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
||||||
|
time.Sleep(sleepDuration)
|
||||||
|
delay *= 2
|
||||||
|
if delay > c.opts.maxBackoff {
|
||||||
|
delay = c.opts.maxBackoff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||||
|
if err == nil {
|
||||||
|
ctx = SetHeaders(ctx, msg.Headers)
|
||||||
|
c.OnMessage(ctx, msg, conn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||||
|
c.OnClose(ctx, conn)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.OnError(ctx, conn, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume starts the consumer to consume tasks from the queues.
|
||||||
|
func (c *Consumer) Consume(ctx context.Context) error {
|
||||||
|
err := c.AttemptConnect()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, q := range c.queues {
|
||||||
|
if err := c.subscribe(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||||
|
log.Println("Error reading message:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) waitForAck(conn net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.SUBSCRIBE_ACK {
|
||||||
|
log.Println("Received SUBSCRIBE_ACK: Subscribed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterHandler registers a handler for a queue.
|
||||||
|
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
||||||
|
c.queues = append(c.queues, queue)
|
||||||
|
c.handlers[queue] = handler
|
||||||
|
}
|
145
v2/ctx.go
Normal file
145
v2/ctx.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/oarkflow/xid"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/consts"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsClosed(conn net.Conn) bool {
|
||||||
|
_, err := conn.Read(make([]byte, 1))
|
||||||
|
if err != nil {
|
||||||
|
if err == net.ErrClosed {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetHeaders(ctx context.Context, headers map[string]string) context.Context {
|
||||||
|
hd, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
hd = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, val := range headers {
|
||||||
|
hd[key] = val
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, consts.HeaderKey, hd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
|
||||||
|
hd, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
hd = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, val := range headers {
|
||||||
|
hd[key] = val
|
||||||
|
}
|
||||||
|
return hd
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetHeaders(ctx context.Context) (map[string]string, bool) {
|
||||||
|
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||||
|
return headers, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetHeader(ctx context.Context, key string) (string, bool) {
|
||||||
|
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
val, ok := headers[key]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContentType(ctx context.Context) (string, bool) {
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
contentType, ok := headers[consts.ContentType]
|
||||||
|
return contentType, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetQueue(ctx context.Context) (string, bool) {
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
contentType, ok := headers[consts.QueueKey]
|
||||||
|
return contentType, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConsumerID(ctx context.Context) (string, bool) {
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
contentType, ok := headers[consts.ConsumerKey]
|
||||||
|
return contentType, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTriggerNode(ctx context.Context) (string, bool) {
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
contentType, ok := headers[consts.TriggerNode]
|
||||||
|
return contentType, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPublisherID(ctx context.Context) (string, bool) {
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
contentType, ok := headers[consts.PublisherKey]
|
||||||
|
return contentType, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewID() string {
|
||||||
|
return xid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}
|
||||||
|
if len(caPath) > 0 && caPath[0] != "" {
|
||||||
|
caCert, err := os.ReadFile(caPath[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load CA cert: %w", err)
|
||||||
|
}
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPool.AppendCertsFromPEM(caCert)
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
tlsConfig.ClientCAs = caCertPool
|
||||||
|
}
|
||||||
|
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to dial TLS connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
|
||||||
|
if config.UseTLS {
|
||||||
|
return createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath)
|
||||||
|
} else {
|
||||||
|
return net.Dial("tcp", addr)
|
||||||
|
}
|
||||||
|
}
|
123
v2/options.go
Normal file
123
v2/options.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Payload json.RawMessage `json:"payload"`
|
||||||
|
Queue string `json:"queue"`
|
||||||
|
MessageID string `json:"message_id"`
|
||||||
|
Error error `json:"error,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TLSConfig struct {
|
||||||
|
UseTLS bool
|
||||||
|
CertPath string
|
||||||
|
KeyPath string
|
||||||
|
CAPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
syncMode bool
|
||||||
|
brokerAddr string
|
||||||
|
callback []func(context.Context, Result) Result
|
||||||
|
maxRetries int
|
||||||
|
initialDelay time.Duration
|
||||||
|
maxBackoff time.Duration
|
||||||
|
jitterPercent float64
|
||||||
|
tlsConfig TLSConfig
|
||||||
|
aesKey json.RawMessage
|
||||||
|
hmacKey json.RawMessage
|
||||||
|
enableEncryption bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultOptions() Options {
|
||||||
|
return Options{
|
||||||
|
syncMode: false,
|
||||||
|
brokerAddr: ":8080",
|
||||||
|
maxRetries: 5,
|
||||||
|
initialDelay: 2 * time.Second,
|
||||||
|
maxBackoff: 20 * time.Second,
|
||||||
|
jitterPercent: 0.5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option defines a function type for setting options.
|
||||||
|
type Option func(*Options)
|
||||||
|
|
||||||
|
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.aesKey = aesKey
|
||||||
|
opts.hmacKey = hmacKey
|
||||||
|
opts.enableEncryption = enableEncryption
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBrokerURL -
|
||||||
|
func WithBrokerURL(url string) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.brokerAddr = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLS - Option to enable/disable TLS
|
||||||
|
func WithTLS(enableTLS bool, certPath, keyPath string) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.tlsConfig.UseTLS = enableTLS
|
||||||
|
o.tlsConfig.CertPath = certPath
|
||||||
|
o.tlsConfig.KeyPath = keyPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCAPath - Option to enable/disable TLS
|
||||||
|
func WithCAPath(caPath string) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.tlsConfig.CAPath = caPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSyncMode -
|
||||||
|
func WithSyncMode(mode bool) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.syncMode = mode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxRetries -
|
||||||
|
func WithMaxRetries(val int) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.maxRetries = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithInitialDelay -
|
||||||
|
func WithInitialDelay(val time.Duration) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.initialDelay = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxBackoff -
|
||||||
|
func WithMaxBackoff(val time.Duration) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.maxBackoff = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCallback -
|
||||||
|
func WithCallback(val ...func(context.Context, Result) Result) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.callback = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithJitterPercent -
|
||||||
|
func WithJitterPercent(val float64) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.jitterPercent = val
|
||||||
|
}
|
||||||
|
}
|
106
v2/publisher.go
Normal file
106
v2/publisher.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/codec"
|
||||||
|
"github.com/oarkflow/mq/consts"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Publisher struct {
|
||||||
|
id string
|
||||||
|
opts Options
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPublisher(id string, opts ...Option) *Publisher {
|
||||||
|
options := defaultOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
b := &Publisher{id: id, opts: options}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
||||||
|
headers := WithHeaders(ctx, map[string]string{
|
||||||
|
consts.PublisherKey: p.id,
|
||||||
|
consts.ContentType: consts.TypeJson,
|
||||||
|
})
|
||||||
|
if task.ID == "" {
|
||||||
|
task.ID = NewID()
|
||||||
|
}
|
||||||
|
task.CreatedAt = time.Now()
|
||||||
|
payload, err := json.Marshal(task)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
msg := codec.NewMessage(command, payload, queue, headers)
|
||||||
|
if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.waitForAck(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.PUBLISH_ACK {
|
||||||
|
log.Println("Received PUBLISH_ACK: Message published successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) waitForResponse(conn net.Conn) Result {
|
||||||
|
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||||
|
if err != nil {
|
||||||
|
return Result{Error: err}
|
||||||
|
}
|
||||||
|
if msg.Command == consts.RESPONSE {
|
||||||
|
var result Result
|
||||||
|
err = json.Unmarshal(msg.Payload, &result)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
|
||||||
|
return Result{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error {
|
||||||
|
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to broker: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
return p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error {
|
||||||
|
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
|
||||||
|
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result {
|
||||||
|
ctx = SetHeaders(ctx, map[string]string{
|
||||||
|
consts.AwaitResponseKey: "true",
|
||||||
|
})
|
||||||
|
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to connect to broker: %w", err)
|
||||||
|
return Result{Error: err}
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||||
|
return p.waitForResponse(conn)
|
||||||
|
}
|
Reference in New Issue
Block a user