diff --git a/broker.go b/broker.go index 98850aa..a220c61 100644 --- a/broker.go +++ b/broker.go @@ -24,8 +24,9 @@ type QueuedTask struct { } type consumer struct { - id string - conn net.Conn + id string + state consts.ConsumerState + conn net.Conn } type publisher struct { @@ -50,6 +51,10 @@ func NewBroker(opts ...Option) *Broker { } } +func (b *Broker) Options() Options { + return b.opts +} + func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error { consumerID, ok := GetConsumerID(ctx) if ok && consumerID != "" { @@ -110,6 +115,16 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con b.MessageResponseHandler(ctx, msg) case consts.MESSAGE_ACK: b.MessageAck(ctx, msg) + case consts.MESSAGE_DENY: + b.MessageDeny(ctx, msg) + case consts.CONSUMER_PAUSED: + b.OnConsumerPause(ctx, msg) + case consts.CONSUMER_RESUMED: + b.OnConsumerResume(ctx, msg) + case consts.CONSUMER_STOPPED: + b.OnConsumerStop(ctx, msg) + default: + log.Printf("BROKER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue) } } @@ -119,6 +134,43 @@ func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) { log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID) } +func (b *Broker) MessageDeny(ctx context.Context, msg *codec.Message) { + consumerID, _ := GetConsumerID(ctx) + taskID, _ := jsonparser.GetString(msg.Payload, "id") + taskError, _ := jsonparser.GetString(msg.Payload, "error") + log.Printf("BROKER - MESSAGE_DENY ~> %s on %s for Task %s, Error: %s", consumerID, msg.Queue, taskID, taskError) +} + +func (b *Broker) OnConsumerPause(ctx context.Context, msg *codec.Message) { + consumerID, _ := GetConsumerID(ctx) + if consumerID != "" { + if con, exists := b.consumers.Get(consumerID); exists { + con.state = consts.ConsumerStatePaused + log.Printf("BROKER - CONSUMER ~> Paused %s", consumerID) + } + } +} + +func (b *Broker) OnConsumerStop(ctx context.Context, msg *codec.Message) { + consumerID, _ := GetConsumerID(ctx) + if consumerID != "" { + if con, exists := b.consumers.Get(consumerID); exists { + con.state = consts.ConsumerStateStopped + log.Printf("BROKER - CONSUMER ~> Stopped %s", consumerID) + } + } +} + +func (b *Broker) OnConsumerResume(ctx context.Context, msg *codec.Message) { + consumerID, _ := GetConsumerID(ctx) + if consumerID != "" { + if con, exists := b.consumers.Get(consumerID); exists { + con.state = consts.ConsumerStateActive + log.Printf("BROKER - CONSUMER ~> Resumed %s", consumerID) + } + } +} + func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) { msg.Command = consts.RESPONSE b.HandleCallback(ctx, msg) @@ -170,7 +222,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M } func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { - consumerID := b.addConsumer(ctx, msg.Queue, conn) + consumerID := b.AddConsumer(ctx, msg.Queue, conn) ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers) if err := b.send(conn, ack); err != nil { log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) @@ -181,7 +233,7 @@ func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec go func() { select { case <-ctx.Done(): - b.removeConsumer(msg.Queue, consumerID) + b.RemoveConsumer(consumerID, msg.Queue) } }() } @@ -267,7 +319,7 @@ func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Co return con } -func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string { +func (b *Broker) AddConsumer(ctx context.Context, queueName string, conn net.Conn) string { consumerID, ok := GetConsumerID(ctx) q, ok := b.queues.Get(queueName) if !ok { @@ -280,15 +332,66 @@ func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Con return consumerID } -func (b *Broker) removeConsumer(queueName, consumerID string) { - if queue, ok := b.queues.Get(queueName); ok { +func (b *Broker) RemoveConsumer(consumerID string, queues ...string) { + if len(queues) > 0 { + for _, queueName := range queues { + if queue, ok := b.queues.Get(queueName); ok { + con, ok := queue.consumers.Get(consumerID) + if ok { + con.conn.Close() + queue.consumers.Del(consumerID) + } + b.queues.Del(queueName) + } + } + return + } + b.queues.ForEach(func(queueName string, queue *Queue) bool { con, ok := queue.consumers.Get(consumerID) if ok { con.conn.Close() queue.consumers.Del(consumerID) } b.queues.Del(queueName) + return true + }) +} + +func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) { + fn := func(queue *Queue) { + con, ok := queue.consumers.Get(consumerID) + if ok { + ack := codec.NewMessage(cmd, []byte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID}) + err := b.send(con.conn, ack) + if err == nil { + con.state = state + } + } } + if len(queues) > 0 { + for _, queueName := range queues { + if queue, ok := b.queues.Get(queueName); ok { + fn(queue) + } + } + return + } + b.queues.ForEach(func(queueName string, queue *Queue) bool { + fn(queue) + return true + }) +} + +func (b *Broker) PauseConsumer(consumerID string, queues ...string) { + b.handleConsumer(consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...) +} + +func (b *Broker) ResumeConsumer(consumerID string, queues ...string) { + b.handleConsumer(consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...) +} + +func (b *Broker) StopConsumer(consumerID string, queues ...string) { + b.handleConsumer(consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...) } func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { @@ -323,13 +426,22 @@ func (b *Broker) dispatchWorker(queue *Queue) { func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool { var consumerFound bool + var err error queue.consumers.ForEach(func(_ string, con *consumer) bool { + if con.state != consts.ConsumerStateActive { + err = fmt.Errorf("consumer %s is not active", con.id) + return false + } if err := b.send(con.conn, task.Message); err == nil { consumerFound = true - return false // break the loop once a consumer is found + return false } return true }) + if err != nil { + log.Println(err.Error()) + return false + } if !consumerFound { log.Printf("No available consumers for queue %s, retrying...", queue.name) } diff --git a/codec/codec.go b/codec/codec.go index ac71082..d7c1b8f 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "sync" "github.com/oarkflow/mq/consts" ) @@ -15,7 +16,7 @@ type Message struct { Queue string `json:"q"` Command consts.CMD `json:"c"` Payload json.RawMessage `json:"p"` - // Metadata map[string]any `json:"m"` + m sync.RWMutex } func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message { @@ -24,14 +25,13 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m Queue: queue, Command: cmd, Payload: payload, - // Metadata: nil, } } func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) { + m.m.Lock() + defer m.m.Unlock() var buf bytes.Buffer - - // Serialize Headers, Topic, Command, Payload, and Metadata if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil { return nil, "", fmt.Errorf("error serializing headers: %v", err) } @@ -44,56 +44,37 @@ func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, strin if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil { return nil, "", err } - /*if err := writeLengthPrefixedJSON(&buf, m.Metadata); err != nil { - return nil, "", fmt.Errorf("error serializing metadata: %v", err) - }*/ - - // Calculate HMAC messageBytes := buf.Bytes() hmacSignature := CalculateHMAC(hmacKey, messageBytes) - return messageBytes, hmacSignature, nil } func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) { if !VerifyHMAC(hmacKey, data, receivedHMAC) { - return nil, fmt.Errorf("HMAC verification failed") + return nil, fmt.Errorf("HMAC verification failed %s", string(hmacKey)) } - buf := bytes.NewReader(data) - - // Deserialize Headers, Topic, Command, Payload, and Metadata headers := make(map[string]string) if err := readLengthPrefixedJSON(buf, &headers); err != nil { return nil, fmt.Errorf("error deserializing headers: %v", err) } - topic, err := readLengthPrefixedString(buf) if err != nil { return nil, fmt.Errorf("error deserializing topic: %v", err) } - var command consts.CMD if err := binary.Read(buf, binary.LittleEndian, &command); err != nil { return nil, fmt.Errorf("error deserializing command: %v", err) } - payload, err := readPayload(buf, aesKey, decrypt) if err != nil { return nil, fmt.Errorf("error deserializing payload: %v", err) } - - /*metadata := make(map[string]any) - if err := readLengthPrefixedJSON(buf, &metadata); err != nil { - return nil, fmt.Errorf("error deserializing metadata: %v", err) - }*/ - return &Message{ Headers: headers, Queue: topic, Command: command, Payload: payload, - // Metadata: metadata, }, nil } @@ -102,11 +83,9 @@ func SendMessage(conn io.Writer, msg *Message, aesKey, hmacKey []byte, encrypt b if err != nil { return fmt.Errorf("error serializing message: %v", err) } - if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil { return fmt.Errorf("error sending message: %v", err) } - return nil } @@ -115,7 +94,6 @@ func ReadMessage(conn io.Reader, aesKey, hmacKey []byte, decrypt bool) (*Message if err != nil { return nil, err } - return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt) } @@ -195,7 +173,6 @@ func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage, if err != nil { return nil, err } - var payloadBytes []byte if decrypt { nonce := make([]byte, 12) @@ -209,12 +186,10 @@ func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage, } else { payloadBytes = encryptedPayload } - var payload json.RawMessage if err := json.Unmarshal(payloadBytes, &payload); err != nil { return nil, fmt.Errorf("error unmarshalling payload: %v", err) } - return payload, nil } func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error { diff --git a/consts/constants.go b/consts/constants.go index 9abbc6b..c9280b8 100644 --- a/consts/constants.go +++ b/consts/constants.go @@ -2,7 +2,7 @@ package consts type CMD byte -func (c CMD) IsValid() bool { return c >= PING && c <= STOP } +func (c CMD) IsValid() bool { return c >= PING && c <= CONSUMER_STOP } const ( PING CMD = iota + 1 @@ -11,13 +11,29 @@ const ( MESSAGE_SEND MESSAGE_RESPONSE + MESSAGE_DENY MESSAGE_ACK MESSAGE_ERROR PUBLISH PUBLISH_ACK RESPONSE - STOP + + CONSUMER_PAUSE + CONSUMER_RESUME + CONSUMER_STOP + + CONSUMER_PAUSED + CONSUMER_RESUMED + CONSUMER_STOPPED +) + +type ConsumerState byte + +const ( + ConsumerStateActive ConsumerState = iota + ConsumerStatePaused + ConsumerStateStopped ) func (c CMD) String() string { @@ -30,6 +46,8 @@ func (c CMD) String() string { return "SUBSCRIBE_ACK" case MESSAGE_SEND: return "MESSAGE_SEND" + case MESSAGE_DENY: + return "MESSAGE_DENY" case MESSAGE_RESPONSE: return "MESSAGE_RESPONSE" case MESSAGE_ERROR: @@ -40,8 +58,18 @@ func (c CMD) String() string { return "PUBLISH" case PUBLISH_ACK: return "PUBLISH_ACK" - case STOP: - return "STOP" + case CONSUMER_PAUSE: + return "CONSUMER_PAUSE" + case CONSUMER_RESUME: + return "CONSUMER_RESUME" + case CONSUMER_STOP: + return "CONSUMER_STOP" + case CONSUMER_PAUSED: + return "CONSUMER_PAUSED" + case CONSUMER_RESUMED: + return "CONSUMER_RESUMED" + case CONSUMER_STOPPED: + return "CONSUMER_STOPPED" case RESPONSE: return "RESPONSE" default: diff --git a/consumer.go b/consumer.go index f639538..8fc5f1e 100644 --- a/consumer.go +++ b/consumer.go @@ -23,6 +23,7 @@ type Consumer struct { conn net.Conn queue string opts Options + pool *Pool } // NewConsumer initializes a new consumer with the provided options. @@ -44,8 +45,9 @@ func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) { return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) } -// Close closes the consumer's connection. +// Close closes the consumer's connection and stops the worker pool. func (c *Consumer) Close() error { + c.pool.Stop() return c.conn.Close() } @@ -55,11 +57,10 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error { consts.ConsumerKey: c.id, consts.ContentType: consts.TypeJson, }) - msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers) + msg := codec.NewMessage(consts.SUBSCRIBE, []byte("{}"), queue, headers) if err := c.send(c.conn, msg); err != nil { return err } - return c.waitForAck(c.conn) } @@ -73,6 +74,30 @@ func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) { } func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + switch msg.Command { + case consts.PUBLISH: + c.ConsumeMessage(ctx, msg, conn) + case consts.CONSUMER_PAUSE: + err := c.Pause() + if err != nil { + log.Printf("Unable to pause consumer: %v", err) + } + case consts.CONSUMER_RESUME: + err := c.Resume() + if err != nil { + log.Printf("Unable to resume consumer: %v", err) + } + case consts.CONSUMER_STOP: + err := c.Stop() + if err != nil { + log.Printf("Unable to stop consumer: %v", err) + } + default: + log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue) + } +} + +func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { headers := WithHeaders(ctx, map[string]string{ consts.ConsumerKey: c.id, consts.ContentType: consts.TypeJson, @@ -89,11 +114,20 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C log.Printf("Error unmarshalling message: %v", err) return } + ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) - result := c.ProcessTask(ctx, &task) - err = c.OnResponse(ctx, result) - if err != nil { - log.Printf("Error on message callback: %v", err) + if !c.opts.enableWorkerPool { + result := c.ProcessTask(ctx, &task) + err = c.OnResponse(ctx, result) + if err != nil { + log.Printf("Error on message callback: %v", err) + } + return + } + // Add the task to the worker pool + if err := c.pool.AddTask(ctx, &task); err != nil { + c.sendDenyMessage(ctx, taskID, msg.Queue, err) + return } } @@ -120,6 +154,17 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error { return nil } +func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) { + headers := WithHeaders(ctx, map[string]string{ + consts.ConsumerKey: c.id, + consts.ContentType: consts.TypeJson, + }) + reply := codec.NewMessage(consts.MESSAGE_DENY, []byte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers) + if sendErr := c.send(c.conn, reply); sendErr != nil { + log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr) + } +} + // ProcessTask handles a received task message and invokes the appropriate handler. func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result { result := c.handler(ctx, msg) @@ -171,10 +216,11 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } - + c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.conn) if err := c.subscribe(ctx, c.queue); err != nil { return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) } + c.pool.Start(c.opts.numOfWorkers) var wg sync.WaitGroup wg.Add(1) go func() { @@ -202,3 +248,53 @@ func (c *Consumer) waitForAck(conn net.Conn) error { } return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) } + +// Additional methods for Pause, Resume, and Stop + +func (c *Consumer) Pause() error { + if err := c.sendPauseMessage(); err != nil { + return err + } + c.pool.Pause() + return nil +} + +func (c *Consumer) sendPauseMessage() error { + headers := WithHeaders(context.Background(), map[string]string{ + consts.ConsumerKey: c.id, + }) + msg := codec.NewMessage(consts.CONSUMER_PAUSED, nil, c.queue, headers) + return c.send(c.conn, msg) +} + +func (c *Consumer) Resume() error { + if err := c.sendResumeMessage(); err != nil { + return err + } + c.pool.Resume() + return nil +} + +func (c *Consumer) sendResumeMessage() error { + headers := WithHeaders(context.Background(), map[string]string{ + consts.ConsumerKey: c.id, + }) + msg := codec.NewMessage(consts.CONSUMER_RESUMED, nil, c.queue, headers) + return c.send(c.conn, msg) +} + +func (c *Consumer) Stop() error { + if err := c.sendStopMessage(); err != nil { + return err + } + c.pool.Stop() + return nil +} + +func (c *Consumer) sendStopMessage() error { + headers := WithHeaders(context.Background(), map[string]string{ + consts.ConsumerKey: c.id, + }) + msg := codec.NewMessage(consts.CONSUMER_STOPPED, nil, c.queue, headers) + return c.send(c.conn, msg) +} diff --git a/dag/dag.go b/dag/dag.go index 057ed88..70b7a4d 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" "fmt" - "github.com/oarkflow/xid" "log" "net/http" "sync" "time" + "github.com/oarkflow/xid" + "github.com/oarkflow/mq" ) @@ -49,6 +50,8 @@ type DAG struct { taskContext map[string]*TaskManager conditions map[string]map[string]string mu sync.RWMutex + paused bool + opts []mq.Option } func NewDAG(opts ...mq.Option) *DAG { @@ -59,6 +62,7 @@ func NewDAG(opts ...mq.Option) *DAG { } opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) d.server = mq.NewBroker(opts...) + d.opts = opts return d } @@ -95,10 +99,13 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { if con.isReady { go func(con *Node) { time.Sleep(1 * time.Second) - con.consumer.Consume(ctx) + err := con.consumer.Consume(ctx) + if err != nil { + panic(err) + } }(con) } else { - log.Printf("[WARNING] - %s is not ready yet", con.Key) + log.Printf("[WARNING] - Consumer %s is not ready yet", con.Key) } } } @@ -114,7 +121,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) { tm.mu.Lock() defer tm.mu.Unlock() - con := mq.NewConsumer(key, key, handler) + con := mq.NewConsumer(key, key, handler, tm.opts...) tm.Nodes[key] = &Node{ Key: key, consumer: con, @@ -176,8 +183,11 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { } func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { + if tm.paused { + return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} + } if !tm.IsReady() { - return mq.Result{Error: fmt.Errorf("DAG is not ready yet")} + return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")} } val := ctx.Value("initial_node") initialNode, ok := val.(string) @@ -226,3 +236,27 @@ func (tm *DAG) FindInitialNode() *Node { } return nil } + +func (tm *DAG) Pause() { + tm.paused = true + log.Printf("DAG - PAUSED") +} + +func (tm *DAG) Resume() { + tm.paused = false + log.Printf("DAG - RESUMED") +} + +func (tm *DAG) PauseConsumer(id string) { + if node, ok := tm.Nodes[id]; ok { + node.consumer.Pause() + node.isReady = false + } +} + +func (tm *DAG) ResumeConsumer(id string) { + if node, ok := tm.Nodes[id]; ok { + node.consumer.Resume() + node.isReady = true + } +} diff --git a/examples/dag.go b/examples/dag.go index f095153..ffa46c5 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -15,7 +15,11 @@ import ( ) var ( - d = dag.NewDAG(mq.WithSyncMode(false), mq.WithNotifyResponse(tasks.NotifyResponse)) + d = dag.NewDAG( + mq.WithNotifyResponse(tasks.NotifyResponse), + mq.WithWorkerPool(100, 4, 5000000), + mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")), + ) // d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) ) @@ -34,6 +38,24 @@ func main() { d.AddEdge("E", "F") http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) + http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { + id := request.PathValue("id") + if id != "" { + d.PauseConsumer(id) + } + }) + http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { + id := request.PathValue("id") + if id != "" { + d.ResumeConsumer(id) + } + }) + http.HandleFunc("/pause", func(writer http.ResponseWriter, request *http.Request) { + d.Pause() + }) + http.HandleFunc("/resume", func(writer http.ResponseWriter, request *http.Request) { + d.Resume() + }) err := d.Start(context.TODO(), ":8083") if err != nil { panic(err) diff --git a/examples/hmac.go b/examples/hmac.go new file mode 100644 index 0000000..e37da79 --- /dev/null +++ b/examples/hmac.go @@ -0,0 +1,34 @@ +package main + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "log" +) + +func GenerateSecretKey() (string, error) { + // Create a byte slice to hold 32 random bytes + key := make([]byte, 32) + + // Fill the slice with secure random bytes + _, err := rand.Read(key) + if err != nil { + return "", err + } + + // Encode the byte slice to a Base64 string + secretKey := base64.StdEncoding.EncodeToString(key) + + // Return the first 32 characters + return secretKey[:32], nil +} + +func main() { + secretKey, err := GenerateSecretKey() + if err != nil { + log.Fatalf("Error generating secret key: %v", err) + } + + fmt.Println("Generated Secret Key:", secretKey) +} diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index 670c408..38764b3 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -9,17 +9,19 @@ import ( "github.com/oarkflow/mq" ) -func Node1(ctx context.Context, task *mq.Task) mq.Result { +func Node1(_ context.Context, task *mq.Task) mq.Result { + fmt.Println("Node 1", string(task.Payload)) return mq.Result{Payload: task.Payload, TaskID: task.ID} } -func Node2(ctx context.Context, task *mq.Task) mq.Result { +func Node2(_ context.Context, task *mq.Task) mq.Result { + fmt.Println("Node 2", string(task.Payload)) return mq.Result{Payload: task.Payload, TaskID: task.ID} } -func Node3(ctx context.Context, task *mq.Task) mq.Result { +func Node3(_ context.Context, task *mq.Task) mq.Result { var user map[string]any - json.Unmarshal(task.Payload, &user) + _ = json.Unmarshal(task.Payload, &user) age := int(user["age"].(float64)) status := "FAIL" if age > 20 { @@ -30,34 +32,34 @@ func Node3(ctx context.Context, task *mq.Task) mq.Result { return mq.Result{Payload: resultPayload, Status: status} } -func Node4(ctx context.Context, task *mq.Task) mq.Result { +func Node4(_ context.Context, task *mq.Task) mq.Result { var user map[string]any - json.Unmarshal(task.Payload, &user) + _ = json.Unmarshal(task.Payload, &user) user["final"] = "D" resultPayload, _ := json.Marshal(user) return mq.Result{Payload: resultPayload} } -func Node5(ctx context.Context, task *mq.Task) mq.Result { +func Node5(_ context.Context, task *mq.Task) mq.Result { var user map[string]any - json.Unmarshal(task.Payload, &user) + _ = json.Unmarshal(task.Payload, &user) user["salary"] = "E" resultPayload, _ := json.Marshal(user) return mq.Result{Payload: resultPayload} } -func Node6(ctx context.Context, task *mq.Task) mq.Result { +func Node6(_ context.Context, task *mq.Task) mq.Result { var user map[string]any - json.Unmarshal(task.Payload, &user) + _ = json.Unmarshal(task.Payload, &user) resultPayload, _ := json.Marshal(map[string]any{"storage": user}) return mq.Result{Payload: resultPayload} } -func Callback(ctx context.Context, task mq.Result) mq.Result { +func Callback(_ context.Context, task mq.Result) mq.Result { fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic) return mq.Result{} } -func NotifyResponse(ctx context.Context, result mq.Result) { - log.Printf("DAG Final response: TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic) +func NotifyResponse(_ context.Context, result mq.Result) { + log.Printf("DAG - FINAL_RESPONSE ~> TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic) } diff --git a/options.go b/options.go index 0e509b3..5899e15 100644 --- a/options.go +++ b/options.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "runtime" "time" ) @@ -72,17 +73,22 @@ type Options struct { hmacKey json.RawMessage enableEncryption bool queueSize int + numOfWorkers int + maxMemoryLoad int64 + enableWorkerPool bool } func defaultOptions() Options { return Options{ - syncMode: false, brokerAddr: ":8080", maxRetries: 5, initialDelay: 2 * time.Second, maxBackoff: 20 * time.Second, jitterPercent: 0.5, queueSize: 100, + hmacKey: []byte(`a9f4b9415485b70275673b5920182796ea497b5e093ead844a43ea5d77cbc24f`), + numOfWorkers: runtime.NumCPU(), + maxMemoryLoad: 5000000, } } @@ -103,6 +109,15 @@ func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option } } +func WithWorkerPool(queueSize, numOfWorkers int, maxMemoryLoad int64) Option { + return func(opts *Options) { + opts.enableWorkerPool = true + opts.queueSize = queueSize + opts.numOfWorkers = numOfWorkers + opts.maxMemoryLoad = maxMemoryLoad + } +} + func WithConsumerOnSubscribe(handler func(ctx context.Context, topic, consumerName string)) Option { return func(opts *Options) { opts.consumerOnSubscribe = handler @@ -115,11 +130,16 @@ func WithConsumerOnClose(handler func(ctx context.Context, topic, consumerName s } } -func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { +func WithSecretKey(aesKey json.RawMessage) Option { return func(opts *Options) { opts.aesKey = aesKey + opts.enableEncryption = true + } +} + +func WithHMACKey(hmacKey json.RawMessage) Option { + return func(opts *Options) { opts.hmacKey = hmacKey - opts.enableEncryption = enableEncryption } } diff --git a/utils/encrypt.go b/utils/encrypt.go new file mode 100644 index 0000000..7c4d958 --- /dev/null +++ b/utils/encrypt.go @@ -0,0 +1,43 @@ +package utils + +import ( + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" +) + +func GenerateHMACKey(length int) (string, error) { + key := make([]byte, length) + _, err := rand.Read(key) + if err != nil { + return "", fmt.Errorf("failed to generate random key: %v", err) + } + return hex.EncodeToString(key), nil +} + +func MustGenerateHMACKey(length int) string { + key, err := GenerateHMACKey(length) + if err != nil { + panic(err) + } + return key +} + +func GenerateSecretKey(length int) (string, error) { + key := make([]byte, length) + _, err := rand.Read(key) + if err != nil { + return "", err + } + secretKey := base64.StdEncoding.EncodeToString(key) + return secretKey[:length], nil +} + +func MustGenerateSecretKey(length int) string { + key, err := GenerateSecretKey(length) + if err != nil { + panic(err) + } + return key +}