feat: add example

This commit is contained in:
sujit
2024-10-10 20:35:18 +05:45
parent 512cead15f
commit 87454dd9f2
10 changed files with 438 additions and 72 deletions

128
broker.go
View File

@@ -24,8 +24,9 @@ type QueuedTask struct {
} }
type consumer struct { type consumer struct {
id string id string
conn net.Conn state consts.ConsumerState
conn net.Conn
} }
type publisher struct { type publisher struct {
@@ -50,6 +51,10 @@ func NewBroker(opts ...Option) *Broker {
} }
} }
func (b *Broker) Options() Options {
return b.opts
}
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error { func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
consumerID, ok := GetConsumerID(ctx) consumerID, ok := GetConsumerID(ctx)
if ok && consumerID != "" { if ok && consumerID != "" {
@@ -110,6 +115,16 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con
b.MessageResponseHandler(ctx, msg) b.MessageResponseHandler(ctx, msg)
case consts.MESSAGE_ACK: case consts.MESSAGE_ACK:
b.MessageAck(ctx, msg) b.MessageAck(ctx, msg)
case consts.MESSAGE_DENY:
b.MessageDeny(ctx, msg)
case consts.CONSUMER_PAUSED:
b.OnConsumerPause(ctx, msg)
case consts.CONSUMER_RESUMED:
b.OnConsumerResume(ctx, msg)
case consts.CONSUMER_STOPPED:
b.OnConsumerStop(ctx, msg)
default:
log.Printf("BROKER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
} }
} }
@@ -119,6 +134,43 @@ func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID) log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
} }
func (b *Broker) MessageDeny(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
taskError, _ := jsonparser.GetString(msg.Payload, "error")
log.Printf("BROKER - MESSAGE_DENY ~> %s on %s for Task %s, Error: %s", consumerID, msg.Queue, taskID, taskError)
}
func (b *Broker) OnConsumerPause(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
if consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
con.state = consts.ConsumerStatePaused
log.Printf("BROKER - CONSUMER ~> Paused %s", consumerID)
}
}
}
func (b *Broker) OnConsumerStop(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
if consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
con.state = consts.ConsumerStateStopped
log.Printf("BROKER - CONSUMER ~> Stopped %s", consumerID)
}
}
}
func (b *Broker) OnConsumerResume(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
if consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
con.state = consts.ConsumerStateActive
log.Printf("BROKER - CONSUMER ~> Resumed %s", consumerID)
}
}
}
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) { func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
msg.Command = consts.RESPONSE msg.Command = consts.RESPONSE
b.HandleCallback(ctx, msg) b.HandleCallback(ctx, msg)
@@ -170,7 +222,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
} }
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
consumerID := b.addConsumer(ctx, msg.Queue, conn) consumerID := b.AddConsumer(ctx, msg.Queue, conn)
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers) ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil { if err := b.send(conn, ack); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
@@ -181,7 +233,7 @@ func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
b.removeConsumer(msg.Queue, consumerID) b.RemoveConsumer(consumerID, msg.Queue)
} }
}() }()
} }
@@ -267,7 +319,7 @@ func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Co
return con return con
} }
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string { func (b *Broker) AddConsumer(ctx context.Context, queueName string, conn net.Conn) string {
consumerID, ok := GetConsumerID(ctx) consumerID, ok := GetConsumerID(ctx)
q, ok := b.queues.Get(queueName) q, ok := b.queues.Get(queueName)
if !ok { if !ok {
@@ -280,15 +332,66 @@ func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Con
return consumerID return consumerID
} }
func (b *Broker) removeConsumer(queueName, consumerID string) { func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
if queue, ok := b.queues.Get(queueName); ok { if len(queues) > 0 {
for _, queueName := range queues {
if queue, ok := b.queues.Get(queueName); ok {
con, ok := queue.consumers.Get(consumerID)
if ok {
con.conn.Close()
queue.consumers.Del(consumerID)
}
b.queues.Del(queueName)
}
}
return
}
b.queues.ForEach(func(queueName string, queue *Queue) bool {
con, ok := queue.consumers.Get(consumerID) con, ok := queue.consumers.Get(consumerID)
if ok { if ok {
con.conn.Close() con.conn.Close()
queue.consumers.Del(consumerID) queue.consumers.Del(consumerID)
} }
b.queues.Del(queueName) b.queues.Del(queueName)
return true
})
}
func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
fn := func(queue *Queue) {
con, ok := queue.consumers.Get(consumerID)
if ok {
ack := codec.NewMessage(cmd, []byte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
err := b.send(con.conn, ack)
if err == nil {
con.state = state
}
}
} }
if len(queues) > 0 {
for _, queueName := range queues {
if queue, ok := b.queues.Get(queueName); ok {
fn(queue)
}
}
return
}
b.queues.ForEach(func(queueName string, queue *Queue) bool {
fn(queue)
return true
})
}
func (b *Broker) PauseConsumer(consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
}
func (b *Broker) ResumeConsumer(consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
}
func (b *Broker) StopConsumer(consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
} }
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
@@ -323,13 +426,22 @@ func (b *Broker) dispatchWorker(queue *Queue) {
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool { func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
var consumerFound bool var consumerFound bool
var err error
queue.consumers.ForEach(func(_ string, con *consumer) bool { queue.consumers.ForEach(func(_ string, con *consumer) bool {
if con.state != consts.ConsumerStateActive {
err = fmt.Errorf("consumer %s is not active", con.id)
return false
}
if err := b.send(con.conn, task.Message); err == nil { if err := b.send(con.conn, task.Message); err == nil {
consumerFound = true consumerFound = true
return false // break the loop once a consumer is found return false
} }
return true return true
}) })
if err != nil {
log.Println(err.Error())
return false
}
if !consumerFound { if !consumerFound {
log.Printf("No available consumers for queue %s, retrying...", queue.name) log.Printf("No available consumers for queue %s, retrying...", queue.name)
} }

View File

@@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"sync"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
) )
@@ -15,7 +16,7 @@ type Message struct {
Queue string `json:"q"` Queue string `json:"q"`
Command consts.CMD `json:"c"` Command consts.CMD `json:"c"`
Payload json.RawMessage `json:"p"` Payload json.RawMessage `json:"p"`
// Metadata map[string]any `json:"m"` m sync.RWMutex
} }
func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message { func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message {
@@ -24,14 +25,13 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m
Queue: queue, Queue: queue,
Command: cmd, Command: cmd,
Payload: payload, Payload: payload,
// Metadata: nil,
} }
} }
func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) { func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) {
m.m.Lock()
defer m.m.Unlock()
var buf bytes.Buffer var buf bytes.Buffer
// Serialize Headers, Topic, Command, Payload, and Metadata
if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil { if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil {
return nil, "", fmt.Errorf("error serializing headers: %v", err) return nil, "", fmt.Errorf("error serializing headers: %v", err)
} }
@@ -44,56 +44,37 @@ func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, strin
if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil { if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil {
return nil, "", err return nil, "", err
} }
/*if err := writeLengthPrefixedJSON(&buf, m.Metadata); err != nil {
return nil, "", fmt.Errorf("error serializing metadata: %v", err)
}*/
// Calculate HMAC
messageBytes := buf.Bytes() messageBytes := buf.Bytes()
hmacSignature := CalculateHMAC(hmacKey, messageBytes) hmacSignature := CalculateHMAC(hmacKey, messageBytes)
return messageBytes, hmacSignature, nil return messageBytes, hmacSignature, nil
} }
func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) { func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) {
if !VerifyHMAC(hmacKey, data, receivedHMAC) { if !VerifyHMAC(hmacKey, data, receivedHMAC) {
return nil, fmt.Errorf("HMAC verification failed") return nil, fmt.Errorf("HMAC verification failed %s", string(hmacKey))
} }
buf := bytes.NewReader(data) buf := bytes.NewReader(data)
// Deserialize Headers, Topic, Command, Payload, and Metadata
headers := make(map[string]string) headers := make(map[string]string)
if err := readLengthPrefixedJSON(buf, &headers); err != nil { if err := readLengthPrefixedJSON(buf, &headers); err != nil {
return nil, fmt.Errorf("error deserializing headers: %v", err) return nil, fmt.Errorf("error deserializing headers: %v", err)
} }
topic, err := readLengthPrefixedString(buf) topic, err := readLengthPrefixedString(buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("error deserializing topic: %v", err) return nil, fmt.Errorf("error deserializing topic: %v", err)
} }
var command consts.CMD var command consts.CMD
if err := binary.Read(buf, binary.LittleEndian, &command); err != nil { if err := binary.Read(buf, binary.LittleEndian, &command); err != nil {
return nil, fmt.Errorf("error deserializing command: %v", err) return nil, fmt.Errorf("error deserializing command: %v", err)
} }
payload, err := readPayload(buf, aesKey, decrypt) payload, err := readPayload(buf, aesKey, decrypt)
if err != nil { if err != nil {
return nil, fmt.Errorf("error deserializing payload: %v", err) return nil, fmt.Errorf("error deserializing payload: %v", err)
} }
/*metadata := make(map[string]any)
if err := readLengthPrefixedJSON(buf, &metadata); err != nil {
return nil, fmt.Errorf("error deserializing metadata: %v", err)
}*/
return &Message{ return &Message{
Headers: headers, Headers: headers,
Queue: topic, Queue: topic,
Command: command, Command: command,
Payload: payload, Payload: payload,
// Metadata: metadata,
}, nil }, nil
} }
@@ -102,11 +83,9 @@ func SendMessage(conn io.Writer, msg *Message, aesKey, hmacKey []byte, encrypt b
if err != nil { if err != nil {
return fmt.Errorf("error serializing message: %v", err) return fmt.Errorf("error serializing message: %v", err)
} }
if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil { if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil {
return fmt.Errorf("error sending message: %v", err) return fmt.Errorf("error sending message: %v", err)
} }
return nil return nil
} }
@@ -115,7 +94,6 @@ func ReadMessage(conn io.Reader, aesKey, hmacKey []byte, decrypt bool) (*Message
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt) return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt)
} }
@@ -195,7 +173,6 @@ func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage,
if err != nil { if err != nil {
return nil, err return nil, err
} }
var payloadBytes []byte var payloadBytes []byte
if decrypt { if decrypt {
nonce := make([]byte, 12) nonce := make([]byte, 12)
@@ -209,12 +186,10 @@ func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage,
} else { } else {
payloadBytes = encryptedPayload payloadBytes = encryptedPayload
} }
var payload json.RawMessage var payload json.RawMessage
if err := json.Unmarshal(payloadBytes, &payload); err != nil { if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return nil, fmt.Errorf("error unmarshalling payload: %v", err) return nil, fmt.Errorf("error unmarshalling payload: %v", err)
} }
return payload, nil return payload, nil
} }
func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error { func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error {

View File

@@ -2,7 +2,7 @@ package consts
type CMD byte type CMD byte
func (c CMD) IsValid() bool { return c >= PING && c <= STOP } func (c CMD) IsValid() bool { return c >= PING && c <= CONSUMER_STOP }
const ( const (
PING CMD = iota + 1 PING CMD = iota + 1
@@ -11,13 +11,29 @@ const (
MESSAGE_SEND MESSAGE_SEND
MESSAGE_RESPONSE MESSAGE_RESPONSE
MESSAGE_DENY
MESSAGE_ACK MESSAGE_ACK
MESSAGE_ERROR MESSAGE_ERROR
PUBLISH PUBLISH
PUBLISH_ACK PUBLISH_ACK
RESPONSE RESPONSE
STOP
CONSUMER_PAUSE
CONSUMER_RESUME
CONSUMER_STOP
CONSUMER_PAUSED
CONSUMER_RESUMED
CONSUMER_STOPPED
)
type ConsumerState byte
const (
ConsumerStateActive ConsumerState = iota
ConsumerStatePaused
ConsumerStateStopped
) )
func (c CMD) String() string { func (c CMD) String() string {
@@ -30,6 +46,8 @@ func (c CMD) String() string {
return "SUBSCRIBE_ACK" return "SUBSCRIBE_ACK"
case MESSAGE_SEND: case MESSAGE_SEND:
return "MESSAGE_SEND" return "MESSAGE_SEND"
case MESSAGE_DENY:
return "MESSAGE_DENY"
case MESSAGE_RESPONSE: case MESSAGE_RESPONSE:
return "MESSAGE_RESPONSE" return "MESSAGE_RESPONSE"
case MESSAGE_ERROR: case MESSAGE_ERROR:
@@ -40,8 +58,18 @@ func (c CMD) String() string {
return "PUBLISH" return "PUBLISH"
case PUBLISH_ACK: case PUBLISH_ACK:
return "PUBLISH_ACK" return "PUBLISH_ACK"
case STOP: case CONSUMER_PAUSE:
return "STOP" return "CONSUMER_PAUSE"
case CONSUMER_RESUME:
return "CONSUMER_RESUME"
case CONSUMER_STOP:
return "CONSUMER_STOP"
case CONSUMER_PAUSED:
return "CONSUMER_PAUSED"
case CONSUMER_RESUMED:
return "CONSUMER_RESUMED"
case CONSUMER_STOPPED:
return "CONSUMER_STOPPED"
case RESPONSE: case RESPONSE:
return "RESPONSE" return "RESPONSE"
default: default:

View File

@@ -23,6 +23,7 @@ type Consumer struct {
conn net.Conn conn net.Conn
queue string queue string
opts Options opts Options
pool *Pool
} }
// NewConsumer initializes a new consumer with the provided options. // NewConsumer initializes a new consumer with the provided options.
@@ -44,8 +45,9 @@ func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
} }
// Close closes the consumer's connection. // Close closes the consumer's connection and stops the worker pool.
func (c *Consumer) Close() error { func (c *Consumer) Close() error {
c.pool.Stop()
return c.conn.Close() return c.conn.Close()
} }
@@ -55,11 +57,10 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error {
consts.ConsumerKey: c.id, consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson, consts.ContentType: consts.TypeJson,
}) })
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers) msg := codec.NewMessage(consts.SUBSCRIBE, []byte("{}"), queue, headers)
if err := c.send(c.conn, msg); err != nil { if err := c.send(c.conn, msg); err != nil {
return err return err
} }
return c.waitForAck(c.conn) return c.waitForAck(c.conn)
} }
@@ -73,6 +74,30 @@ func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
} }
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
switch msg.Command {
case consts.PUBLISH:
c.ConsumeMessage(ctx, msg, conn)
case consts.CONSUMER_PAUSE:
err := c.Pause()
if err != nil {
log.Printf("Unable to pause consumer: %v", err)
}
case consts.CONSUMER_RESUME:
err := c.Resume()
if err != nil {
log.Printf("Unable to resume consumer: %v", err)
}
case consts.CONSUMER_STOP:
err := c.Stop()
if err != nil {
log.Printf("Unable to stop consumer: %v", err)
}
default:
log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
}
}
func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
headers := WithHeaders(ctx, map[string]string{ headers := WithHeaders(ctx, map[string]string{
consts.ConsumerKey: c.id, consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson, consts.ContentType: consts.TypeJson,
@@ -89,11 +114,20 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C
log.Printf("Error unmarshalling message: %v", err) log.Printf("Error unmarshalling message: %v", err)
return return
} }
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
result := c.ProcessTask(ctx, &task) if !c.opts.enableWorkerPool {
err = c.OnResponse(ctx, result) result := c.ProcessTask(ctx, &task)
if err != nil { err = c.OnResponse(ctx, result)
log.Printf("Error on message callback: %v", err) if err != nil {
log.Printf("Error on message callback: %v", err)
}
return
}
// Add the task to the worker pool
if err := c.pool.AddTask(ctx, &task); err != nil {
c.sendDenyMessage(ctx, taskID, msg.Queue, err)
return
} }
} }
@@ -120,6 +154,17 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
return nil return nil
} }
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
headers := WithHeaders(ctx, map[string]string{
consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson,
})
reply := codec.NewMessage(consts.MESSAGE_DENY, []byte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
if sendErr := c.send(c.conn, reply); sendErr != nil {
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
}
}
// ProcessTask handles a received task message and invokes the appropriate handler. // ProcessTask handles a received task message and invokes the appropriate handler.
func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result { func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result {
result := c.handler(ctx, msg) result := c.handler(ctx, msg)
@@ -171,10 +216,11 @@ func (c *Consumer) Consume(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.conn)
if err := c.subscribe(ctx, c.queue); err != nil { if err := c.subscribe(ctx, c.queue); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
} }
c.pool.Start(c.opts.numOfWorkers)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
@@ -202,3 +248,53 @@ func (c *Consumer) waitForAck(conn net.Conn) error {
} }
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
} }
// Additional methods for Pause, Resume, and Stop
func (c *Consumer) Pause() error {
if err := c.sendPauseMessage(); err != nil {
return err
}
c.pool.Pause()
return nil
}
func (c *Consumer) sendPauseMessage() error {
headers := WithHeaders(context.Background(), map[string]string{
consts.ConsumerKey: c.id,
})
msg := codec.NewMessage(consts.CONSUMER_PAUSED, nil, c.queue, headers)
return c.send(c.conn, msg)
}
func (c *Consumer) Resume() error {
if err := c.sendResumeMessage(); err != nil {
return err
}
c.pool.Resume()
return nil
}
func (c *Consumer) sendResumeMessage() error {
headers := WithHeaders(context.Background(), map[string]string{
consts.ConsumerKey: c.id,
})
msg := codec.NewMessage(consts.CONSUMER_RESUMED, nil, c.queue, headers)
return c.send(c.conn, msg)
}
func (c *Consumer) Stop() error {
if err := c.sendStopMessage(); err != nil {
return err
}
c.pool.Stop()
return nil
}
func (c *Consumer) sendStopMessage() error {
headers := WithHeaders(context.Background(), map[string]string{
consts.ConsumerKey: c.id,
})
msg := codec.NewMessage(consts.CONSUMER_STOPPED, nil, c.queue, headers)
return c.send(c.conn, msg)
}

