diff --git a/broker.go b/broker.go index 1342ab4..1e59c64 100644 --- a/broker.go +++ b/broker.go @@ -208,7 +208,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M taskID, _ := jsonparser.GetString(msg.Payload, "id") log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID) - ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) + ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) if err := b.send(conn, ack); err != nil { log.Printf("Error sending PUBLISH_ACK: %v\n", err) } @@ -361,7 +361,7 @@ func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, cons fn := func(queue *Queue) { con, ok := queue.consumers.Get(consumerID) if ok { - ack := codec.NewMessage(cmd, []byte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID}) + ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID}) err := b.send(con.conn, ack) if err == nil { con.state = state diff --git a/codec/codec.go b/codec/codec.go index d7c1b8f..bb8ee41 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -2,13 +2,20 @@ package codec import ( "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" "encoding/binary" + "encoding/hex" "encoding/json" "fmt" "io" "sync" "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/utils" ) type Message struct { @@ -31,18 +38,18 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) { m.m.Lock() defer m.m.Unlock() - var buf bytes.Buffer - if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil { + buf := bytes.NewBuffer(make([]byte, 0, 512)) + if err := writeLengthPrefixedJSON(buf, m.Headers); err != nil { return nil, "", fmt.Errorf("error serializing headers: %v", err) } - if err := writeLengthPrefixed(&buf, []byte(m.Queue)); err != nil { - return nil, "", fmt.Errorf("error serializing topic: %v", err) + if err := writeLengthPrefixed(buf, utils.ToByte(m.Queue)); err != nil { + return nil, "", fmt.Errorf("error serializing queue: %v", err) } - if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil { + if err := binary.Write(buf, binary.LittleEndian, m.Command); err != nil { return nil, "", fmt.Errorf("error serializing command: %v", err) } - if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil { - return nil, "", err + if err := writePayload(buf, aesKey, m.Payload, encrypt); err != nil { + return nil, "", fmt.Errorf("error serializing payload: %v", err) } messageBytes := buf.Bytes() hmacSignature := CalculateHMAC(hmacKey, messageBytes) @@ -51,16 +58,17 @@ func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, strin func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) { if !VerifyHMAC(hmacKey, data, receivedHMAC) { - return nil, fmt.Errorf("HMAC verification failed %s", string(hmacKey)) + return nil, fmt.Errorf("HMAC verification failed") } buf := bytes.NewReader(data) headers := make(map[string]string) + if err := readLengthPrefixedJSON(buf, &headers); err != nil { return nil, fmt.Errorf("error deserializing headers: %v", err) } - topic, err := readLengthPrefixedString(buf) + queue, err := readLengthPrefixedString(buf) if err != nil { - return nil, fmt.Errorf("error deserializing topic: %v", err) + return nil, fmt.Errorf("error deserializing queue: %v", err) } var command consts.CMD if err := binary.Read(buf, binary.LittleEndian, &command); err != nil { @@ -72,7 +80,7 @@ func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool } return &Message{ Headers: headers, - Queue: topic, + Queue: queue, Command: command, Payload: payload, }, nil @@ -139,7 +147,7 @@ func readLengthPrefixedString(r *bytes.Reader) (string, error) { if err != nil { return "", err } - return string(data), nil + return utils.FromByte(data), nil } func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, encrypt bool) error { @@ -147,7 +155,6 @@ func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, enc if err != nil { return fmt.Errorf("error marshalling payload: %v", err) } - var encryptedPayload, nonce []byte if encrypt { encryptedPayload, nonce, err = EncryptPayload(aesKey, payloadBytes) @@ -157,11 +164,9 @@ func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, enc } else { encryptedPayload = payloadBytes } - if err := writeLengthPrefixed(buf, encryptedPayload); err != nil { return err } - if encrypt { buf.Write(nonce) } @@ -192,6 +197,7 @@ func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage, } return payload, nil } + func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error { if err := binary.Write(conn, binary.LittleEndian, uint32(len(messageBytes))); err != nil { return err @@ -199,7 +205,11 @@ func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature str if _, err := conn.Write(messageBytes); err != nil { return err } - if _, err := conn.Write([]byte(hmacSignature)); err != nil { + hmacBytes, err := hex.DecodeString(hmacSignature) + if err != nil { + return err + } + if _, err := conn.Write(hmacBytes); err != nil { return err } return nil @@ -214,12 +224,54 @@ func readMessageWithHMAC(conn io.Reader) ([]byte, string, error) { if _, err := io.ReadFull(conn, data); err != nil { return nil, "", err } - - hmacBytes := make([]byte, 64) + hmacBytes := make([]byte, 32) if _, err := io.ReadFull(conn, hmacBytes); err != nil { return nil, "", err } - receivedHMAC := string(hmacBytes) - + receivedHMAC := hex.EncodeToString(hmacBytes) return data, receivedHMAC, nil } + +func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, err + } + nonce := make([]byte, aesGCM.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, nil, err + } + ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil) + return ciphertext, nonce, nil +} + +func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + return plaintext, nil +} + +func CalculateHMAC(key []byte, data []byte) string { + h := hmac.New(sha256.New, key) + h.Write(data) + return hex.EncodeToString(h.Sum(nil)) +} + +func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool { + expectedHMAC := CalculateHMAC(key, data) + return hmac.Equal(utils.ToByte(receivedHMAC), utils.ToByte(expectedHMAC)) +} diff --git a/codec/encrypt.go b/codec/encrypt.go deleted file mode 100644 index 50d3495..0000000 --- a/codec/encrypt.go +++ /dev/null @@ -1,51 +0,0 @@ -package codec - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "encoding/hex" - "io" -) - -func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, nil, err - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, nil, err - } - nonce := make([]byte, aesGCM.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { - return nil, nil, err - } - ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil) - return ciphertext, nonce, nil -} - -func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - return aesGCM.Open(nil, nonce, ciphertext, nil) -} - -func CalculateHMAC(key []byte, data []byte) string { - h := hmac.New(sha256.New, key) - h.Write(data) - return hex.EncodeToString(h.Sum(nil)) -} - -func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool { - expectedHMAC := CalculateHMAC(key, data) - return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC)) -} diff --git a/consumer.go b/consumer.go index 43ed186..9918295 100644 --- a/consumer.go +++ b/consumer.go @@ -57,7 +57,7 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error { consts.ConsumerKey: c.id, consts.ContentType: consts.TypeJson, }) - msg := codec.NewMessage(consts.SUBSCRIBE, []byte("{}"), queue, headers) + msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers) if err := c.send(c.conn, msg); err != nil { return err } @@ -104,7 +104,7 @@ func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn consts.QueueKey: msg.Queue, }) taskID, _ := jsonparser.GetString(msg.Payload, "id") - reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) + reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) if err := c.send(conn, reply); err != nil { fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err) } @@ -158,7 +158,7 @@ func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, er consts.ConsumerKey: c.id, consts.ContentType: consts.TypeJson, }) - reply := codec.NewMessage(consts.MESSAGE_DENY, []byte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers) + reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers) if sendErr := c.send(c.conn, reply); sendErr != nil { log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr) } diff --git a/dag/dag.go b/dag/dag.go index 70b7a4d..1f164ba 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -44,8 +44,8 @@ type Edge struct { } type DAG struct { - FirstNode string - Nodes map[string]*Node + startNode string + nodes map[string]*Node server *mq.Broker taskContext map[string]*TaskManager conditions map[string]map[string]string @@ -56,7 +56,7 @@ type DAG struct { func NewDAG(opts ...mq.Option) *DAG { d := &DAG{ - Nodes: make(map[string]*Node), + nodes: make(map[string]*Node), taskContext: make(map[string]*TaskManager), conditions: make(map[string]map[string]string), } @@ -74,19 +74,27 @@ func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { } func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { - if node, ok := tm.Nodes[topic]; ok { + if node, ok := tm.nodes[topic]; ok { log.Printf("DAG - CONSUMER ~> ready on %s", topic) node.isReady = true } } func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) { - if node, ok := tm.Nodes[topic]; ok { + if node, ok := tm.nodes[topic]; ok { log.Printf("DAG - CONSUMER ~> down on %s", topic) node.isReady = false } } +func (tm *DAG) SetStartNode(node string) { + tm.startNode = node +} + +func (tm *DAG) GetStartNode() string { + return tm.startNode +} + func (tm *DAG) Start(ctx context.Context, addr string) error { if !tm.server.SyncMode() { go func() { @@ -95,7 +103,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { panic(err) } }() - for _, con := range tm.Nodes { + for _, con := range tm.nodes { if con.isReady { go func(con *Node) { time.Sleep(1 * time.Second) @@ -122,13 +130,13 @@ func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) { tm.mu.Lock() defer tm.mu.Unlock() con := mq.NewConsumer(key, key, handler, tm.opts...) - tm.Nodes[key] = &Node{ + tm.nodes[key] = &Node{ Key: key, consumer: con, isReady: true, } if len(firstNode) > 0 && firstNode[0] { - tm.FirstNode = key + tm.startNode = key } } @@ -138,11 +146,11 @@ func (tm *DAG) AddDeferredNode(key string, firstNode ...bool) error { } tm.mu.Lock() defer tm.mu.Unlock() - tm.Nodes[key] = &Node{ + tm.nodes[key] = &Node{ Key: key, } if len(firstNode) > 0 && firstNode[0] { - tm.FirstNode = key + tm.startNode = key } return nil } @@ -150,7 +158,7 @@ func (tm *DAG) AddDeferredNode(key string, firstNode ...bool) error { func (tm *DAG) IsReady() bool { tm.mu.Lock() defer tm.mu.Unlock() - for _, node := range tm.Nodes { + for _, node := range tm.nodes { if !node.isReady { return false } @@ -167,11 +175,11 @@ func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { tm.mu.Lock() defer tm.mu.Unlock() - fromNode, ok := tm.Nodes[from] + fromNode, ok := tm.nodes[from] if !ok { return } - toNode, ok := tm.Nodes[to] + toNode, ok := tm.nodes[to] if !ok { return } @@ -183,25 +191,28 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { } func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { + tm.mu.RLock() // lock when reading `paused` if tm.paused { + tm.mu.RUnlock() return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} } + tm.mu.RUnlock() if !tm.IsReady() { return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")} } val := ctx.Value("initial_node") initialNode, ok := val.(string) if !ok { - if tm.FirstNode == "" { + if tm.startNode == "" { firstNode := tm.FindInitialNode() if firstNode != nil { - tm.FirstNode = firstNode.Key + tm.startNode = firstNode.Key } } - if tm.FirstNode == "" { + if tm.startNode == "" { return mq.Result{Error: fmt.Errorf("initial node not found")} } - initialNode = tm.FirstNode + initialNode = tm.startNode } tm.mu.Lock() defer tm.mu.Unlock() @@ -214,7 +225,7 @@ func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { func (tm *DAG) FindInitialNode() *Node { incomingEdges := make(map[string]bool) connectedNodes := make(map[string]bool) - for _, node := range tm.Nodes { + for _, node := range tm.nodes { for _, edge := range node.Edges { if edge.Type.IsValid() { connectedNodes[node.Key] = true @@ -229,7 +240,7 @@ func (tm *DAG) FindInitialNode() *Node { } } } - for nodeID, node := range tm.Nodes { + for nodeID, node := range tm.nodes { if !incomingEdges[nodeID] && connectedNodes[nodeID] { return node } @@ -238,24 +249,28 @@ func (tm *DAG) FindInitialNode() *Node { } func (tm *DAG) Pause() { + tm.mu.Lock() // lock when modifying `paused` + defer tm.mu.Unlock() tm.paused = true log.Printf("DAG - PAUSED") } func (tm *DAG) Resume() { + tm.mu.Lock() // lock when modifying `paused` + defer tm.mu.Unlock() tm.paused = false log.Printf("DAG - RESUMED") } func (tm *DAG) PauseConsumer(id string) { - if node, ok := tm.Nodes[id]; ok { + if node, ok := tm.nodes[id]; ok { node.consumer.Pause() node.isReady = false } } func (tm *DAG) ResumeConsumer(id string) { - if node, ok := tm.Nodes[id]; ok { + if node, ok := tm.nodes[id]; ok { node.consumer.Resume() node.isReady = true } diff --git a/dag/task_manager.go b/dag/task_manager.go index e09842c..bceafb5 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -33,8 +33,13 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { } } +func (tm *TaskManager) updateTS(result *mq.Result) { + result.CreatedAt = tm.createdAt + result.ProcessedAt = time.Now() +} + func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { - node, ok := tm.dag.Nodes[nodeID] + node, ok := tm.dag.nodes[nodeID] if !ok { return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } @@ -45,8 +50,7 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j if awaitResponse != "true" { go func() { finalResult := <-tm.finalResult - finalResult.CreatedAt = tm.createdAt - finalResult.ProcessedAt = time.Now() + tm.updateTS(&finalResult) if tm.dag.server.NotifyHandler() != nil { tm.dag.server.NotifyHandler()(ctx, finalResult) } @@ -54,8 +58,7 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j return mq.Result{CreatedAt: tm.createdAt, TaskID: tm.taskID, Topic: nodeID, Status: "PENDING"} } else { finalResult := <-tm.finalResult - finalResult.CreatedAt = tm.createdAt - finalResult.ProcessedAt = time.Now() + tm.updateTS(&finalResult) if tm.dag.server.NotifyHandler() != nil { tm.dag.server.NotifyHandler()(ctx, finalResult) } @@ -79,7 +82,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. if result.Topic != "" { atomic.AddInt64(&tm.waitingCallback, -1) } - node, ok := tm.dag.Nodes[result.Topic] + node, ok := tm.dag.nodes[result.Topic] if !ok { return result } @@ -88,7 +91,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. if result.Status != "" { if conditions, ok := tm.dag.conditions[result.Topic]; ok { if targetNodeKey, ok := conditions[result.Status]; ok { - if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { + if targetNode, ok := tm.dag.nodes[targetNodeKey]; ok { edges = append(edges, Edge{From: node, To: targetNode}) } } @@ -147,12 +150,7 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result if err != nil { return mq.HandleError(ctx, err) } - return mq.Result{ - TaskID: tm.taskID, - Payload: finalOutput, - Status: status, - Topic: topic, - } + return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic} case mq.Result: return res } diff --git a/examples/hmac.go b/examples/hmac.go index e37da79..e8e0038 100644 --- a/examples/hmac.go +++ b/examples/hmac.go @@ -2,33 +2,25 @@ package main import ( "crypto/rand" - "encoding/base64" + "encoding/hex" "fmt" - "log" ) -func GenerateSecretKey() (string, error) { - // Create a byte slice to hold 32 random bytes - key := make([]byte, 32) - - // Fill the slice with secure random bytes +func generateHMACKey() ([]byte, error) { + key := make([]byte, 32) // 32 bytes = 256 bits _, err := rand.Read(key) if err != nil { - return "", err + return nil, err } - - // Encode the byte slice to a Base64 string - secretKey := base64.StdEncoding.EncodeToString(key) - - // Return the first 32 characters - return secretKey[:32], nil + return key, nil } func main() { - secretKey, err := GenerateSecretKey() + hmacKey, err := generateHMACKey() if err != nil { - log.Fatalf("Error generating secret key: %v", err) + fmt.Println("Error generating HMAC key:", err) + return } - fmt.Println("Generated Secret Key:", secretKey) + fmt.Println("HMAC Key (hex):", hex.EncodeToString(hmacKey)) } diff --git a/options.go b/options.go index 5dfc9df..92de05f 100644 --- a/options.go +++ b/options.go @@ -86,7 +86,7 @@ func defaultOptions() Options { maxBackoff: 20 * time.Second, jitterPercent: 0.5, queueSize: 100, - hmacKey: []byte(`a9f4b9415485b70275673b5920182796ea497b5e093ead844a43ea5d77cbc24f`), + hmacKey: []byte(`475f3adc6be9ee6f5357020e2922ff5b8f971598e175878e617d19df584bc648`), numOfWorkers: runtime.NumCPU(), maxMemoryLoad: 5000000, } diff --git a/utils/str.go b/utils/str.go new file mode 100644 index 0000000..e9f2855 --- /dev/null +++ b/utils/str.go @@ -0,0 +1,16 @@ +package utils + +import ( + "unsafe" +) + +func ToByte(s string) []byte { + p := unsafe.StringData(s) + b := unsafe.Slice(p, len(s)) + return b +} + +func FromByte(b []byte) string { + p := unsafe.SliceData(b) + return unsafe.String(p, len(b)) +}