From a5bc632b1ee584477b9b4dd9a63c154426d01543 Mon Sep 17 00:00:00 2001 From: sujit Date: Mon, 14 Oct 2024 05:54:01 +0545 Subject: [PATCH] feat: update --- broker.go | 4 +- codec/codec.go | 274 +++++++-------------------------------- consumer.go | 6 +- dag/dag.go | 15 +++ examples/dag_consumer.go | 8 +- examples/server.go | 1 + examples/tasks/tasks.go | 4 + go.mod | 3 + go.sum | 12 ++ publisher.go | 6 +- 10 files changed, 92 insertions(+), 241 deletions(-) diff --git a/broker.go b/broker.go index 60b7bd5..57bd824 100644 --- a/broker.go +++ b/broker.go @@ -283,11 +283,11 @@ func (b *Broker) Start(ctx context.Context) error { } func (b *Broker) send(conn net.Conn, msg *codec.Message) error { - return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) + return codec.SendMessage(conn, msg) } func (b *Broker) receive(c net.Conn) (*codec.Message, error) { - return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) + return codec.ReadMessage(c) } func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) { diff --git a/codec/codec.go b/codec/codec.go index bb8ee41..172f11b 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -1,32 +1,26 @@ package codec import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" "encoding/binary" - "encoding/hex" - "encoding/json" - "fmt" - "io" + "net" "sync" + "github.com/vmihailenco/msgpack/v5" + "github.com/oarkflow/mq/consts" - "github.com/oarkflow/mq/utils" ) +// Message represents the structure of our message. type Message struct { - Headers map[string]string `json:"h"` - Queue string `json:"q"` - Command consts.CMD `json:"c"` - Payload json.RawMessage `json:"p"` + Headers map[string]string `msgpack:"h"` + Queue string `msgpack:"q"` + Command consts.CMD `msgpack:"c"` + Payload []byte `msgpack:"p"` // Using []byte instead of json.RawMessage m sync.RWMutex } -func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message { +// NewMessage creates a new Message instance. +func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message { return &Message{ Headers: headers, Queue: queue, @@ -35,243 +29,63 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m } } -func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) { - m.m.Lock() - defer m.m.Unlock() - buf := bytes.NewBuffer(make([]byte, 0, 512)) - if err := writeLengthPrefixedJSON(buf, m.Headers); err != nil { - return nil, "", fmt.Errorf("error serializing headers: %v", err) - } - if err := writeLengthPrefixed(buf, utils.ToByte(m.Queue)); err != nil { - return nil, "", fmt.Errorf("error serializing queue: %v", err) - } - if err := binary.Write(buf, binary.LittleEndian, m.Command); err != nil { - return nil, "", fmt.Errorf("error serializing command: %v", err) - } - if err := writePayload(buf, aesKey, m.Payload, encrypt); err != nil { - return nil, "", fmt.Errorf("error serializing payload: %v", err) - } - messageBytes := buf.Bytes() - hmacSignature := CalculateHMAC(hmacKey, messageBytes) - return messageBytes, hmacSignature, nil -} +// Serialize encodes the Message to a byte slice using MessagePack. +func (m *Message) Serialize() ([]byte, error) { + m.m.RLock() + defer m.m.RUnlock() -func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) { - if !VerifyHMAC(hmacKey, data, receivedHMAC) { - return nil, fmt.Errorf("HMAC verification failed") - } - buf := bytes.NewReader(data) - headers := make(map[string]string) - - if err := readLengthPrefixedJSON(buf, &headers); err != nil { - return nil, fmt.Errorf("error deserializing headers: %v", err) - } - queue, err := readLengthPrefixedString(buf) - if err != nil { - return nil, fmt.Errorf("error deserializing queue: %v", err) - } - var command consts.CMD - if err := binary.Read(buf, binary.LittleEndian, &command); err != nil { - return nil, fmt.Errorf("error deserializing command: %v", err) - } - payload, err := readPayload(buf, aesKey, decrypt) - if err != nil { - return nil, fmt.Errorf("error deserializing payload: %v", err) - } - return &Message{ - Headers: headers, - Queue: queue, - Command: command, - Payload: payload, - }, nil -} - -func SendMessage(conn io.Writer, msg *Message, aesKey, hmacKey []byte, encrypt bool) error { - sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey, encrypt) - if err != nil { - return fmt.Errorf("error serializing message: %v", err) - } - if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil { - return fmt.Errorf("error sending message: %v", err) - } - return nil -} - -func ReadMessage(conn io.Reader, aesKey, hmacKey []byte, decrypt bool) (*Message, error) { - data, receivedHMAC, err := readMessageWithHMAC(conn) + data, err := msgpack.Marshal(m) if err != nil { return nil, err } - return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt) -} -func writeLengthPrefixed(buf *bytes.Buffer, data []byte) error { - length := uint32(len(data)) - if err := binary.Write(buf, binary.LittleEndian, length); err != nil { - return err - } - buf.Write(data) - return nil -} - -func writeLengthPrefixedJSON(buf *bytes.Buffer, v any) error { - data, err := json.Marshal(v) - if err != nil { - return err - } - return writeLengthPrefixed(buf, data) -} - -func readLengthPrefixed(r *bytes.Reader) ([]byte, error) { - var length uint32 - if err := binary.Read(r, binary.LittleEndian, &length); err != nil { - return nil, err - } - data := make([]byte, length) - if _, err := r.Read(data); err != nil { - return nil, err - } return data, nil } -func readLengthPrefixedJSON(r *bytes.Reader, v any) error { - data, err := readLengthPrefixed(r) - if err != nil { - return err - } - return json.Unmarshal(data, v) -} - -func readLengthPrefixedString(r *bytes.Reader) (string, error) { - data, err := readLengthPrefixed(r) - if err != nil { - return "", err - } - return utils.FromByte(data), nil -} - -func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, encrypt bool) error { - payloadBytes, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("error marshalling payload: %v", err) - } - var encryptedPayload, nonce []byte - if encrypt { - encryptedPayload, nonce, err = EncryptPayload(aesKey, payloadBytes) - if err != nil { - return err - } - } else { - encryptedPayload = payloadBytes - } - if err := writeLengthPrefixed(buf, encryptedPayload); err != nil { - return err - } - if encrypt { - buf.Write(nonce) - } - return nil -} - -func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage, error) { - encryptedPayload, err := readLengthPrefixed(r) - if err != nil { +// Deserialize decodes a byte slice to a Message instance using MessagePack. +func Deserialize(data []byte) (*Message, error) { + var msg Message + if err := msgpack.Unmarshal(data, &msg); err != nil { return nil, err } - var payloadBytes []byte - if decrypt { - nonce := make([]byte, 12) - if _, err := r.Read(nonce); err != nil { - return nil, err - } - payloadBytes, err = DecryptPayload(aesKey, encryptedPayload, nonce) - if err != nil { - return nil, fmt.Errorf("error decrypting payload: %v", err) - } - } else { - payloadBytes = encryptedPayload - } - var payload json.RawMessage - if err := json.Unmarshal(payloadBytes, &payload); err != nil { - return nil, fmt.Errorf("error unmarshalling payload: %v", err) - } - return payload, nil + + return &msg, nil } -func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error { - if err := binary.Write(conn, binary.LittleEndian, uint32(len(messageBytes))); err != nil { - return err - } - if _, err := conn.Write(messageBytes); err != nil { - return err - } - hmacBytes, err := hex.DecodeString(hmacSignature) +// SendMessage sends a Message over a net.Conn. +func SendMessage(conn net.Conn, msg *Message) error { + data, err := msg.Serialize() if err != nil { return err } - if _, err := conn.Write(hmacBytes); err != nil { + + // Send the length of the data followed by the data itself + length := make([]byte, 4) + binary.BigEndian.PutUint32(length, uint32(len(data))) + + if _, err := conn.Write(length); err != nil { + return err + } + if _, err := conn.Write(data); err != nil { return err } return nil } -func readMessageWithHMAC(conn io.Reader) ([]byte, string, error) { - var length uint32 - if err := binary.Read(conn, binary.LittleEndian, &length); err != nil { - return nil, "", err +// ReadMessage receives a Message from a net.Conn. +func ReadMessage(conn net.Conn) (*Message, error) { + // Read the length of the incoming message + lengthBytes := make([]byte, 4) + if _, err := conn.Read(lengthBytes); err != nil { + return nil, err } + length := binary.BigEndian.Uint32(lengthBytes) + + // Read the actual message data data := make([]byte, length) - if _, err := io.ReadFull(conn, data); err != nil { - return nil, "", err - } - hmacBytes := make([]byte, 32) - if _, err := io.ReadFull(conn, hmacBytes); err != nil { - return nil, "", err - } - receivedHMAC := hex.EncodeToString(hmacBytes) - return data, receivedHMAC, nil -} - -func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, nil, err - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, nil, err - } - nonce := make([]byte, aesGCM.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return nil, nil, err - } - ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil) - return ciphertext, nonce, nil -} - -func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { + if _, err := conn.Read(data); err != nil { return nil, err } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil) - if err != nil { - return nil, err - } - return plaintext, nil -} -func CalculateHMAC(key []byte, data []byte) string { - h := hmac.New(sha256.New, key) - h.Write(data) - return hex.EncodeToString(h.Sum(nil)) -} - -func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool { - expectedHMAC := CalculateHMAC(key, data) - return hmac.Equal(utils.ToByte(receivedHMAC), utils.ToByte(expectedHMAC)) + return Deserialize(data) } diff --git a/consumer.go b/consumer.go index d0ba460..936f703 100644 --- a/consumer.go +++ b/consumer.go @@ -46,11 +46,11 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons } func (c *Consumer) send(conn net.Conn, msg *codec.Message) error { - return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) + return codec.SendMessage(conn, msg) } func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) { - return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) + return codec.ReadMessage(conn) } func (c *Consumer) Close() error { @@ -66,7 +66,7 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error { headers := HeadersWithConsumerID(ctx, c.id) msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers) if err := c.send(c.conn, msg); err != nil { - return err + return fmt.Errorf("error while trying to subscribe: %v", err) } return c.waitForAck(c.conn) } diff --git a/dag/dag.go b/dag/dag.go index fb53e25..3724c43 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -184,6 +184,21 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { return http.ListenAndServe(addr, nil) } +func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) { + dag.AssignTopic(key) + tm.mu.Lock() + defer tm.mu.Unlock() + tm.nodes[key] = &Node{ + Name: name, + Key: key, + processor: dag, + isReady: true, + } + if len(firstNode) > 0 && firstNode[0] { + tm.startNode = key + } +} + func (tm *DAG) AddNode(name, key string, handler mq.Handler, firstNode ...bool) { tm.mu.Lock() defer tm.mu.Unlock() diff --git a/examples/dag_consumer.go b/examples/dag_consumer.go index 388e8d2..5faac02 100644 --- a/examples/dag_consumer.go +++ b/examples/dag_consumer.go @@ -13,7 +13,6 @@ func main() { d := dag.NewDAG("Sample DAG", "sample-dag", mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse), - mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")), ) d.AddNode("C", "C", tasks.Node3, true) d.AddNode("D", "D", tasks.Node4) @@ -27,6 +26,9 @@ func main() { d.AddEdge("Label 2", "D", "F") d.AddEdge("Label 3", "E", "F") d.AddEdge("Label 4", "F", "G", "H") - d.AssignTopic("queue1") - d.Consume(context.Background()) + d.AssignTopic("D") + err := d.Consume(context.Background()) + if err != nil { + panic(err) + } } diff --git a/examples/server.go b/examples/server.go index 2c4dcb5..c755959 100644 --- a/examples/server.go +++ b/examples/server.go @@ -2,6 +2,7 @@ package main import ( "context" + mq2 "github.com/oarkflow/mq" "github.com/oarkflow/mq/examples/tasks" diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index 08d1b93..574179b 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -79,3 +79,7 @@ func Callback(_ context.Context, task mq.Result) mq.Result { func NotifyResponse(_ context.Context, result mq.Result) { log.Printf("DAG - FINAL_RESPONSE ~> TaskID: %s, Payload: %s, Topic: %s, Error: %s", result.TaskID, result.Payload, result.Topic, result.Error) } + +func NotifySubDAGResponse(_ context.Context, result mq.Result) { + log.Printf("SUB DAG - FINAL_RESPONSE ~> TaskID: %s, Payload: %s, Topic: %s, Error: %s", result.TaskID, result.Payload, result.Topic, result.Error) +} diff --git a/go.mod b/go.mod index be6dc44..7c6f2cc 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,7 @@ go 1.23.0 require ( github.com/oarkflow/xid v1.2.5 github.com/oarkflow/xsync v0.0.5 + github.com/vmihailenco/msgpack/v5 v5.4.1 ) + +require github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index e0e25be..b879d58 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,16 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/oarkflow/xid v1.2.5 h1:6RcNJm9+oZ/B647gkME9trCzhpxGQaSdNoD56Vmkeho= github.com/oarkflow/xid v1.2.5/go.mod h1:jG4YBh+swbjlWApGWDBYnsJEa7hi3CCpmuqhB3RAxVo= github.com/oarkflow/xsync v0.0.5 h1:7HBQjmDus4YFLQFC5D197TB4c2YJTVwsTFuqk5zWKBM= github.com/oarkflow/xsync v0.0.5/go.mod h1:KAaEc506OEd3ISxfhgUBKxk8eQzkz+mb0JkpGGd/QwU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/publisher.go b/publisher.go index 4ce1b0a..44e16b3 100644 --- a/publisher.go +++ b/publisher.go @@ -37,7 +37,7 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net. return err } msg := codec.NewMessage(command, payload, queue, headers) - if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil { + if err := codec.SendMessage(conn, msg); err != nil { return err } @@ -45,7 +45,7 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net. } func (p *Publisher) waitForAck(conn net.Conn) error { - msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption) + msg, err := codec.ReadMessage(conn) if err != nil { return err } @@ -58,7 +58,7 @@ func (p *Publisher) waitForAck(conn net.Conn) error { } func (p *Publisher) waitForResponse(conn net.Conn) Result { - msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption) + msg, err := codec.ReadMessage(conn) if err != nil { return Result{Error: err} }