mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-04 07:37:05 +08:00
feat: update
This commit is contained in:
@@ -283,11 +283,11 @@ func (b *Broker) Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||
return codec.SendMessage(conn, msg)
|
||||
}
|
||||
|
||||
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||
return codec.ReadMessage(c)
|
||||
}
|
||||
|
||||
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
||||
|
274
codec/codec.go
274
codec/codec.go
@@ -1,32 +1,26 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
// Message represents the structure of our message.
|
||||
type Message struct {
|
||||
Headers map[string]string `json:"h"`
|
||||
Queue string `json:"q"`
|
||||
Command consts.CMD `json:"c"`
|
||||
Payload json.RawMessage `json:"p"`
|
||||
Headers map[string]string `msgpack:"h"`
|
||||
Queue string `msgpack:"q"`
|
||||
Command consts.CMD `msgpack:"c"`
|
||||
Payload []byte `msgpack:"p"` // Using []byte instead of json.RawMessage
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message {
|
||||
// NewMessage creates a new Message instance.
|
||||
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message {
|
||||
return &Message{
|
||||
Headers: headers,
|
||||
Queue: queue,
|
||||
@@ -35,243 +29,63 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) {
|
||||
m.m.Lock()
|
||||
defer m.m.Unlock()
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 512))
|
||||
if err := writeLengthPrefixedJSON(buf, m.Headers); err != nil {
|
||||
return nil, "", fmt.Errorf("error serializing headers: %v", err)
|
||||
}
|
||||
if err := writeLengthPrefixed(buf, utils.ToByte(m.Queue)); err != nil {
|
||||
return nil, "", fmt.Errorf("error serializing queue: %v", err)
|
||||
}
|
||||
if err := binary.Write(buf, binary.LittleEndian, m.Command); err != nil {
|
||||
return nil, "", fmt.Errorf("error serializing command: %v", err)
|
||||
}
|
||||
if err := writePayload(buf, aesKey, m.Payload, encrypt); err != nil {
|
||||
return nil, "", fmt.Errorf("error serializing payload: %v", err)
|
||||
}
|
||||
messageBytes := buf.Bytes()
|
||||
hmacSignature := CalculateHMAC(hmacKey, messageBytes)
|
||||
return messageBytes, hmacSignature, nil
|
||||
}
|
||||
// Serialize encodes the Message to a byte slice using MessagePack.
|
||||
func (m *Message) Serialize() ([]byte, error) {
|
||||
m.m.RLock()
|
||||
defer m.m.RUnlock()
|
||||
|
||||
func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) {
|
||||
if !VerifyHMAC(hmacKey, data, receivedHMAC) {
|
||||
return nil, fmt.Errorf("HMAC verification failed")
|
||||
}
|
||||
buf := bytes.NewReader(data)
|
||||
headers := make(map[string]string)
|
||||
|
||||
if err := readLengthPrefixedJSON(buf, &headers); err != nil {
|
||||
return nil, fmt.Errorf("error deserializing headers: %v", err)
|
||||
}
|
||||
queue, err := readLengthPrefixedString(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deserializing queue: %v", err)
|
||||
}
|
||||
var command consts.CMD
|
||||
if err := binary.Read(buf, binary.LittleEndian, &command); err != nil {
|
||||
return nil, fmt.Errorf("error deserializing command: %v", err)
|
||||
}
|
||||
payload, err := readPayload(buf, aesKey, decrypt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deserializing payload: %v", err)
|
||||
}
|
||||
return &Message{
|
||||
Headers: headers,
|
||||
Queue: queue,
|
||||
Command: command,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SendMessage(conn io.Writer, msg *Message, aesKey, hmacKey []byte, encrypt bool) error {
|
||||
sentData, hmacSignature, err := msg.Serialize(aesKey, hmacKey, encrypt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing message: %v", err)
|
||||
}
|
||||
if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil {
|
||||
return fmt.Errorf("error sending message: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadMessage(conn io.Reader, aesKey, hmacKey []byte, decrypt bool) (*Message, error) {
|
||||
data, receivedHMAC, err := readMessageWithHMAC(conn)
|
||||
data, err := msgpack.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt)
|
||||
}
|
||||
|
||||
func writeLengthPrefixed(buf *bytes.Buffer, data []byte) error {
|
||||
length := uint32(len(data))
|
||||
if err := binary.Write(buf, binary.LittleEndian, length); err != nil {
|
||||
return err
|
||||
}
|
||||
buf.Write(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeLengthPrefixedJSON(buf *bytes.Buffer, v any) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeLengthPrefixed(buf, data)
|
||||
}
|
||||
|
||||
func readLengthPrefixed(r *bytes.Reader) ([]byte, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]byte, length)
|
||||
if _, err := r.Read(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func readLengthPrefixedJSON(r *bytes.Reader, v any) error {
|
||||
data, err := readLengthPrefixed(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func readLengthPrefixedString(r *bytes.Reader) (string, error) {
|
||||
data, err := readLengthPrefixed(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return utils.FromByte(data), nil
|
||||
}
|
||||
|
||||
func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, encrypt bool) error {
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling payload: %v", err)
|
||||
}
|
||||
var encryptedPayload, nonce []byte
|
||||
if encrypt {
|
||||
encryptedPayload, nonce, err = EncryptPayload(aesKey, payloadBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
encryptedPayload = payloadBytes
|
||||
}
|
||||
if err := writeLengthPrefixed(buf, encryptedPayload); err != nil {
|
||||
return err
|
||||
}
|
||||
if encrypt {
|
||||
buf.Write(nonce)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage, error) {
|
||||
encryptedPayload, err := readLengthPrefixed(r)
|
||||
if err != nil {
|
||||
// Deserialize decodes a byte slice to a Message instance using MessagePack.
|
||||
func Deserialize(data []byte) (*Message, error) {
|
||||
var msg Message
|
||||
if err := msgpack.Unmarshal(data, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var payloadBytes []byte
|
||||
if decrypt {
|
||||
nonce := make([]byte, 12)
|
||||
if _, err := r.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadBytes, err = DecryptPayload(aesKey, encryptedPayload, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting payload: %v", err)
|
||||
}
|
||||
} else {
|
||||
payloadBytes = encryptedPayload
|
||||
}
|
||||
var payload json.RawMessage
|
||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling payload: %v", err)
|
||||
}
|
||||
return payload, nil
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error {
|
||||
if err := binary.Write(conn, binary.LittleEndian, uint32(len(messageBytes))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := conn.Write(messageBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
hmacBytes, err := hex.DecodeString(hmacSignature)
|
||||
// SendMessage sends a Message over a net.Conn.
|
||||
func SendMessage(conn net.Conn, msg *Message) error {
|
||||
data, err := msg.Serialize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := conn.Write(hmacBytes); err != nil {
|
||||
|
||||
// Send the length of the data followed by the data itself
|
||||
length := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(length, uint32(len(data)))
|
||||
|
||||
if _, err := conn.Write(length); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readMessageWithHMAC(conn io.Reader) ([]byte, string, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
|
||||
return nil, "", err
|
||||
// ReadMessage receives a Message from a net.Conn.
|
||||
func ReadMessage(conn net.Conn) (*Message, error) {
|
||||
// Read the length of the incoming message
|
||||
lengthBytes := make([]byte, 4)
|
||||
if _, err := conn.Read(lengthBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := binary.BigEndian.Uint32(lengthBytes)
|
||||
|
||||
// Read the actual message data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(conn, data); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
hmacBytes := make([]byte, 32)
|
||||
if _, err := io.ReadFull(conn, hmacBytes); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
receivedHMAC := hex.EncodeToString(hmacBytes)
|
||||
return data, receivedHMAC, nil
|
||||
}
|
||||
|
||||
func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
nonce := make([]byte, aesGCM.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil)
|
||||
return ciphertext, nonce, nil
|
||||
}
|
||||
|
||||
func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
if _, err := conn.Read(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func CalculateHMAC(key []byte, data []byte) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool {
|
||||
expectedHMAC := CalculateHMAC(key, data)
|
||||
return hmac.Equal(utils.ToByte(receivedHMAC), utils.ToByte(expectedHMAC))
|
||||
return Deserialize(data)
|
||||
}
|
||||
|
@@ -46,11 +46,11 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons
|
||||
}
|
||||
|
||||
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||
return codec.SendMessage(conn, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||
return codec.ReadMessage(conn)
|
||||
}
|
||||
|
||||
func (c *Consumer) Close() error {
|
||||
@@ -66,7 +66,7 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
||||
if err := c.send(c.conn, msg); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error while trying to subscribe: %v", err)
|
||||
}
|
||||
return c.waitForAck(c.conn)
|
||||
}
|
||||
|
15
dag/dag.go
15
dag/dag.go
@@ -184,6 +184,21 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
|
||||
return http.ListenAndServe(addr, nil)
|
||||
}
|
||||
|
||||
func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) {
|
||||
dag.AssignTopic(key)
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
tm.nodes[key] = &Node{
|
||||
Name: name,
|
||||
Key: key,
|
||||
processor: dag,
|
||||
isReady: true,
|
||||
}
|
||||
if len(firstNode) > 0 && firstNode[0] {
|
||||
tm.startNode = key
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *DAG) AddNode(name, key string, handler mq.Handler, firstNode ...bool) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
@@ -13,7 +13,6 @@ func main() {
|
||||
d := dag.NewDAG("Sample DAG", "sample-dag",
|
||||
mq.WithSyncMode(true),
|
||||
mq.WithNotifyResponse(tasks.NotifyResponse),
|
||||
mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")),
|
||||
)
|
||||
d.AddNode("C", "C", tasks.Node3, true)
|
||||
d.AddNode("D", "D", tasks.Node4)
|
||||
@@ -27,6 +26,9 @@ func main() {
|
||||
d.AddEdge("Label 2", "D", "F")
|
||||
d.AddEdge("Label 3", "E", "F")
|
||||
d.AddEdge("Label 4", "F", "G", "H")
|
||||
d.AssignTopic("queue1")
|
||||
d.Consume(context.Background())
|
||||
d.AssignTopic("D")
|
||||
err := d.Consume(context.Background())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mq2 "github.com/oarkflow/mq"
|
||||
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
|
@@ -79,3 +79,7 @@ func Callback(_ context.Context, task mq.Result) mq.Result {
|
||||
func NotifyResponse(_ context.Context, result mq.Result) {
|
||||
log.Printf("DAG - FINAL_RESPONSE ~> TaskID: %s, Payload: %s, Topic: %s, Error: %s", result.TaskID, result.Payload, result.Topic, result.Error)
|
||||
}
|
||||
|
||||
func NotifySubDAGResponse(_ context.Context, result mq.Result) {
|
||||
log.Printf("SUB DAG - FINAL_RESPONSE ~> TaskID: %s, Payload: %s, Topic: %s, Error: %s", result.TaskID, result.Payload, result.Topic, result.Error)
|
||||
}
|
||||
|
3
go.mod
3
go.mod
@@ -5,4 +5,7 @@ go 1.23.0
|
||||
require (
|
||||
github.com/oarkflow/xid v1.2.5
|
||||
github.com/oarkflow/xsync v0.0.5
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1
|
||||
)
|
||||
|
||||
require github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
|
12
go.sum
12
go.sum
@@ -1,4 +1,16 @@
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/oarkflow/xid v1.2.5 h1:6RcNJm9+oZ/B647gkME9trCzhpxGQaSdNoD56Vmkeho=
|
||||
github.com/oarkflow/xid v1.2.5/go.mod h1:jG4YBh+swbjlWApGWDBYnsJEa7hi3CCpmuqhB3RAxVo=
|
||||
github.com/oarkflow/xsync v0.0.5 h1:7HBQjmDus4YFLQFC5D197TB4c2YJTVwsTFuqk5zWKBM=
|
||||
github.com/oarkflow/xsync v0.0.5/go.mod h1:KAaEc506OEd3ISxfhgUBKxk8eQzkz+mb0JkpGGd/QwU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@@ -37,7 +37,7 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.
|
||||
return err
|
||||
}
|
||||
msg := codec.NewMessage(command, payload, queue, headers)
|
||||
if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil {
|
||||
if err := codec.SendMessage(conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.
|
||||
}
|
||||
|
||||
func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||
msg, err := codec.ReadMessage(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||
}
|
||||
|
||||
func (p *Publisher) waitForResponse(conn net.Conn) Result {
|
||||
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||
msg, err := codec.ReadMessage(conn)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
}
|
||||
|
Reference in New Issue
Block a user