mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-06 16:36:53 +08:00
feat: implement TLS support
This commit is contained in:
12
broker.go
12
broker.go
@@ -11,6 +11,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xsync"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type consumer struct {
|
||||
@@ -78,7 +80,7 @@ type Task struct {
|
||||
|
||||
type Command struct {
|
||||
ID string `json:"id"`
|
||||
Command CMD `json:"command"`
|
||||
Command consts.CMD `json:"command"`
|
||||
Queue string `json:"queue"`
|
||||
MessageID string `json:"message_id"`
|
||||
Payload json.RawMessage `json:"payload,omitempty"` // Used for carrying the task payload
|
||||
@@ -265,7 +267,7 @@ func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Con
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
defer func() {
|
||||
cmd := Command{
|
||||
Command: SUBSCRIBE_ACK,
|
||||
Command: consts.SUBSCRIBE_ACK,
|
||||
Queue: queueName,
|
||||
Error: "",
|
||||
}
|
||||
@@ -335,7 +337,7 @@ func (b *Broker) handleTaskMessage(ctx context.Context, _ net.Conn, msg Result)
|
||||
|
||||
func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error {
|
||||
status := "PUBLISH"
|
||||
if msg.Command == REQUEST {
|
||||
if msg.Command == consts.REQUEST {
|
||||
status = "REQUEST"
|
||||
}
|
||||
b.addPublisher(ctx, msg.Queue, conn)
|
||||
@@ -360,10 +362,10 @@ func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error
|
||||
|
||||
func (b *Broker) handleCommandMessage(ctx context.Context, conn net.Conn, msg Command) error {
|
||||
switch msg.Command {
|
||||
case SUBSCRIBE:
|
||||
case consts.SUBSCRIBE:
|
||||
b.subscribe(ctx, msg.Queue, conn)
|
||||
return nil
|
||||
case PUBLISH, REQUEST:
|
||||
case consts.PUBLISH, consts.REQUEST:
|
||||
return b.publish(ctx, conn, msg)
|
||||
default:
|
||||
return fmt.Errorf("unknown command: %d", msg.Command)
|
||||
|
203
codec/codec.go
Normal file
203
codec/codec.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Headers map[string]string `json:"h"`
|
||||
Topic string `json:"t"`
|
||||
Command consts.CMD `json:"c"`
|
||||
Payload json.RawMessage `json:"p"`
|
||||
}
|
||||
|
||||
func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
nonce := make([]byte, aesGCM.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil)
|
||||
return ciphertext, nonce, nil
|
||||
}
|
||||
|
||||
func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||
}
|
||||
|
||||
func CalculateHMAC(key []byte, data []byte) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool {
|
||||
expectedHMAC := CalculateHMAC(key, data)
|
||||
return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC))
|
||||
}
|
||||
|
||||
func (m *Message) Serialize(aesKey []byte, hmacKey []byte) ([]byte, string, error) {
|
||||
var buf bytes.Buffer
|
||||
headersBytes, err := json.Marshal(m.Headers)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error marshalling headers: %v", err)
|
||||
}
|
||||
headersLen := uint32(len(headersBytes))
|
||||
if err := binary.Write(&buf, binary.LittleEndian, headersLen); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
buf.Write(headersBytes)
|
||||
topicBytes := []byte(m.Topic)
|
||||
topicLen := uint8(len(topicBytes))
|
||||
if err := binary.Write(&buf, binary.LittleEndian, topicLen); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
buf.Write(topicBytes)
|
||||
if !m.Command.IsValid() {
|
||||
return nil, "", fmt.Errorf("invalid command: %s", m.Command)
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
payloadBytes, err := json.Marshal(m.Payload)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error marshalling payload: %v", err)
|
||||
}
|
||||
encryptedPayload, nonce, err := EncryptPayload(aesKey, payloadBytes)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
payloadLen := uint32(len(encryptedPayload))
|
||||
if err := binary.Write(&buf, binary.LittleEndian, payloadLen); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
buf.Write(encryptedPayload)
|
||||
buf.Write(nonce)
|
||||
messageBytes := buf.Bytes()
|
||||
hmacSignature := CalculateHMAC(hmacKey, messageBytes)
|
||||
return messageBytes, hmacSignature, nil
|
||||
}
|
||||
|
||||
func Deserialize(data []byte, aesKey []byte, hmacKey []byte, receivedHMAC string) (*Message, error) {
|
||||
if !VerifyHMAC(hmacKey, data, receivedHMAC) {
|
||||
return nil, fmt.Errorf("HMAC verification failed")
|
||||
}
|
||||
buf := bytes.NewReader(data)
|
||||
var topicLen uint8
|
||||
var payloadLen uint32
|
||||
var headersLen uint32
|
||||
if err := binary.Read(buf, binary.LittleEndian, &headersLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headersBytes := make([]byte, headersLen)
|
||||
if _, err := buf.Read(headersBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var headers map[string]string
|
||||
if err := json.Unmarshal(headersBytes, &headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(buf, binary.LittleEndian, &topicLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topicBytes := make([]byte, topicLen)
|
||||
if _, err := buf.Read(topicBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topic := string(topicBytes)
|
||||
var command consts.CMD
|
||||
if err := binary.Read(buf, binary.LittleEndian, &command); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !command.IsValid() {
|
||||
return nil, fmt.Errorf("invalid command: %s", command)
|
||||
}
|
||||
if err := binary.Read(buf, binary.LittleEndian, &payloadLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encryptedPayload := make([]byte, payloadLen)
|
||||
if _, err := buf.Read(encryptedPayload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce := make([]byte, 12)
|
||||
if _, err := buf.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadBytes, err := DecryptPayload(aesKey, encryptedPayload, nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var payload json.RawMessage
|
||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling payload: %v", err)
|
||||
}
|
||||
return &Message{
|
||||
Headers: headers,
|
||||
Topic: topic,
|
||||
Command: command,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SendMessage(conn io.Writer, msg *Message, aesKey []byte, hmacKey []byte) error {
|
||||
sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing message: %v", err)
|
||||
}
|
||||
if err := binary.Write(conn, binary.LittleEndian, uint32(len(sentData))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := conn.Write(sentData); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := conn.Write([]byte(hmacSignature)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadMessage(conn io.Reader, aesKey []byte, hmacKey []byte) (*Message, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(conn, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hmacBytes := make([]byte, 64)
|
||||
if _, err := io.ReadFull(conn, hmacBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
receivedHMAC := string(hmacBytes)
|
||||
return Deserialize(data, aesKey, hmacKey, receivedHMAC)
|
||||
}
|
23
constants.go
23
constants.go
@@ -1,23 +0,0 @@
|
||||
package mq
|
||||
|
||||
type CMD byte
|
||||
|
||||
func (c CMD) IsValid() bool { return c > SUBSCRIBE && c < STOP }
|
||||
|
||||
const (
|
||||
SUBSCRIBE CMD = iota + 1
|
||||
SUBSCRIBE_ACK
|
||||
PUBLISH
|
||||
REQUEST
|
||||
RESPONSE
|
||||
STOP
|
||||
)
|
||||
|
||||
var (
|
||||
ConsumerKey = "Consumer-Key"
|
||||
PublisherKey = "Publisher-Key"
|
||||
ContentType = "Content-Type"
|
||||
TypeJson = "application/json"
|
||||
HeaderKey = "headers"
|
||||
TriggerNode = "triggerNode"
|
||||
)
|
37
consts/constants.go
Normal file
37
consts/constants.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package consts
|
||||
|
||||
type CMD byte
|
||||
|
||||
func (c CMD) IsValid() bool { return c >= PING && c <= STOP }
|
||||
|
||||
const (
|
||||
PING CMD = iota + 1
|
||||
SUBSCRIBE
|
||||
SUBSCRIBE_ACK
|
||||
PUBLISH
|
||||
REQUEST
|
||||
RESPONSE
|
||||
STOP
|
||||
)
|
||||
|
||||
func (c CMD) String() string {
|
||||
switch c {
|
||||
case PING:
|
||||
return "PING"
|
||||
case SUBSCRIBE:
|
||||
return "SUBSCRIBE"
|
||||
case STOP:
|
||||
return "STOP"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ConsumerKey = "Consumer-Key"
|
||||
PublisherKey = "Publisher-Key"
|
||||
ContentType = "Content-Type"
|
||||
TypeJson = "application/json"
|
||||
HeaderKey = "headers"
|
||||
TriggerNode = "triggerNode"
|
||||
)
|
11
consumer.go
11
consumer.go
@@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
@@ -45,11 +46,11 @@ func (c *Consumer) Close() error {
|
||||
func (c *Consumer) subscribe(queue string) error {
|
||||
ctx := context.Background()
|
||||
ctx = SetHeaders(ctx, map[string]string{
|
||||
ConsumerKey: c.id,
|
||||
ContentType: TypeJson,
|
||||
consts.ConsumerKey: c.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
subscribe := Command{
|
||||
Command: SUBSCRIBE,
|
||||
Command: consts.SUBSCRIBE,
|
||||
Queue: queue,
|
||||
ID: NewID(),
|
||||
}
|
||||
@@ -68,9 +69,9 @@ func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
||||
// Handle command message sent by the server.
|
||||
func (c *Consumer) handleCommandMessage(msg Command) error {
|
||||
switch msg.Command {
|
||||
case STOP:
|
||||
case consts.STOP:
|
||||
return c.Close()
|
||||
case SUBSCRIBE_ACK:
|
||||
case consts.SUBSCRIBE_ACK:
|
||||
log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue)
|
||||
return nil
|
||||
default:
|
||||
|
14
ctx.go
14
ctx.go
@@ -14,6 +14,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/oarkflow/xid"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
@@ -51,11 +53,11 @@ func SetHeaders(ctx context.Context, headers map[string]string) context.Context
|
||||
for key, val := range headers {
|
||||
hd[key] = val
|
||||
}
|
||||
return context.WithValue(ctx, HeaderKey, hd)
|
||||
return context.WithValue(ctx, consts.HeaderKey, hd)
|
||||
}
|
||||
|
||||
func GetHeaders(ctx context.Context) (map[string]string, bool) {
|
||||
headers, ok := ctx.Value(HeaderKey).(map[string]string)
|
||||
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||
return headers, ok
|
||||
}
|
||||
|
||||
@@ -64,7 +66,7 @@ func GetContentType(ctx context.Context) (string, bool) {
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[ContentType]
|
||||
contentType, ok := headers[consts.ContentType]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
@@ -73,7 +75,7 @@ func GetConsumerID(ctx context.Context) (string, bool) {
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[ConsumerKey]
|
||||
contentType, ok := headers[consts.ConsumerKey]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
@@ -82,7 +84,7 @@ func GetTriggerNode(ctx context.Context) (string, bool) {
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[TriggerNode]
|
||||
contentType, ok := headers[consts.TriggerNode]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
@@ -91,7 +93,7 @@ func GetPublisherID(ctx context.Context) (string, bool) {
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[PublisherKey]
|
||||
contentType, ok := headers[consts.PublisherKey]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type taskContext struct {
|
||||
@@ -260,7 +261,7 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
},
|
||||
}
|
||||
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.Queue})
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue})
|
||||
for _, loopNode := range loopNodes {
|
||||
for _, item := range items {
|
||||
rs := d.PublishTask(ctx, item, loopNode, task.MessageID)
|
||||
@@ -275,7 +276,7 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
if multipleResults && completed {
|
||||
task.Queue = triggeredNode
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.Queue})
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue})
|
||||
edge, exists := d.edges[task.Queue]
|
||||
if exists {
|
||||
d.taskResults[task.MessageID] = map[string]*taskContext{
|
||||
|
68
examples/message.go
Normal file
68
examples/message.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
func main() {
|
||||
aesKey := []byte("thisis32bytekeyforaesencryption1")
|
||||
hmacKey := []byte("thisisasecrethmackey1")
|
||||
go func() {
|
||||
listener, err := net.Listen("tcp", ":8081")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
log.Println("Server is listening on :8080")
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println("Connection error:", err)
|
||||
continue
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
for {
|
||||
msg, err := codec.ReadMessage(c, aesKey, hmacKey)
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
log.Println("Client disconnected")
|
||||
break
|
||||
}
|
||||
log.Println("Failed to receive message:", err)
|
||||
break
|
||||
}
|
||||
log.Printf("Received Message:\n Headers: %v\n Topic: %s\n Command: %v\n Payload: %s\n",
|
||||
msg.Headers, msg.Topic, msg.Command, msg.Payload)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
time.Sleep(5 * time.Second)
|
||||
conn, err := net.Dial("tcp", ":8081")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
headers := map[string]string{"Api-Key": "121323"}
|
||||
data := map[string]interface{}{"temperature": 23.5, "humidity": 60}
|
||||
payload, _ := json.Marshal(data)
|
||||
msg := &codec.Message{
|
||||
Headers: headers,
|
||||
Topic: "sensor_data",
|
||||
Command: consts.SUBSCRIBE,
|
||||
Payload: payload,
|
||||
}
|
||||
if err := codec.SendMessage(conn, msg, aesKey, hmacKey); err != nil {
|
||||
log.Fatalf("Error sending message: %v", err)
|
||||
}
|
||||
fmt.Println("Message sent successfully")
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
12
publisher.go
12
publisher.go
@@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type Publisher struct {
|
||||
@@ -22,10 +24,10 @@ func NewPublisher(id string, opts ...Option) *Publisher {
|
||||
return b
|
||||
}
|
||||
|
||||
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command CMD) error {
|
||||
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
||||
ctx = SetHeaders(ctx, map[string]string{
|
||||
PublisherKey: p.id,
|
||||
ContentType: TypeJson,
|
||||
consts.PublisherKey: p.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
cmd := Command{
|
||||
ID: NewID(),
|
||||
@@ -43,7 +45,7 @@ func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error
|
||||
return fmt.Errorf("failed to connect to broker: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
return p.send(ctx, queue, task, conn, PUBLISH)
|
||||
return p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||
}
|
||||
|
||||
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error {
|
||||
@@ -62,7 +64,7 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Resul
|
||||
}
|
||||
defer conn.Close()
|
||||
var result Result
|
||||
err = p.send(ctx, queue, task, conn, REQUEST)
|
||||
err = p.send(ctx, queue, task, conn, consts.REQUEST)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user