diff --git a/codec/codec.go b/codec/codec.go index 8aa5953..2d397bf 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -2,13 +2,7 @@ package codec import ( "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" "encoding/binary" - "encoding/hex" "encoding/json" "fmt" "io" @@ -18,164 +12,216 @@ import ( type Message struct { Headers map[string]string `json:"h"` - Topic string `json:"t"` + Queue string `json:"q"` Command consts.CMD `json:"c"` Payload json.RawMessage `json:"p"` + // Metadata map[string]any `json:"m"` } -func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, nil, err +func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message { + return &Message{ + Headers: headers, + Queue: queue, + Command: cmd, + Payload: payload, + // Metadata: nil, } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, nil, err - } - nonce := make([]byte, aesGCM.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { - return nil, nil, err - } - ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil) - return ciphertext, nonce, nil } -func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - return aesGCM.Open(nil, nonce, ciphertext, nil) -} - -func CalculateHMAC(key []byte, data []byte) string { - h := hmac.New(sha256.New, key) - h.Write(data) - return hex.EncodeToString(h.Sum(nil)) -} - -func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool { - expectedHMAC := CalculateHMAC(key, data) - return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC)) -} - -func (m *Message) Serialize(aesKey []byte, hmacKey []byte) ([]byte, string, error) { +func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) { var buf bytes.Buffer - headersBytes, err := json.Marshal(m.Headers) - if err != nil { - return nil, "", fmt.Errorf("error marshalling headers: %v", err) + + // Serialize Headers, Queue, Command, Payload, and Metadata + if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil { + return nil, "", fmt.Errorf("error serializing headers: %v", err) } - headersLen := uint32(len(headersBytes)) - if err := binary.Write(&buf, binary.LittleEndian, headersLen); err != nil { - return nil, "", err - } - buf.Write(headersBytes) - topicBytes := []byte(m.Topic) - topicLen := uint8(len(topicBytes)) - if err := binary.Write(&buf, binary.LittleEndian, topicLen); err != nil { - return nil, "", err - } - buf.Write(topicBytes) - if !m.Command.IsValid() { - return nil, "", fmt.Errorf("invalid command: %s", m.Command) + if err := writeLengthPrefixed(&buf, []byte(m.Queue)); err != nil { + return nil, "", fmt.Errorf("error serializing topic: %v", err) } if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil { + return nil, "", fmt.Errorf("error serializing command: %v", err) + } + if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil { return nil, "", err } - payloadBytes, err := json.Marshal(m.Payload) - if err != nil { - return nil, "", fmt.Errorf("error marshalling payload: %v", err) - } - encryptedPayload, nonce, err := EncryptPayload(aesKey, payloadBytes) - if err != nil { - return nil, "", err - } - payloadLen := uint32(len(encryptedPayload)) - if err := binary.Write(&buf, binary.LittleEndian, payloadLen); err != nil { - return nil, "", err - } - buf.Write(encryptedPayload) - buf.Write(nonce) + /*if err := writeLengthPrefixedJSON(&buf, m.Metadata); err != nil { + return nil, "", fmt.Errorf("error serializing metadata: %v", err) + }*/ + + // Calculate HMAC messageBytes := buf.Bytes() hmacSignature := CalculateHMAC(hmacKey, messageBytes) + return messageBytes, hmacSignature, nil } -func Deserialize(data []byte, aesKey []byte, hmacKey []byte, receivedHMAC string) (*Message, error) { +func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) { if !VerifyHMAC(hmacKey, data, receivedHMAC) { return nil, fmt.Errorf("HMAC verification failed") } + buf := bytes.NewReader(data) - var topicLen uint8 - var payloadLen uint32 - var headersLen uint32 - if err := binary.Read(buf, binary.LittleEndian, &headersLen); err != nil { - return nil, err + + // Deserialize Headers, Queue, Command, Payload, and Metadata + headers := make(map[string]string) + if err := readLengthPrefixedJSON(buf, &headers); err != nil { + return nil, fmt.Errorf("error deserializing headers: %v", err) } - headersBytes := make([]byte, headersLen) - if _, err := buf.Read(headersBytes); err != nil { - return nil, err + + topic, err := readLengthPrefixedString(buf) + if err != nil { + return nil, fmt.Errorf("error deserializing topic: %v", err) } - var headers map[string]string - if err := json.Unmarshal(headersBytes, &headers); err != nil { - return nil, err - } - if err := binary.Read(buf, binary.LittleEndian, &topicLen); err != nil { - return nil, err - } - topicBytes := make([]byte, topicLen) - if _, err := buf.Read(topicBytes); err != nil { - return nil, err - } - topic := string(topicBytes) + var command consts.CMD if err := binary.Read(buf, binary.LittleEndian, &command); err != nil { - return nil, err + return nil, fmt.Errorf("error deserializing command: %v", err) } - if !command.IsValid() { - return nil, fmt.Errorf("invalid command: %s", command) + + payload, err := readPayload(buf, aesKey, decrypt) + if err != nil { + return nil, fmt.Errorf("error deserializing payload: %v", err) } - if err := binary.Read(buf, binary.LittleEndian, &payloadLen); err != nil { - return nil, err + + /*metadata := make(map[string]any) + if err := readLengthPrefixedJSON(buf, &metadata); err != nil { + return nil, fmt.Errorf("error deserializing metadata: %v", err) + }*/ + + return &Message{ + Headers: headers, + Queue: topic, + Command: command, + Payload: payload, + // Metadata: metadata, + }, nil +} + +func SendMessage(conn io.Writer, msg *Message, aesKey, hmacKey []byte, encrypt bool) error { + sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey, encrypt) + if err != nil { + return fmt.Errorf("error serializing message: %v", err) } - encryptedPayload := make([]byte, payloadLen) - if _, err := buf.Read(encryptedPayload); err != nil { - return nil, err + + if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil { + return fmt.Errorf("error sending message: %v", err) } - nonce := make([]byte, 12) - if _, err := buf.Read(nonce); err != nil { - return nil, err - } - payloadBytes, err := DecryptPayload(aesKey, encryptedPayload, nonce) + + return nil +} + +func ReadMessage(conn io.Reader, aesKey, hmacKey []byte, decrypt bool) (*Message, error) { + data, receivedHMAC, err := readMessageWithHMAC(conn) if err != nil { return nil, err } + + return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt) +} + +func writeLengthPrefixed(buf *bytes.Buffer, data []byte) error { + length := uint32(len(data)) + if err := binary.Write(buf, binary.LittleEndian, length); err != nil { + return err + } + buf.Write(data) + return nil +} + +func writeLengthPrefixedJSON(buf *bytes.Buffer, v any) error { + data, err := json.Marshal(v) + if err != nil { + return err + } + return writeLengthPrefixed(buf, data) +} + +func readLengthPrefixed(r *bytes.Reader) ([]byte, error) { + var length uint32 + if err := binary.Read(r, binary.LittleEndian, &length); err != nil { + return nil, err + } + data := make([]byte, length) + if _, err := r.Read(data); err != nil { + return nil, err + } + return data, nil +} + +func readLengthPrefixedJSON(r *bytes.Reader, v any) error { + data, err := readLengthPrefixed(r) + if err != nil { + return err + } + return json.Unmarshal(data, v) +} + +func readLengthPrefixedString(r *bytes.Reader) (string, error) { + data, err := readLengthPrefixed(r) + if err != nil { + return "", err + } + return string(data), nil +} + +func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, encrypt bool) error { + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("error marshalling payload: %v", err) + } + + var encryptedPayload, nonce []byte + if encrypt { + encryptedPayload, nonce, err = EncryptPayload(aesKey, payloadBytes) + if err != nil { + return err + } + } else { + encryptedPayload = payloadBytes + } + + if err := writeLengthPrefixed(buf, encryptedPayload); err != nil { + return err + } + + if encrypt { + buf.Write(nonce) + } + return nil +} + +func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage, error) { + encryptedPayload, err := readLengthPrefixed(r) + if err != nil { + return nil, err + } + + var payloadBytes []byte + if decrypt { + nonce := make([]byte, 12) + if _, err := r.Read(nonce); err != nil { + return nil, err + } + payloadBytes, err = DecryptPayload(aesKey, encryptedPayload, nonce) + if err != nil { + return nil, fmt.Errorf("error decrypting payload: %v", err) + } + } else { + payloadBytes = encryptedPayload + } + var payload json.RawMessage if err := json.Unmarshal(payloadBytes, &payload); err != nil { return nil, fmt.Errorf("error unmarshalling payload: %v", err) } - return &Message{ - Headers: headers, - Topic: topic, - Command: command, - Payload: payload, - }, nil -} -func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) error { - sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey) - if err != nil { - return fmt.Errorf("error serializing message: %v", err) - } - if err := binary.Write(conn, binary.LittleEndian, uint32(len(sentData))); err != nil { + return payload, nil +} +func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error { + if err := binary.Write(conn, binary.LittleEndian, uint32(len(messageBytes))); err != nil { return err } - - if _, err := conn.Write(sentData); err != nil { + if _, err := conn.Write(messageBytes); err != nil { return err } if _, err := conn.Write([]byte(hmacSignature)); err != nil { @@ -184,20 +230,21 @@ func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) er return nil } -func ReadMessage(conn io.Reader, aesKey []byte, hmacKey []byte) (*Message, error) { +func readMessageWithHMAC(conn io.Reader) ([]byte, string, error) { var length uint32 if err := binary.Read(conn, binary.LittleEndian, &length); err != nil { - return nil, err + return nil, "", err } data := make([]byte, length) if _, err := io.ReadFull(conn, data); err != nil { - return nil, err + return nil, "", err } hmacBytes := make([]byte, 64) if _, err := io.ReadFull(conn, hmacBytes); err != nil { - return nil, err + return nil, "", err } receivedHMAC := string(hmacBytes) - return Deserialize(data, aesKey, hmacKey, receivedHMAC) + + return data, receivedHMAC, nil } diff --git a/codec/encrypt.go b/codec/encrypt.go new file mode 100644 index 0000000..50d3495 --- /dev/null +++ b/codec/encrypt.go @@ -0,0 +1,51 @@ +package codec + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "io" +) + +func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, err + } + nonce := make([]byte, aesGCM.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, nil, err + } + ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil) + return ciphertext, nonce, nil +} + +func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + return aesGCM.Open(nil, nonce, ciphertext, nil) +} + +func CalculateHMAC(key []byte, data []byte) string { + h := hmac.New(sha256.New, key) + h.Write(data) + return hex.EncodeToString(h.Sum(nil)) +} + +func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool { + expectedHMAC := CalculateHMAC(key, data) + return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC)) +} diff --git a/consts/constants.go b/consts/constants.go index 71742df..8de4490 100644 --- a/consts/constants.go +++ b/consts/constants.go @@ -8,8 +8,15 @@ const ( PING CMD = iota + 1 SUBSCRIBE SUBSCRIBE_ACK + + MESSAGE_SEND + MESSAGE_RESPONSE + MESSAGE_ACK + MESSAGE_ERROR + PUBLISH PUBLISH_ACK + REQUEST RESPONSE STOP @@ -23,6 +30,14 @@ func (c CMD) String() string { return "SUBSCRIBE" case SUBSCRIBE_ACK: return "SUBSCRIBE_ACK" + case MESSAGE_SEND: + return "MESSAGE_SEND" + case MESSAGE_RESPONSE: + return "MESSAGE_RESPONSE" + case MESSAGE_ERROR: + return "MESSAGE_ERROR" + case MESSAGE_ACK: + return "MESSAGE_ACK" case PUBLISH: return "PUBLISH" case PUBLISH_ACK: @@ -37,10 +52,12 @@ func (c CMD) String() string { } var ( - ConsumerKey = "Consumer-Key" - PublisherKey = "Publisher-Key" - ContentType = "Content-Type" - TypeJson = "application/json" - HeaderKey = "headers" - TriggerNode = "triggerNode" + ConsumerKey = "Consumer-Key" + PublisherKey = "Publisher-Key" + ContentType = "Content-Type" + AwaitResponseKey = "Await-Response" + QueueKey = "Queue" + TypeJson = "application/json" + HeaderKey = "headers" + TriggerNode = "triggerNode" ) diff --git a/examples/message.go b/examples/message.go index 2ee9120..47c8f53 100644 --- a/examples/message.go +++ b/examples/message.go @@ -1,68 +1,266 @@ package main import ( + "context" "encoding/json" "fmt" "log" "net" + "strings" + "sync" "time" + "github.com/oarkflow/mq" "github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/consts" ) -func main() { - aesKey := []byte("thisis32bytekeyforaesencryption1") - hmacKey := []byte("thisisasecrethmackey1") - go func() { - listener, err := net.Listen("tcp", ":8081") - if err != nil { - log.Fatal(err) - } - defer listener.Close() - log.Println("Server is listening on :8080") - for { - conn, err := listener.Accept() - if err != nil { - log.Println("Connection error:", err) - continue +type Broker struct { + aesKey json.RawMessage + hmacKey json.RawMessage + subscribers map[string][]net.Conn + mu sync.RWMutex +} + +func NewBroker(aesKey, hmacKey json.RawMessage) *Broker { + return &Broker{ + aesKey: aesKey, + hmacKey: hmacKey, + subscribers: make(map[string][]net.Conn), + } +} + +func (b *Broker) addSubscriber(topic string, conn net.Conn) { + b.mu.Lock() + defer b.mu.Unlock() + b.subscribers[topic] = append(b.subscribers[topic], conn) +} + +func (b *Broker) removeSubscriber(conn net.Conn) { + b.mu.Lock() + defer b.mu.Unlock() + for topic, conns := range b.subscribers { + for i, c := range conns { + if c == conn { + b.subscribers[topic] = append(conns[:i], conns[i+1:]...) + break } - go func(c net.Conn) { - defer c.Close() - for { - msg, err := codec.ReadMessage(c, aesKey, hmacKey) - if err != nil { - if err.Error() == "EOF" { - log.Println("Client disconnected") - break - } - log.Println("Failed to receive message:", err) - break - } - log.Printf("Received Message:\n Headers: %v\n Topic: %s\n Command: %v\n Payload: %s\n", - msg.Headers, msg.Topic, msg.Command, msg.Payload) - } - }(conn) } - }() - time.Sleep(5 * time.Second) - conn, err := net.Dial("tcp", ":8081") + } +} + +func (b *Broker) broadcastToSubscribers(topic string, msg *codec.Message) { + b.mu.RLock() + defer b.mu.RUnlock() + + subscribers, ok := b.subscribers[topic] + if !ok || len(subscribers) == 0 { + log.Printf("No subscribers for topic: %s", topic) + return + } + + for _, conn := range subscribers { + err := codec.SendMessage(conn, msg, b.aesKey, b.hmacKey, true) + if err != nil { + log.Printf("Error sending message to subscriber: %v", err) + } + } +} + +func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + switch msg.Command { + case consts.PUBLISH: + b.broadcastToSubscribers(msg.Queue, msg) + ack := &codec.Message{ + Headers: msg.Headers, + Queue: msg.Queue, + Command: consts.PUBLISH_ACK, + } + if err := codec.SendMessage(conn, ack, b.aesKey, b.hmacKey, true); err != nil { + log.Printf("Error sending PUBLISH_ACK: %v\n", err) + } + case consts.SUBSCRIBE: + b.addSubscriber(msg.Queue, conn) + ack := &codec.Message{ + Headers: msg.Headers, + Queue: msg.Queue, + Command: consts.SUBSCRIBE_ACK, + } + if err := codec.SendMessage(conn, ack, b.aesKey, b.hmacKey, true); err != nil { + log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) + } + } +} + +func (b *Broker) OnClose(ctx context.Context, conn net.Conn) { + log.Println("Connection closed") + b.removeSubscriber(conn) +} + +func (b *Broker) OnError(ctx context.Context, err error) { + log.Printf("Connection Error: %v\n", err) +} + +func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { + msg, err := codec.ReadMessage(c, b.aesKey, b.hmacKey, true) + if err == nil { + ctx = mq.SetHeaders(ctx, msg.Headers) + b.OnMessage(ctx, msg, c) + return nil + } + if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { + b.OnClose(ctx, c) + return err + } + b.OnError(ctx, err) + return err +} + +func (b *Broker) Serve(ctx context.Context, addr string) error { + listener, err := net.Listen("tcp", addr) if err != nil { - log.Fatal(err) + return err + } + defer listener.Close() + for { + conn, err := listener.Accept() + if err != nil { + b.OnError(ctx, err) + continue + } + go func(c net.Conn) { + defer c.Close() + for { + err := b.readMessage(ctx, c) + if err != nil { + break + } + } + }(conn) + } +} + +type Publisher struct { + aesKey json.RawMessage + hmacKey json.RawMessage +} + +func NewPublisher(aesKey, hmacKey json.RawMessage) *Publisher { + return &Publisher{aesKey: aesKey, hmacKey: hmacKey} +} + +func (p *Publisher) Publish(ctx context.Context, addr, topic string, payload json.RawMessage) error { + conn, err := net.Dial("tcp", addr) + if err != nil { + return err } defer conn.Close() - headers := map[string]string{"Api-Key": "121323"} - data := map[string]interface{}{"temperature": 23.5, "humidity": 60} - payload, _ := json.Marshal(data) + + headers, _ := mq.GetHeaders(ctx) msg := &codec.Message{ Headers: headers, - Topic: "sensor_data", - Command: consts.SUBSCRIBE, + Queue: topic, + Command: consts.PUBLISH, Payload: payload, } - if err := codec.SendMessage(conn, msg, aesKey, hmacKey); err != nil { - log.Fatalf("Error sending message: %v", err) + if err := codec.SendMessage(conn, msg, p.aesKey, p.hmacKey, true); err != nil { + return err } - fmt.Println("Message sent successfully") - time.Sleep(5 * time.Second) + + return p.waitForAck(conn) +} + +func (p *Publisher) waitForAck(conn net.Conn) error { + msg, err := codec.ReadMessage(conn, p.aesKey, p.hmacKey, true) + if err != nil { + return err + } + if msg.Command == consts.PUBLISH_ACK { + log.Println("Received PUBLISH_ACK: Message published successfully") + return nil + } + return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command) +} + +type Consumer struct { + aesKey json.RawMessage + hmacKey json.RawMessage +} + +func NewConsumer(aesKey, hmacKey json.RawMessage) *Consumer { + return &Consumer{aesKey: aesKey, hmacKey: hmacKey} +} + +func (c *Consumer) Subscribe(ctx context.Context, addr, topic string) error { + conn, err := net.Dial("tcp", addr) + if err != nil { + return err + } + defer conn.Close() + + headers, _ := mq.GetHeaders(ctx) + msg := &codec.Message{ + Headers: headers, + Queue: topic, + Command: consts.SUBSCRIBE, + } + if err := codec.SendMessage(conn, msg, c.aesKey, c.hmacKey, true); err != nil { + return err + } + + err = c.waitForAck(conn) + if err != nil { + return err + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + msg, err := codec.ReadMessage(conn, c.aesKey, c.hmacKey, true) + if err != nil { + log.Printf("Error reading message: %v\n", err) + break + } + log.Printf("Received task on topic %s: %s\n", msg.Queue, msg.Payload) + } + }() + + wg.Wait() + + return nil +} + +func (c *Consumer) waitForAck(conn net.Conn) error { + msg, err := codec.ReadMessage(conn, c.aesKey, c.hmacKey, true) + if err != nil { + return err + } + if msg.Command == consts.SUBSCRIBE_ACK { + log.Println("Received SUBSCRIBE_ACK: Subscribed successfully") + return nil + } + return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) +} + +func main() { + addr := ":8081" + aesKey := []byte("thisis32bytekeyforaesencryption1") + hmacKey := []byte("thisisasecrethmackey1") + + broker := NewBroker(aesKey, hmacKey) + publisher := NewPublisher(aesKey, hmacKey) + consumer := NewConsumer(aesKey, hmacKey) + + go broker.Serve(context.Background(), addr) + + time.Sleep(1 * time.Second) + go consumer.Subscribe(context.Background(), addr, "sensor_data") + + time.Sleep(3 * time.Second) + data := map[string]interface{}{"temperature": 23.5, "humidity": 60} + payload, _ := json.Marshal(data) + go publisher.Publish(context.Background(), addr, "sensor_data", payload) + + time.Sleep(10 * time.Second) } diff --git a/examples/msg.go b/examples/msg.go new file mode 100644 index 0000000..bc6bed5 --- /dev/null +++ b/examples/msg.go @@ -0,0 +1,34 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "time" + + v2 "github.com/oarkflow/mq/v2" +) + +func main() { + broker := v2.NewBroker() + go broker.Start(context.Background()) + time.Sleep(1 * time.Second) + + consumer := v2.NewConsumer("consumer-1") + consumer.RegisterHandler("queue-1", func(ctx context.Context, task v2.Task) v2.Result { + fmt.Println("Handling on queue-1", string(task.Payload)) + return v2.Result{Payload: task.Payload} + }) + go func() { + err := consumer.Consume(context.Background()) + if err != nil { + panic(err) + } + }() + + publisher := v2.NewPublisher("publisher-1") + data := map[string]interface{}{"temperature": 23.5, "humidity": 60} + payload, _ := json.Marshal(data) + rs := publisher.Request(context.Background(), "queue-1", v2.Task{Payload: payload}) + fmt.Println("Response:", string(rs.Payload), rs.Error) +} diff --git a/v2/broker.go b/v2/broker.go new file mode 100644 index 0000000..aeba02d --- /dev/null +++ b/v2/broker.go @@ -0,0 +1,332 @@ +package v2 + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "log" + "net" + "strings" + "time" + + "github.com/oarkflow/xsync" + + "github.com/oarkflow/mq/codec" + "github.com/oarkflow/mq/consts" +) + +type consumer struct { + id string + conn net.Conn +} + +func (p *consumer) send(ctx context.Context, cmd any) error { + return nil +} + +type publisher struct { + id string + conn net.Conn +} + +func (p *publisher) send(ctx context.Context, cmd any) error { + return nil +} + +type Handler func(context.Context, Task) Result + +type Broker struct { + queues xsync.IMap[string, *Queue] + consumers xsync.IMap[string, *consumer] + publishers xsync.IMap[string, *publisher] + opts Options +} + +type Queue struct { + name string + consumers xsync.IMap[string, *consumer] + messages xsync.IMap[string, *Task] + deferred xsync.IMap[string, *Task] +} + +func newQueue(name string) *Queue { + return &Queue{ + name: name, + consumers: xsync.NewMap[string, *consumer](), + messages: xsync.NewMap[string, *Task](), + deferred: xsync.NewMap[string, *Task](), + } +} + +func (queue *Queue) send(ctx context.Context, cmd any) { + queue.consumers.ForEach(func(_ string, client *consumer) bool { + err := client.send(ctx, cmd) + if err != nil { + return false + } + return true + }) +} + +type Task struct { + ID string `json:"id"` + Payload json.RawMessage `json:"payload"` + CreatedAt time.Time `json:"created_at"` + ProcessedAt time.Time `json:"processed_at"` + Status string `json:"status"` + Error error `json:"error"` +} + +func NewBroker(opts ...Option) *Broker { + options := defaultOptions() + for _, opt := range opts { + opt(&options) + } + b := &Broker{ + queues: xsync.NewMap[string, *Queue](), + publishers: xsync.NewMap[string, *publisher](), + consumers: xsync.NewMap[string, *consumer](), + opts: options, + } + return b +} + +func (b *Broker) TLSConfig() TLSConfig { + return b.opts.tlsConfig +} + +func (b *Broker) SyncMode() bool { + return b.opts.syncMode +} + +func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error { + consumerID, ok := GetConsumerID(ctx) + if ok && consumerID != "" { + if con, exists := b.consumers.Get(consumerID); exists { + con.conn.Close() + b.consumers.Del(consumerID) + } + b.queues.ForEach(func(_ string, queue *Queue) bool { + queue.consumers.Del(consumerID) + return true + }) + } + publisherID, ok := GetPublisherID(ctx) + if ok && publisherID != "" { + if con, exists := b.publishers.Get(publisherID); exists { + con.conn.Close() + b.publishers.Del(publisherID) + } + } + return nil +} + +func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) { + fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) +} + +func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + switch msg.Command { + case consts.PUBLISH: + b.publish(ctx, conn, msg) + case consts.SUBSCRIBE: + b.subscribe(ctx, conn, msg) + case consts.MESSAGE_RESPONSE: + headers, ok := GetHeaders(ctx) + if ok { + if awaitResponse, ok := headers[consts.AwaitResponseKey]; ok && awaitResponse == "true" { + publisherID, exists := headers[consts.PublisherKey] + if exists { + con, ok := b.publishers.Get(publisherID) + if ok { + msg.Command = consts.RESPONSE + err := codec.SendMessage(con.conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) + if err != nil { + panic(err) + } + } + } + } + } + fmt.Println("consumer confirmed", headers, ok, string(msg.Payload)) + case consts.MESSAGE_ACK: + fmt.Println("consumer confirmed") + } +} + +func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { + msg, err := codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) + if err == nil { + ctx = SetHeaders(ctx, msg.Headers) + b.OnMessage(ctx, msg, c) + return nil + } + if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { + b.OnClose(ctx, c) + return err + } + b.OnError(ctx, c, err) + return err +} + +// Start the broker server with optional TLS support +func (b *Broker) Start(ctx context.Context) error { + var listener net.Listener + var err error + + if b.opts.tlsConfig.UseTLS { + cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath) + if err != nil { + return fmt.Errorf("failed to load TLS certificates: %v", err) + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig) + if err != nil { + return fmt.Errorf("failed to start TLS listener: %v", err) + } + log.Println("TLS server started on", b.opts.brokerAddr) + } else { + listener, err = net.Listen("tcp", b.opts.brokerAddr) + if err != nil { + return fmt.Errorf("failed to start TCP listener: %v", err) + } + log.Println("TCP server started on", b.opts.brokerAddr) + } + defer listener.Close() + for { + conn, err := listener.Accept() + if err != nil { + b.OnError(ctx, conn, err) + continue + } + go func(c net.Conn) { + defer c.Close() + for { + err := b.readMessage(ctx, c) + if err != nil { + break + } + } + }(conn) + } +} + +func (b *Broker) NewQueue(qName string) *Queue { + q, ok := b.queues.Get(qName) + if ok { + return q + } + q = newQueue(qName) + b.queues.Set(qName, q) + return q +} + +func (b *Broker) AddMessageToQueue(task *Task, queueName string) (*Queue, *Task, error) { + queue := b.NewQueue(queueName) + if task.ID == "" { + task.ID = NewID() + } + task.CreatedAt = time.Now() + queue.messages.Set(task.ID, task) + return queue, task, nil +} + +func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string { + consumerID, ok := GetConsumerID(ctx) + q, ok := b.queues.Get(queueName) + if !ok { + q = b.NewQueue(queueName) + } + con := &consumer{id: consumerID, conn: conn} + b.consumers.Set(consumerID, con) + q.consumers.Set(consumerID, con) + log.Printf("Consumer %s joined server on queue %s", consumerID, queueName) + return consumerID +} + +func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher { + publisherID, ok := GetPublisherID(ctx) + _, ok = b.queues.Get(queueName) + if !ok { + b.NewQueue(queueName) + } + con := &publisher{id: publisherID, conn: conn} + b.publishers.Set(publisherID, con) + return con +} + +func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) { + if queue, ok := b.queues.Get(msg.Queue); ok { + queue.consumers.ForEach(func(_ string, con *consumer) bool { + msg.Command = consts.MESSAGE_SEND + if err := codec.SendMessage(con.conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil { + log.Printf("Error sending Message: %v\n", err) + } + return true + }) + } +} + +func (b *Broker) waitForAck(conn net.Conn) error { + msg, err := codec.ReadMessage(conn, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) + if err != nil { + return err + } + if msg.Command == consts.MESSAGE_ACK { + log.Println("Received CONSUMER_ACK: Subscribed successfully") + return nil + } + return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command) +} + +func (b *Broker) publish(ctx context.Context, conn net.Conn, msg *codec.Message) { + pub := b.addPublisher(ctx, msg.Queue, conn) + ack := &codec.Message{ + Headers: msg.Headers, + Queue: msg.Queue, + Command: consts.PUBLISH_ACK, + } + if err := codec.SendMessage(conn, ack, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil { + log.Printf("Error sending PUBLISH_ACK: %v\n", err) + } + b.broadcastToConsumers(ctx, msg) + go func() { + select { + case <-ctx.Done(): + b.publishers.Del(pub.id) + } + }() +} + +func (b *Broker) subscribe(ctx context.Context, conn net.Conn, msg *codec.Message) { + consumerID := b.addConsumer(ctx, msg.Queue, conn) + ack := &codec.Message{ + Headers: msg.Headers, + Queue: msg.Queue, + Command: consts.SUBSCRIBE_ACK, + } + if err := codec.SendMessage(conn, ack, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil { + log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) + } + go func() { + select { + case <-ctx.Done(): + b.removeConsumer(msg.Queue, consumerID) + } + }() +} + +// Removes connection from the queue and broker +func (b *Broker) removeConsumer(queueName, consumerID string) { + if queue, ok := b.queues.Get(queueName); ok { + con, ok := queue.consumers.Get(consumerID) + if ok { + con.conn.Close() + queue.consumers.Del(consumerID) + } + b.queues.Del(queueName) + } +} diff --git a/v2/consumer.go b/v2/consumer.go new file mode 100644 index 0000000..44c26bf --- /dev/null +++ b/v2/consumer.go @@ -0,0 +1,191 @@ +package v2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "strings" + "sync" + "time" + + "github.com/oarkflow/mq/codec" + "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/utils" +) + +// Consumer structure to hold consumer-specific configurations and state. +type Consumer struct { + id string + handlers map[string]Handler + conn net.Conn + queues []string + opts Options +} + +// NewConsumer initializes a new consumer with the provided options. +func NewConsumer(id string, opts ...Option) *Consumer { + options := defaultOptions() + for _, opt := range opts { + opt(&options) + } + b := &Consumer{ + handlers: make(map[string]Handler), + id: id, + opts: options, + } + return b +} + +// Close closes the consumer's connection. +func (c *Consumer) Close() error { + return c.conn.Close() +} + +// Subscribe to a specific queue. +func (c *Consumer) subscribe(ctx context.Context, queue string) error { + headers := WithHeaders(ctx, map[string]string{ + consts.ConsumerKey: c.id, + consts.ContentType: consts.TypeJson, + }) + msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers) + if err := codec.SendMessage(c.conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption); err != nil { + return err + } + + return c.waitForAck(c.conn) +} + +func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error { + fmt.Println("Consumer closed") + return nil +} + +func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) { + fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) +} + +func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + headers := WithHeaders(ctx, map[string]string{ + consts.ConsumerKey: c.id, + consts.ContentType: consts.TypeJson, + }) + reply := codec.NewMessage(consts.MESSAGE_ACK, nil, msg.Queue, headers) + if err := codec.SendMessage(conn, reply, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption); err != nil { + fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err) + } + var task Task + err := json.Unmarshal(msg.Payload, &task) + if err != nil { + log.Println("Error unmarshalling message:", err) + return + } + ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) + result := c.ProcessTask(ctx, task) + result.MessageID = task.ID + result.Queue = msg.Queue + if result.Error != nil { + result.Status = "FAILED" + } else { + result.Status = "SUCCESS" + } + bt, _ := json.Marshal(result) + reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers) + if err := codec.SendMessage(conn, reply, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption); err != nil { + fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err) + } +} + +// ProcessTask handles a received task message and invokes the appropriate handler. +func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { + queue, _ := GetQueue(ctx) + handler, exists := c.handlers[queue] + if !exists { + return Result{Error: errors.New("No handler for queue " + queue)} + } + return handler(ctx, msg) +} + +// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration. +func (c *Consumer) AttemptConnect() error { + var err error + delay := c.opts.initialDelay + for i := 0; i < c.opts.maxRetries; i++ { + conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig) + if err == nil { + c.conn = conn + return nil + } + sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent) + fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration) + time.Sleep(sleepDuration) + delay *= 2 + if delay > c.opts.maxBackoff { + delay = c.opts.maxBackoff + } + } + + return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err) +} + +func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error { + msg, err := codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) + if err == nil { + ctx = SetHeaders(ctx, msg.Headers) + c.OnMessage(ctx, msg, conn) + return nil + } + if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { + c.OnClose(ctx, conn) + return err + } + c.OnError(ctx, conn, err) + return err +} + +// Consume starts the consumer to consume tasks from the queues. +func (c *Consumer) Consume(ctx context.Context) error { + err := c.AttemptConnect() + if err != nil { + return err + } + for _, q := range c.queues { + if err := c.subscribe(ctx, q); err != nil { + return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) + } + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + if err := c.readMessage(ctx, c.conn); err != nil { + log.Println("Error reading message:", err) + break + } + } + }() + + wg.Wait() + return nil +} + +func (c *Consumer) waitForAck(conn net.Conn) error { + msg, err := codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) + if err != nil { + return err + } + if msg.Command == consts.SUBSCRIBE_ACK { + log.Println("Received SUBSCRIBE_ACK: Subscribed successfully") + return nil + } + return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) +} + +// RegisterHandler registers a handler for a queue. +func (c *Consumer) RegisterHandler(queue string, handler Handler) { + c.queues = append(c.queues, queue) + c.handlers[queue] = handler +} diff --git a/v2/ctx.go b/v2/ctx.go new file mode 100644 index 0000000..28d2220 --- /dev/null +++ b/v2/ctx.go @@ -0,0 +1,145 @@ +package v2 + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "os" + + "github.com/oarkflow/xid" + + "github.com/oarkflow/mq/consts" +) + +func IsClosed(conn net.Conn) bool { + _, err := conn.Read(make([]byte, 1)) + if err != nil { + if err == net.ErrClosed { + return true + } + } + return false +} + +func SetHeaders(ctx context.Context, headers map[string]string) context.Context { + hd, ok := GetHeaders(ctx) + if !ok { + hd = make(map[string]string) + } + for key, val := range headers { + hd[key] = val + } + return context.WithValue(ctx, consts.HeaderKey, hd) +} + +func WithHeaders(ctx context.Context, headers map[string]string) map[string]string { + hd, ok := GetHeaders(ctx) + if !ok { + hd = make(map[string]string) + } + for key, val := range headers { + hd[key] = val + } + return hd +} + +func GetHeaders(ctx context.Context) (map[string]string, bool) { + headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) + return headers, ok +} + +func GetHeader(ctx context.Context, key string) (string, bool) { + headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) + if !ok { + return "", false + } + val, ok := headers[key] + return val, ok +} + +func GetContentType(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[consts.ContentType] + return contentType, ok +} + +func GetQueue(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[consts.QueueKey] + return contentType, ok +} + +func GetConsumerID(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[consts.ConsumerKey] + return contentType, ok +} + +func GetTriggerNode(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[consts.TriggerNode] + return contentType, ok +} + +func GetPublisherID(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[consts.PublisherKey] + return contentType, ok +} + +func NewID() string { + return xid.New().String() +} + +func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("failed to load client cert/key: %w", err) + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequireAndVerifyClientCert, + InsecureSkipVerify: true, + } + if len(caPath) > 0 && caPath[0] != "" { + caCert, err := os.ReadFile(caPath[0]) + if err != nil { + return nil, fmt.Errorf("failed to load CA cert: %w", err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + conn, err := tls.Dial("tcp", addr, tlsConfig) + if err != nil { + return nil, fmt.Errorf("failed to dial TLS connection: %w", err) + } + + return conn, nil +} + +func GetConnection(addr string, config TLSConfig) (net.Conn, error) { + if config.UseTLS { + return createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath) + } else { + return net.Dial("tcp", addr) + } +} diff --git a/v2/options.go b/v2/options.go new file mode 100644 index 0000000..12cf261 --- /dev/null +++ b/v2/options.go @@ -0,0 +1,123 @@ +package v2 + +import ( + "context" + "encoding/json" + "time" +) + +type Result struct { + Payload json.RawMessage `json:"payload"` + Queue string `json:"queue"` + MessageID string `json:"message_id"` + Error error `json:"error,omitempty"` + Status string `json:"status"` +} + +type TLSConfig struct { + UseTLS bool + CertPath string + KeyPath string + CAPath string +} + +type Options struct { + syncMode bool + brokerAddr string + callback []func(context.Context, Result) Result + maxRetries int + initialDelay time.Duration + maxBackoff time.Duration + jitterPercent float64 + tlsConfig TLSConfig + aesKey json.RawMessage + hmacKey json.RawMessage + enableEncryption bool +} + +func defaultOptions() Options { + return Options{ + syncMode: false, + brokerAddr: ":8080", + maxRetries: 5, + initialDelay: 2 * time.Second, + maxBackoff: 20 * time.Second, + jitterPercent: 0.5, + } +} + +// Option defines a function type for setting options. +type Option func(*Options) + +func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { + return func(opts *Options) { + opts.aesKey = aesKey + opts.hmacKey = hmacKey + opts.enableEncryption = enableEncryption + } +} + +// WithBrokerURL - +func WithBrokerURL(url string) Option { + return func(opts *Options) { + opts.brokerAddr = url + } +} + +// WithTLS - Option to enable/disable TLS +func WithTLS(enableTLS bool, certPath, keyPath string) Option { + return func(o *Options) { + o.tlsConfig.UseTLS = enableTLS + o.tlsConfig.CertPath = certPath + o.tlsConfig.KeyPath = keyPath + } +} + +// WithCAPath - Option to enable/disable TLS +func WithCAPath(caPath string) Option { + return func(o *Options) { + o.tlsConfig.CAPath = caPath + } +} + +// WithSyncMode - +func WithSyncMode(mode bool) Option { + return func(opts *Options) { + opts.syncMode = mode + } +} + +// WithMaxRetries - +func WithMaxRetries(val int) Option { + return func(opts *Options) { + opts.maxRetries = val + } +} + +// WithInitialDelay - +func WithInitialDelay(val time.Duration) Option { + return func(opts *Options) { + opts.initialDelay = val + } +} + +// WithMaxBackoff - +func WithMaxBackoff(val time.Duration) Option { + return func(opts *Options) { + opts.maxBackoff = val + } +} + +// WithCallback - +func WithCallback(val ...func(context.Context, Result) Result) Option { + return func(opts *Options) { + opts.callback = val + } +} + +// WithJitterPercent - +func WithJitterPercent(val float64) Option { + return func(opts *Options) { + opts.jitterPercent = val + } +} diff --git a/v2/publisher.go b/v2/publisher.go new file mode 100644 index 0000000..914b610 --- /dev/null +++ b/v2/publisher.go @@ -0,0 +1,106 @@ +package v2 + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net" + "time" + + "github.com/oarkflow/mq/codec" + "github.com/oarkflow/mq/consts" +) + +type Publisher struct { + id string + opts Options +} + +func NewPublisher(id string, opts ...Option) *Publisher { + options := defaultOptions() + for _, opt := range opts { + opt(&options) + } + b := &Publisher{id: id, opts: options} + return b +} + +func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error { + headers := WithHeaders(ctx, map[string]string{ + consts.PublisherKey: p.id, + consts.ContentType: consts.TypeJson, + }) + if task.ID == "" { + task.ID = NewID() + } + task.CreatedAt = time.Now() + payload, err := json.Marshal(task) + if err != nil { + return err + } + msg := codec.NewMessage(command, payload, queue, headers) + if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil { + return err + } + + return p.waitForAck(conn) +} + +func (p *Publisher) waitForAck(conn net.Conn) error { + msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption) + if err != nil { + return err + } + if msg.Command == consts.PUBLISH_ACK { + log.Println("Received PUBLISH_ACK: Message published successfully") + return nil + } + return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command) +} + +func (p *Publisher) waitForResponse(conn net.Conn) Result { + msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption) + if err != nil { + return Result{Error: err} + } + if msg.Command == consts.RESPONSE { + var result Result + err = json.Unmarshal(msg.Payload, &result) + return result + } + err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command) + return Result{Error: err} +} + +func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error { + conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) + if err != nil { + return fmt.Errorf("failed to connect to broker: %w", err) + } + defer conn.Close() + return p.send(ctx, queue, task, conn, consts.PUBLISH) +} + +func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error { + fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr()) + return nil +} + +func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) { + fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr()) +} + +func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result { + ctx = SetHeaders(ctx, map[string]string{ + consts.AwaitResponseKey: "true", + }) + conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) + if err != nil { + err = fmt.Errorf("failed to connect to broker: %w", err) + return Result{Error: err} + } + defer conn.Close() + err = p.send(ctx, queue, task, conn, consts.PUBLISH) + return p.waitForResponse(conn) +}