View File

@@ -4,12 +4,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/oarkflow/xid"
"log" "log"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/oarkflow/xid"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
) )
@@ -49,6 +50,8 @@ type DAG struct {
taskContext map[string]*TaskManager taskContext map[string]*TaskManager
conditions map[string]map[string]string conditions map[string]map[string]string
mu sync.RWMutex mu sync.RWMutex
paused bool
opts []mq.Option
} }
func NewDAG(opts ...mq.Option) *DAG { func NewDAG(opts ...mq.Option) *DAG {
@@ -59,6 +62,7 @@ func NewDAG(opts ...mq.Option) *DAG {
} }
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
d.server = mq.NewBroker(opts...) d.server = mq.NewBroker(opts...)
d.opts = opts
return d return d
} }
@@ -95,10 +99,13 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
if con.isReady { if con.isReady {
go func(con *Node) { go func(con *Node) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
con.consumer.Consume(ctx) err := con.consumer.Consume(ctx)
if err != nil {
panic(err)
}
}(con) }(con)
} else { } else {
log.Printf("[WARNING] - %s is not ready yet", con.Key) log.Printf("[WARNING] - Consumer %s is not ready yet", con.Key)
} }
} }
} }
@@ -114,7 +121,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) { func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) {
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
con := mq.NewConsumer(key, key, handler) con := mq.NewConsumer(key, key, handler, tm.opts...)
tm.Nodes[key] = &Node{ tm.Nodes[key] = &Node{
Key: key, Key: key,
consumer: con, consumer: con,
@@ -176,8 +183,11 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) {
} }
func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result {
if tm.paused {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
if !tm.IsReady() { if !tm.IsReady() {
return mq.Result{Error: fmt.Errorf("DAG is not ready yet")} return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
} }
val := ctx.Value("initial_node") val := ctx.Value("initial_node")
initialNode, ok := val.(string) initialNode, ok := val.(string)
@@ -226,3 +236,27 @@ func (tm *DAG) FindInitialNode() *Node {
} }
return nil return nil
} }
func (tm *DAG) Pause() {
tm.paused = true
log.Printf("DAG - PAUSED")
}
func (tm *DAG) Resume() {
tm.paused = false
log.Printf("DAG - RESUMED")
}
func (tm *DAG) PauseConsumer(id string) {
if node, ok := tm.Nodes[id]; ok {
node.consumer.Pause()
node.isReady = false
}
}
func (tm *DAG) ResumeConsumer(id string) {
if node, ok := tm.Nodes[id]; ok {
node.consumer.Resume()
node.isReady = true
}
}

