diff --git a/broker.go b/broker.go index 5eb3a91..caf4ff2 100644 --- a/broker.go +++ b/broker.go @@ -11,6 +11,8 @@ import ( "time" "github.com/oarkflow/xsync" + + "github.com/oarkflow/mq/consts" ) type consumer struct { @@ -78,7 +80,7 @@ type Task struct { type Command struct { ID string `json:"id"` - Command CMD `json:"command"` + Command consts.CMD `json:"command"` Queue string `json:"queue"` MessageID string `json:"message_id"` Payload json.RawMessage `json:"payload,omitempty"` // Used for carrying the task payload @@ -265,7 +267,7 @@ func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Con consumerID, ok := GetConsumerID(ctx) defer func() { cmd := Command{ - Command: SUBSCRIBE_ACK, + Command: consts.SUBSCRIBE_ACK, Queue: queueName, Error: "", } @@ -335,7 +337,7 @@ func (b *Broker) handleTaskMessage(ctx context.Context, _ net.Conn, msg Result) func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error { status := "PUBLISH" - if msg.Command == REQUEST { + if msg.Command == consts.REQUEST { status = "REQUEST" } b.addPublisher(ctx, msg.Queue, conn) @@ -360,10 +362,10 @@ func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error func (b *Broker) handleCommandMessage(ctx context.Context, conn net.Conn, msg Command) error { switch msg.Command { - case SUBSCRIBE: + case consts.SUBSCRIBE: b.subscribe(ctx, msg.Queue, conn) return nil - case PUBLISH, REQUEST: + case consts.PUBLISH, consts.REQUEST: return b.publish(ctx, conn, msg) default: return fmt.Errorf("unknown command: %d", msg.Command) diff --git a/codec/codec.go b/codec/codec.go new file mode 100644 index 0000000..8aa5953 --- /dev/null +++ b/codec/codec.go @@ -0,0 +1,203 @@ +package codec + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "io" + + "github.com/oarkflow/mq/consts" +) + +type Message struct { + Headers map[string]string `json:"h"` + Topic string `json:"t"` + Command consts.CMD `json:"c"` + Payload json.RawMessage `json:"p"` +} + +func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, err + } + nonce := make([]byte, aesGCM.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, nil, err + } + ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil) + return ciphertext, nonce, nil +} + +func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + return aesGCM.Open(nil, nonce, ciphertext, nil) +} + +func CalculateHMAC(key []byte, data []byte) string { + h := hmac.New(sha256.New, key) + h.Write(data) + return hex.EncodeToString(h.Sum(nil)) +} + +func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool { + expectedHMAC := CalculateHMAC(key, data) + return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC)) +} + +func (m *Message) Serialize(aesKey []byte, hmacKey []byte) ([]byte, string, error) { + var buf bytes.Buffer + headersBytes, err := json.Marshal(m.Headers) + if err != nil { + return nil, "", fmt.Errorf("error marshalling headers: %v", err) + } + headersLen := uint32(len(headersBytes)) + if err := binary.Write(&buf, binary.LittleEndian, headersLen); err != nil { + return nil, "", err + } + buf.Write(headersBytes) + topicBytes := []byte(m.Topic) + topicLen := uint8(len(topicBytes)) + if err := binary.Write(&buf, binary.LittleEndian, topicLen); err != nil { + return nil, "", err + } + buf.Write(topicBytes) + if !m.Command.IsValid() { + return nil, "", fmt.Errorf("invalid command: %s", m.Command) + } + if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil { + return nil, "", err + } + payloadBytes, err := json.Marshal(m.Payload) + if err != nil { + return nil, "", fmt.Errorf("error marshalling payload: %v", err) + } + encryptedPayload, nonce, err := EncryptPayload(aesKey, payloadBytes) + if err != nil { + return nil, "", err + } + payloadLen := uint32(len(encryptedPayload)) + if err := binary.Write(&buf, binary.LittleEndian, payloadLen); err != nil { + return nil, "", err + } + buf.Write(encryptedPayload) + buf.Write(nonce) + messageBytes := buf.Bytes() + hmacSignature := CalculateHMAC(hmacKey, messageBytes) + return messageBytes, hmacSignature, nil +} + +func Deserialize(data []byte, aesKey []byte, hmacKey []byte, receivedHMAC string) (*Message, error) { + if !VerifyHMAC(hmacKey, data, receivedHMAC) { + return nil, fmt.Errorf("HMAC verification failed") + } + buf := bytes.NewReader(data) + var topicLen uint8 + var payloadLen uint32 + var headersLen uint32 + if err := binary.Read(buf, binary.LittleEndian, &headersLen); err != nil { + return nil, err + } + headersBytes := make([]byte, headersLen) + if _, err := buf.Read(headersBytes); err != nil { + return nil, err + } + var headers map[string]string + if err := json.Unmarshal(headersBytes, &headers); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.LittleEndian, &topicLen); err != nil { + return nil, err + } + topicBytes := make([]byte, topicLen) + if _, err := buf.Read(topicBytes); err != nil { + return nil, err + } + topic := string(topicBytes) + var command consts.CMD + if err := binary.Read(buf, binary.LittleEndian, &command); err != nil { + return nil, err + } + if !command.IsValid() { + return nil, fmt.Errorf("invalid command: %s", command) + } + if err := binary.Read(buf, binary.LittleEndian, &payloadLen); err != nil { + return nil, err + } + encryptedPayload := make([]byte, payloadLen) + if _, err := buf.Read(encryptedPayload); err != nil { + return nil, err + } + nonce := make([]byte, 12) + if _, err := buf.Read(nonce); err != nil { + return nil, err + } + payloadBytes, err := DecryptPayload(aesKey, encryptedPayload, nonce) + if err != nil { + return nil, err + } + var payload json.RawMessage + if err := json.Unmarshal(payloadBytes, &payload); err != nil { + return nil, fmt.Errorf("error unmarshalling payload: %v", err) + } + return &Message{ + Headers: headers, + Topic: topic, + Command: command, + Payload: payload, + }, nil +} + +func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) error { + sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey) + if err != nil { + return fmt.Errorf("error serializing message: %v", err) + } + if err := binary.Write(conn, binary.LittleEndian, uint32(len(sentData))); err != nil { + return err + } + + if _, err := conn.Write(sentData); err != nil { + return err + } + if _, err := conn.Write([]byte(hmacSignature)); err != nil { + return err + } + return nil +} + +func ReadMessage(conn io.Reader, aesKey []byte, hmacKey []byte) (*Message, error) { + var length uint32 + if err := binary.Read(conn, binary.LittleEndian, &length); err != nil { + return nil, err + } + data := make([]byte, length) + if _, err := io.ReadFull(conn, data); err != nil { + return nil, err + } + + hmacBytes := make([]byte, 64) + if _, err := io.ReadFull(conn, hmacBytes); err != nil { + return nil, err + } + receivedHMAC := string(hmacBytes) + return Deserialize(data, aesKey, hmacKey, receivedHMAC) +} diff --git a/constants.go b/constants.go deleted file mode 100644 index cc8d132..0000000 --- a/constants.go +++ /dev/null @@ -1,23 +0,0 @@ -package mq - -type CMD byte - -func (c CMD) IsValid() bool { return c > SUBSCRIBE && c < STOP } - -const ( - SUBSCRIBE CMD = iota + 1 - SUBSCRIBE_ACK - PUBLISH - REQUEST - RESPONSE - STOP -) - -var ( - ConsumerKey = "Consumer-Key" - PublisherKey = "Publisher-Key" - ContentType = "Content-Type" - TypeJson = "application/json" - HeaderKey = "headers" - TriggerNode = "triggerNode" -) diff --git a/consts/constants.go b/consts/constants.go new file mode 100644 index 0000000..64d9e0f --- /dev/null +++ b/consts/constants.go @@ -0,0 +1,37 @@ +package consts + +type CMD byte + +func (c CMD) IsValid() bool { return c >= PING && c <= STOP } + +const ( + PING CMD = iota + 1 + SUBSCRIBE + SUBSCRIBE_ACK + PUBLISH + REQUEST + RESPONSE + STOP +) + +func (c CMD) String() string { + switch c { + case PING: + return "PING" + case SUBSCRIBE: + return "SUBSCRIBE" + case STOP: + return "STOP" + default: + return "UNKNOWN" + } +} + +var ( + ConsumerKey = "Consumer-Key" + PublisherKey = "Publisher-Key" + ContentType = "Content-Type" + TypeJson = "application/json" + HeaderKey = "headers" + TriggerNode = "triggerNode" +) diff --git a/consumer.go b/consumer.go index 40d457e..981262f 100644 --- a/consumer.go +++ b/consumer.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/utils" ) @@ -45,11 +46,11 @@ func (c *Consumer) Close() error { func (c *Consumer) subscribe(queue string) error { ctx := context.Background() ctx = SetHeaders(ctx, map[string]string{ - ConsumerKey: c.id, - ContentType: TypeJson, + consts.ConsumerKey: c.id, + consts.ContentType: consts.TypeJson, }) subscribe := Command{ - Command: SUBSCRIBE, + Command: consts.SUBSCRIBE, Queue: queue, ID: NewID(), } @@ -68,9 +69,9 @@ func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { // Handle command message sent by the server. func (c *Consumer) handleCommandMessage(msg Command) error { switch msg.Command { - case STOP: + case consts.STOP: return c.Close() - case SUBSCRIBE_ACK: + case consts.SUBSCRIBE_ACK: log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue) return nil default: diff --git a/ctx.go b/ctx.go index a6312ad..220b8b4 100644 --- a/ctx.go +++ b/ctx.go @@ -14,6 +14,8 @@ import ( "strings" "github.com/oarkflow/xid" + + "github.com/oarkflow/mq/consts" ) type Message struct { @@ -51,11 +53,11 @@ func SetHeaders(ctx context.Context, headers map[string]string) context.Context for key, val := range headers { hd[key] = val } - return context.WithValue(ctx, HeaderKey, hd) + return context.WithValue(ctx, consts.HeaderKey, hd) } func GetHeaders(ctx context.Context) (map[string]string, bool) { - headers, ok := ctx.Value(HeaderKey).(map[string]string) + headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) return headers, ok } @@ -64,7 +66,7 @@ func GetContentType(ctx context.Context) (string, bool) { if !ok { return "", false } - contentType, ok := headers[ContentType] + contentType, ok := headers[consts.ContentType] return contentType, ok } @@ -73,7 +75,7 @@ func GetConsumerID(ctx context.Context) (string, bool) { if !ok { return "", false } - contentType, ok := headers[ConsumerKey] + contentType, ok := headers[consts.ConsumerKey] return contentType, ok } @@ -82,7 +84,7 @@ func GetTriggerNode(ctx context.Context) (string, bool) { if !ok { return "", false } - contentType, ok := headers[TriggerNode] + contentType, ok := headers[consts.TriggerNode] return contentType, ok } @@ -91,7 +93,7 @@ func GetPublisherID(ctx context.Context) (string, bool) { if !ok { return "", false } - contentType, ok := headers[PublisherKey] + contentType, ok := headers[consts.PublisherKey] return contentType, ok } diff --git a/dag/dag.go b/dag/dag.go index 6da584d..4bbde72 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/oarkflow/mq" + "github.com/oarkflow/mq/consts" ) type taskContext struct { @@ -260,7 +261,7 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { }, } - ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.Queue}) + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) for _, loopNode := range loopNodes { for _, item := range items { rs := d.PublishTask(ctx, item, loopNode, task.MessageID) @@ -275,7 +276,7 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { if multipleResults && completed { task.Queue = triggeredNode } - ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.Queue}) + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) edge, exists := d.edges[task.Queue] if exists { d.taskResults[task.MessageID] = map[string]*taskContext{ diff --git a/examples/message.go b/examples/message.go new file mode 100644 index 0000000..2ee9120 --- /dev/null +++ b/examples/message.go @@ -0,0 +1,68 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "net" + "time" + + "github.com/oarkflow/mq/codec" + "github.com/oarkflow/mq/consts" +) + +func main() { + aesKey := []byte("thisis32bytekeyforaesencryption1") + hmacKey := []byte("thisisasecrethmackey1") + go func() { + listener, err := net.Listen("tcp", ":8081") + if err != nil { + log.Fatal(err) + } + defer listener.Close() + log.Println("Server is listening on :8080") + for { + conn, err := listener.Accept() + if err != nil { + log.Println("Connection error:", err) + continue + } + go func(c net.Conn) { + defer c.Close() + for { + msg, err := codec.ReadMessage(c, aesKey, hmacKey) + if err != nil { + if err.Error() == "EOF" { + log.Println("Client disconnected") + break + } + log.Println("Failed to receive message:", err) + break + } + log.Printf("Received Message:\n Headers: %v\n Topic: %s\n Command: %v\n Payload: %s\n", + msg.Headers, msg.Topic, msg.Command, msg.Payload) + } + }(conn) + } + }() + time.Sleep(5 * time.Second) + conn, err := net.Dial("tcp", ":8081") + if err != nil { + log.Fatal(err) + } + defer conn.Close() + headers := map[string]string{"Api-Key": "121323"} + data := map[string]interface{}{"temperature": 23.5, "humidity": 60} + payload, _ := json.Marshal(data) + msg := &codec.Message{ + Headers: headers, + Topic: "sensor_data", + Command: consts.SUBSCRIBE, + Payload: payload, + } + if err := codec.SendMessage(conn, msg, aesKey, hmacKey); err != nil { + log.Fatalf("Error sending message: %v", err) + } + fmt.Println("Message sent successfully") + time.Sleep(5 * time.Second) +} diff --git a/publisher.go b/publisher.go index 9970b6e..3dc040c 100644 --- a/publisher.go +++ b/publisher.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" "net" + + "github.com/oarkflow/mq/consts" ) type Publisher struct { @@ -22,10 +24,10 @@ func NewPublisher(id string, opts ...Option) *Publisher { return b } -func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command CMD) error { +func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error { ctx = SetHeaders(ctx, map[string]string{ - PublisherKey: p.id, - ContentType: TypeJson, + consts.PublisherKey: p.id, + consts.ContentType: consts.TypeJson, }) cmd := Command{ ID: NewID(), @@ -43,7 +45,7 @@ func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error return fmt.Errorf("failed to connect to broker: %w", err) } defer conn.Close() - return p.send(ctx, queue, task, conn, PUBLISH) + return p.send(ctx, queue, task, conn, consts.PUBLISH) } func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error { @@ -62,7 +64,7 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Resul } defer conn.Close() var result Result - err = p.send(ctx, queue, task, conn, REQUEST) + err = p.send(ctx, queue, task, conn, consts.REQUEST) if err != nil { return result, err }