View File

@@ -15,7 +15,11 @@ import (
) )
var ( var (
d = dag.NewDAG(mq.WithSyncMode(false), mq.WithNotifyResponse(tasks.NotifyResponse)) d = dag.NewDAG(
mq.WithNotifyResponse(tasks.NotifyResponse),
mq.WithWorkerPool(100, 4, 5000000),
mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")),
)
// d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) // d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
) )
@@ -34,6 +38,24 @@ func main() {
d.AddEdge("E", "F") d.AddEdge("E", "F")
http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request")) http.HandleFunc("POST /request", requestHandler("request"))
http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {
d.PauseConsumer(id)
}
})
http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {
d.ResumeConsumer(id)
}
})
http.HandleFunc("/pause", func(writer http.ResponseWriter, request *http.Request) {
d.Pause()
})
http.HandleFunc("/resume", func(writer http.ResponseWriter, request *http.Request) {
d.Resume()
})
err := d.Start(context.TODO(), ":8083") err := d.Start(context.TODO(), ":8083")
if err != nil { if err != nil {
panic(err) panic(err)

34
examples/hmac.go Normal file
View File

@@ -0,0 +1,34 @@
package main
import (
"crypto/rand"
"encoding/base64"
"fmt"
"log"
)
func GenerateSecretKey() (string, error) {
// Create a byte slice to hold 32 random bytes
key := make([]byte, 32)
// Fill the slice with secure random bytes
_, err := rand.Read(key)
if err != nil {
return "", err
}
// Encode the byte slice to a Base64 string
secretKey := base64.StdEncoding.EncodeToString(key)
// Return the first 32 characters
return secretKey[:32], nil
}
func main() {
secretKey, err := GenerateSecretKey()
if err != nil {
log.Fatalf("Error generating secret key: %v", err)
}
fmt.Println("Generated Secret Key:", secretKey)
}

View File

@@ -9,17 +9,19 @@ import (
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
) )
func Node1(ctx context.Context, task *mq.Task) mq.Result { func Node1(_ context.Context, task *mq.Task) mq.Result {
fmt.Println("Node 1", string(task.Payload))
return mq.Result{Payload: task.Payload, TaskID: task.ID} return mq.Result{Payload: task.Payload, TaskID: task.ID}
} }
func Node2(ctx context.Context, task *mq.Task) mq.Result { func Node2(_ context.Context, task *mq.Task) mq.Result {
fmt.Println("Node 2", string(task.Payload))
return mq.Result{Payload: task.Payload, TaskID: task.ID} return mq.Result{Payload: task.Payload, TaskID: task.ID}
} }
func Node3(ctx context.Context, task *mq.Task) mq.Result { func Node3(_ context.Context, task *mq.Task) mq.Result {
var user map[string]any var user map[string]any
json.Unmarshal(task.Payload, &user) _ = json.Unmarshal(task.Payload, &user)
age := int(user["age"].(float64)) age := int(user["age"].(float64))
status := "FAIL" status := "FAIL"
if age > 20 { if age > 20 {
@@ -30,34 +32,34 @@ func Node3(ctx context.Context, task *mq.Task) mq.Result {
return mq.Result{Payload: resultPayload, Status: status} return mq.Result{Payload: resultPayload, Status: status}
} }
func Node4(ctx context.Context, task *mq.Task) mq.Result { func Node4(_ context.Context, task *mq.Task) mq.Result {
var user map[string]any var user map[string]any
json.Unmarshal(task.Payload, &user) _ = json.Unmarshal(task.Payload, &user)
user["final"] = "D" user["final"] = "D"
resultPayload, _ := json.Marshal(user) resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload} return mq.Result{Payload: resultPayload}
} }
func Node5(ctx context.Context, task *mq.Task) mq.Result { func Node5(_ context.Context, task *mq.Task) mq.Result {
var user map[string]any var user map[string]any
json.Unmarshal(task.Payload, &user) _ = json.Unmarshal(task.Payload, &user)
user["salary"] = "E" user["salary"] = "E"
resultPayload, _ := json.Marshal(user) resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload} return mq.Result{Payload: resultPayload}
} }
func Node6(ctx context.Context, task *mq.Task) mq.Result { func Node6(_ context.Context, task *mq.Task) mq.Result {
var user map[string]any var user map[string]any
json.Unmarshal(task.Payload, &user) _ = json.Unmarshal(task.Payload, &user)
resultPayload, _ := json.Marshal(map[string]any{"storage": user}) resultPayload, _ := json.Marshal(map[string]any{"storage": user})
return mq.Result{Payload: resultPayload} return mq.Result{Payload: resultPayload}
} }
func Callback(ctx context.Context, task mq.Result) mq.Result { func Callback(_ context.Context, task mq.Result) mq.Result {
fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic) fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic)
return mq.Result{} return mq.Result{}
} }
func NotifyResponse(ctx context.Context, result mq.Result) { func NotifyResponse(_ context.Context, result mq.Result) {
log.Printf("DAG Final response: TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic) log.Printf("DAG - FINAL_RESPONSE ~> TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic)
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"runtime"
"time" "time"
) )
@@ -72,17 +73,22 @@ type Options struct {
hmacKey json.RawMessage hmacKey json.RawMessage
enableEncryption bool enableEncryption bool
queueSize int queueSize int
numOfWorkers int
maxMemoryLoad int64
enableWorkerPool bool
} }
func defaultOptions() Options { func defaultOptions() Options {
return Options{ return Options{
syncMode: false,
brokerAddr: ":8080", brokerAddr: ":8080",
maxRetries: 5, maxRetries: 5,
initialDelay: 2 * time.Second, initialDelay: 2 * time.Second,
maxBackoff: 20 * time.Second, maxBackoff: 20 * time.Second,
jitterPercent: 0.5, jitterPercent: 0.5,
queueSize: 100, queueSize: 100,
hmacKey: []byte(`a9f4b9415485b70275673b5920182796ea497b5e093ead844a43ea5d77cbc24f`),
numOfWorkers: runtime.NumCPU(),
maxMemoryLoad: 5000000,
} }
} }
@@ -103,6 +109,15 @@ func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option
} }
} }
func WithWorkerPool(queueSize, numOfWorkers int, maxMemoryLoad int64) Option {
return func(opts *Options) {
opts.enableWorkerPool = true
opts.queueSize = queueSize
opts.numOfWorkers = numOfWorkers
opts.maxMemoryLoad = maxMemoryLoad
}
}
func WithConsumerOnSubscribe(handler func(ctx context.Context, topic, consumerName string)) Option { func WithConsumerOnSubscribe(handler func(ctx context.Context, topic, consumerName string)) Option {
return func(opts *Options) { return func(opts *Options) {
opts.consumerOnSubscribe = handler opts.consumerOnSubscribe = handler
@@ -115,11 +130,16 @@ func WithConsumerOnClose(handler func(ctx context.Context, topic, consumerName s
} }
} }
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { func WithSecretKey(aesKey json.RawMessage) Option {
return func(opts *Options) { return func(opts *Options) {
opts.aesKey = aesKey opts.aesKey = aesKey
opts.enableEncryption = true
}
}
func WithHMACKey(hmacKey json.RawMessage) Option {
return func(opts *Options) {
opts.hmacKey = hmacKey opts.hmacKey = hmacKey
opts.enableEncryption = enableEncryption
} }
} }

43
utils/encrypt.go Normal file
View File

@@ -0,0 +1,43 @@
package utils
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
)
func GenerateHMACKey(length int) (string, error) {
key := make([]byte, length)
_, err := rand.Read(key)
if err != nil {
return "", fmt.Errorf("failed to generate random key: %v", err)
}
return hex.EncodeToString(key), nil
}
func MustGenerateHMACKey(length int) string {
key, err := GenerateHMACKey(length)
if err != nil {
panic(err)
}
return key
}
func GenerateSecretKey(length int) (string, error) {
key := make([]byte, length)
_, err := rand.Read(key)
if err != nil {
return "", err
}
secretKey := base64.StdEncoding.EncodeToString(key)
return secretKey[:length], nil
}
func MustGenerateSecretKey(length int) string {
key, err := GenerateSecretKey(length)
if err != nil {
panic(err)
}
return key
}