diff --git a/broker.go b/broker.go index caf4ff2..dac2fe9 100644 --- a/broker.go +++ b/broker.go @@ -4,37 +4,35 @@ import ( "context" "crypto/tls" "encoding/json" - "errors" "fmt" "log" "net" + "strings" "time" "github.com/oarkflow/xsync" + "github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/jsonparser" + "github.com/oarkflow/mq/utils" ) +type QueuedTask struct { + Message *codec.Message + RetryCount int +} + type consumer struct { id string conn net.Conn } -func (p *consumer) send(ctx context.Context, cmd any) error { - return Write(ctx, p.conn, cmd) -} - type publisher struct { id string conn net.Conn } -func (p *publisher) send(ctx context.Context, cmd any) error { - return Write(ctx, p.conn, cmd) -} - -type Handler func(context.Context, Task) Result - type Broker struct { queues xsync.IMap[string, *Queue] consumers xsync.IMap[string, *consumer] @@ -42,100 +40,17 @@ type Broker struct { opts Options } -type Queue struct { - name string - consumers xsync.IMap[string, *consumer] - messages xsync.IMap[string, *Task] - deferred xsync.IMap[string, *Task] -} - -func newQueue(name string) *Queue { - return &Queue{ - name: name, - consumers: xsync.NewMap[string, *consumer](), - messages: xsync.NewMap[string, *Task](), - deferred: xsync.NewMap[string, *Task](), - } -} - -func (queue *Queue) send(ctx context.Context, cmd any) { - queue.consumers.ForEach(func(_ string, client *consumer) bool { - err := client.send(ctx, cmd) - if err != nil { - return false - } - return true - }) -} - -type Task struct { - ID string `json:"id"` - Payload json.RawMessage `json:"payload"` - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - CurrentQueue string `json:"current_queue"` - Status string `json:"status"` - Error error `json:"error"` -} - -type Command struct { - ID string `json:"id"` - Command consts.CMD `json:"command"` - Queue string `json:"queue"` - MessageID string `json:"message_id"` - Payload json.RawMessage `json:"payload,omitempty"` // Used for carrying the task payload - Error string `json:"error,omitempty"` -} - -type Result struct { - Command string `json:"command"` - Payload json.RawMessage `json:"payload"` - Queue string `json:"queue"` - MessageID string `json:"message_id"` - Error error `json:"error"` - Status string `json:"status"` -} - func NewBroker(opts ...Option) *Broker { - options := defaultOptions() - for _, opt := range opts { - opt(&options) - } - b := &Broker{ + options := setupOptions(opts...) + return &Broker{ queues: xsync.NewMap[string, *Queue](), publishers: xsync.NewMap[string, *publisher](), consumers: xsync.NewMap[string, *consumer](), + opts: options, } - b.opts = defaultHandlers(options, b.onMessage, b.onClose, b.onError) - return b } -func (b *Broker) Send(ctx context.Context, cmd Command) error { - queue, ok := b.queues.Get(cmd.Queue) - if !ok || queue == nil { - return errors.New("invalid queue or not exists") - } - queue.send(ctx, cmd) - return nil -} - -func (b *Broker) TLSConfig() TLSConfig { - return b.opts.tlsConfig -} - -func (b *Broker) SyncMode() bool { - return b.opts.syncMode -} - -func (b *Broker) sendToPublisher(ctx context.Context, publisherID string, result Result) error { - pub, ok := b.publishers.Get(publisherID) - if !ok { - return nil - } - return pub.send(ctx, result) -} - -func (b *Broker) onClose(ctx context.Context, _ net.Conn) error { +func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error { consumerID, ok := GetConsumerID(ctx) if ok && consumerID != "" { if con, exists := b.consumers.Get(consumerID); exists { @@ -157,11 +72,94 @@ func (b *Broker) onClose(ctx context.Context, _ net.Conn) error { return nil } -func (b *Broker) onError(_ context.Context, conn net.Conn, err error) { +func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) { fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) } -// Start the broker server with optional TLS support +func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + switch msg.Command { + case consts.PUBLISH: + b.PublishHandler(ctx, conn, msg) + case consts.SUBSCRIBE: + b.SubscribeHandler(ctx, conn, msg) + case consts.MESSAGE_RESPONSE: + b.MessageResponseHandler(ctx, msg) + case consts.MESSAGE_ACK: + b.MessageAck(ctx, msg) + } +} + +func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) { + consumerID, _ := GetConsumerID(ctx) + taskID, _ := jsonparser.GetString(msg.Payload, "id") + log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID) +} + +func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) { + msg.Command = consts.RESPONSE + headers, ok := GetHeaders(ctx) + if !ok { + return + } + b.HandleCallback(ctx, msg) + awaitResponse, ok := headers[consts.AwaitResponseKey] + if !(ok && awaitResponse == "true") { + return + } + publisherID, exists := headers[consts.PublisherKey] + if !exists { + return + } + con, ok := b.publishers.Get(publisherID) + if !ok { + return + } + err := b.send(con.conn, msg) + if err != nil { + panic(err) + } +} + +func (b *Broker) Publish(ctx context.Context, task Task, queue string) error { + headers, _ := GetHeaders(ctx) + payload, _ := json.Marshal(task) + msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers) + b.broadcastToConsumers(ctx, msg) + return nil +} + +func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { + pub := b.addPublisher(ctx, msg.Queue, conn) + taskID, _ := jsonparser.GetString(msg.Payload, "id") + log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID) + + ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) + if err := b.send(conn, ack); err != nil { + log.Printf("Error sending PUBLISH_ACK: %v\n", err) + } + b.broadcastToConsumers(ctx, msg) + go func() { + select { + case <-ctx.Done(): + b.publishers.Del(pub.id) + } + }() +} + +func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { + consumerID := b.addConsumer(ctx, msg.Queue, conn) + ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers) + if err := b.send(conn, ack); err != nil { + log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) + } + go func() { + select { + case <-ctx.Done(): + b.removeConsumer(msg.Queue, consumerID) + } + }() +} + func (b *Broker) Start(ctx context.Context) error { var listener net.Listener var err error @@ -178,113 +176,61 @@ func (b *Broker) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to start TLS listener: %v", err) } - log.Println("TLS server started on", b.opts.brokerAddr) + log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr) } else { listener, err = net.Listen("tcp", b.opts.brokerAddr) if err != nil { return fmt.Errorf("failed to start TCP listener: %v", err) } - log.Println("TCP server started on", b.opts.brokerAddr) + log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr) } defer listener.Close() for { conn, err := listener.Accept() if err != nil { - fmt.Println("Error accepting connection:", err) + b.OnError(ctx, conn, err) continue } - go ReadFromConn(ctx, conn, Handlers{ - MessageHandler: b.opts.messageHandler, - CloseHandler: b.opts.closeHandler, - ErrorHandler: b.opts.errorHandler, - }) - } -} - -func (b *Broker) Publish(ctx context.Context, message Task, queueName string) Result { - queue, task, err := b.AddMessageToQueue(&message, queueName) - if err != nil { - return Result{Error: err} - } - result := Result{ - Command: "PUBLISH", - Payload: message.Payload, - Queue: queueName, - MessageID: task.ID, - } - if queue.consumers.Size() == 0 { - queue.deferred.Set(NewID(), &message) - fmt.Println("task deferred as no consumers are connected", queueName) - return result - } - queue.send(ctx, message) - return result -} - -func (b *Broker) NewQueue(qName string) *Queue { - q, ok := b.queues.Get(qName) - if ok { - return q - } - q = newQueue(qName) - b.queues.Set(qName, q) - return q -} - -func (b *Broker) AddMessageToQueue(task *Task, queueName string) (*Queue, *Task, error) { - queue := b.NewQueue(queueName) - if task.ID == "" { - task.ID = NewID() - } - if queueName != "" { - task.CurrentQueue = queueName - } - task.CreatedAt = time.Now() - queue.messages.Set(task.ID, task) - return queue, task, nil -} - -func (b *Broker) HandleProcessedMessage(ctx context.Context, result Result) error { - publisherID, ok := GetPublisherID(ctx) - if ok && publisherID != "" { - err := b.sendToPublisher(ctx, publisherID, result) - if err != nil { - return err - } - } - for _, callback := range b.opts.callback { - if callback != nil { - rs := callback(ctx, result) - if rs.Error != nil { - return rs.Error + go func(c net.Conn) { + defer c.Close() + for { + err := b.readMessage(ctx, c) + if err != nil { + break + } } - } + }(conn) } - return nil } -func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string { - consumerID, ok := GetConsumerID(ctx) - defer func() { - cmd := Command{ - Command: consts.SUBSCRIBE_ACK, - Queue: queueName, - Error: "", - } - Write(ctx, conn, cmd) - log.Printf("Consumer %s joined server on queue %s", consumerID, queueName) - }() - q, ok := b.queues.Get(queueName) - if !ok { - q = b.NewQueue(queueName) - } - con := &consumer{id: consumerID, conn: conn} - b.consumers.Set(consumerID, con) - q.consumers.Set(consumerID, con) - return consumerID +func (b *Broker) send(conn net.Conn, msg *codec.Message) error { + return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) } -func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) string { +func (b *Broker) receive(c net.Conn) (*codec.Message, error) { + return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) +} + +func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) { + if queue, ok := b.queues.Get(msg.Queue); ok { + task := &QueuedTask{Message: msg, RetryCount: 0} + queue.tasks <- task + } +} + +func (b *Broker) waitForConsumerAck(conn net.Conn) error { + msg, err := b.receive(conn) + if err != nil { + return err + } + if msg.Command == consts.MESSAGE_ACK { + log.Println("Received CONSUMER_ACK: Subscribed successfully") + return nil + } + return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command) +} + +func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher { publisherID, ok := GetPublisherID(ctx) _, ok = b.queues.Get(queueName) if !ok { @@ -292,20 +238,22 @@ func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Co } con := &publisher{id: publisherID, conn: conn} b.publishers.Set(publisherID, con) - return publisherID + return con } -func (b *Broker) subscribe(ctx context.Context, queueName string, conn net.Conn) { - consumerID := b.addConsumer(ctx, queueName, conn) - go func() { - select { - case <-ctx.Done(): - b.removeConsumer(queueName, consumerID) - } - }() +func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string { + consumerID, ok := GetConsumerID(ctx) + q, ok := b.queues.Get(queueName) + if !ok { + q = b.NewQueue(queueName) + } + con := &consumer{id: consumerID, conn: conn} + b.consumers.Set(consumerID, con) + q.consumers.Set(consumerID, con) + log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName) + return consumerID } -// Removes connection from the queue and broker func (b *Broker) removeConsumer(queueName, consumerID string) { if queue, ok := b.queues.Get(queueName); ok { con, ok := queue.consumers.Get(consumerID) @@ -317,57 +265,59 @@ func (b *Broker) removeConsumer(queueName, consumerID string) { } } -func (b *Broker) onMessage(ctx context.Context, conn net.Conn, message []byte) error { - var cmdMsg Command - var resultMsg Result - err := json.Unmarshal(message, &cmdMsg) +func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { + msg, err := b.receive(c) if err == nil { - return b.handleCommandMessage(ctx, conn, cmdMsg) - } - err = json.Unmarshal(message, &resultMsg) - if err == nil { - return b.handleTaskMessage(ctx, conn, resultMsg) - } - return nil -} - -func (b *Broker) handleTaskMessage(ctx context.Context, _ net.Conn, msg Result) error { - return b.HandleProcessedMessage(ctx, msg) -} - -func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error { - status := "PUBLISH" - if msg.Command == consts.REQUEST { - status = "REQUEST" - } - b.addPublisher(ctx, msg.Queue, conn) - task := Task{ - ID: msg.MessageID, - Payload: msg.Payload, - CreatedAt: time.Now(), - CurrentQueue: msg.Queue, - } - result := b.Publish(ctx, task, msg.Queue) - if result.Error != nil { - return result.Error - } - if task.ID != "" { - result.Status = status - result.MessageID = task.ID - result.Queue = msg.Queue - return Write(ctx, conn, result) - } - return nil -} - -func (b *Broker) handleCommandMessage(ctx context.Context, conn net.Conn, msg Command) error { - switch msg.Command { - case consts.SUBSCRIBE: - b.subscribe(ctx, msg.Queue, conn) + ctx = SetHeaders(ctx, msg.Headers) + b.OnMessage(ctx, msg, c) return nil - case consts.PUBLISH, consts.REQUEST: - return b.publish(ctx, conn, msg) - default: - return fmt.Errorf("unknown command: %d", msg.Command) + } + if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { + b.OnClose(ctx, c) + return err + } + b.OnError(ctx, c, err) + return err +} + +func (b *Broker) dispatchWorker(queue *Queue) { + delay := b.opts.initialDelay + for task := range queue.tasks { + success := false + for !success && task.RetryCount <= b.opts.maxRetries { + if b.dispatchTaskToConsumer(queue, task) { + success = true + } else { + task.RetryCount++ + delay = b.backoffRetry(queue, task, delay) + } + } } } + +func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool { + var consumerFound bool + queue.consumers.ForEach(func(_ string, con *consumer) bool { + if err := b.send(con.conn, task.Message); err == nil { + consumerFound = true + return false // break the loop once a consumer is found + } + return true + }) + if !consumerFound { + log.Printf("No available consumers for queue %s, retrying...", queue.name) + } + return consumerFound +} + +func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration { + backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent) + log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue) + time.Sleep(backoffDuration) + queue.tasks <- task + delay *= 2 + if delay > b.opts.maxBackoff { + delay = b.opts.maxBackoff + } + return delay +} diff --git a/consumer.go b/consumer.go index 981262f..570c03e 100644 --- a/consumer.go +++ b/consumer.go @@ -7,10 +7,13 @@ import ( "fmt" "log" "net" + "strings" "sync" "time" + "github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/jsonparser" "github.com/oarkflow/mq/utils" ) @@ -25,16 +28,20 @@ type Consumer struct { // NewConsumer initializes a new consumer with the provided options. func NewConsumer(id string, opts ...Option) *Consumer { - options := defaultOptions() - for _, opt := range opts { - opt(&options) - } - b := &Consumer{ + options := setupOptions(opts...) + return &Consumer{ handlers: make(map[string]Handler), id: id, + opts: options, } - b.opts = defaultHandlers(options, b.onMessage, b.onClose, b.onError) - return b +} + +func (c *Consumer) send(conn net.Conn, msg *codec.Message) error { + return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) +} + +func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) { + return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) } // Close closes the consumer's connection. @@ -43,90 +50,82 @@ func (c *Consumer) Close() error { } // Subscribe to a specific queue. -func (c *Consumer) subscribe(queue string) error { - ctx := context.Background() - ctx = SetHeaders(ctx, map[string]string{ +func (c *Consumer) subscribe(ctx context.Context, queue string) error { + headers := WithHeaders(ctx, map[string]string{ consts.ConsumerKey: c.id, consts.ContentType: consts.TypeJson, }) - subscribe := Command{ - Command: consts.SUBSCRIBE, - Queue: queue, - ID: NewID(), + msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers) + if err := c.send(c.conn, msg); err != nil { + return err + } + + return c.waitForAck(c.conn) +} + +func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error { + fmt.Println("Consumer closed") + return nil +} + +func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) { + fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) +} + +func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + headers := WithHeaders(ctx, map[string]string{ + consts.ConsumerKey: c.id, + consts.ContentType: consts.TypeJson, + }) + taskID, _ := jsonparser.GetString(msg.Payload, "id") + reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) + if err := c.send(conn, reply); err != nil { + fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err) + } + var task Task + err := json.Unmarshal(msg.Payload, &task) + if err != nil { + log.Println("Error unmarshalling message:", err) + return + } + ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) + result := c.ProcessTask(ctx, task) + result.MessageID = task.ID + result.Queue = msg.Queue + if result.Error != nil { + result.Status = "FAILED" + } else { + result.Status = "SUCCESS" + } + bt, _ := json.Marshal(result) + reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers) + if err := c.send(conn, reply); err != nil { + fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err) } - return Write(ctx, c.conn, subscribe) } // ProcessTask handles a received task message and invokes the appropriate handler. func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { - handler, exists := c.handlers[msg.CurrentQueue] + queue, _ := GetQueue(ctx) + handler, exists := c.handlers[queue] if !exists { - return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)} + return Result{Error: errors.New("No handler for queue " + queue)} } return handler(ctx, msg) } -// Handle command message sent by the server. -func (c *Consumer) handleCommandMessage(msg Command) error { - switch msg.Command { - case consts.STOP: - return c.Close() - case consts.SUBSCRIBE_ACK: - log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue) - return nil - default: - return fmt.Errorf("unknown command in consumer %d", msg.Command) - } -} - -// Handle task message sent by the server. -func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error { - response := c.ProcessTask(ctx, msg) - response.Queue = msg.CurrentQueue - if msg.ID == "" { - response.Error = errors.New("task ID is empty") - response.Command = "error" - } else { - response.Command = "completed" - response.MessageID = msg.ID - } - return c.sendResult(ctx, response) -} - -// Send the result of task processing back to the server. -func (c *Consumer) sendResult(ctx context.Context, response Result) error { - return Write(ctx, c.conn, response) -} - -// Read and handle incoming messages. -func (c *Consumer) readMessage(ctx context.Context, message []byte) error { - var cmdMsg Command - var task Task - err := json.Unmarshal(message, &cmdMsg) - if err == nil && cmdMsg.Command != 0 { - return c.handleCommandMessage(cmdMsg) - } - err = json.Unmarshal(message, &task) - if err == nil { - return c.handleTaskMessage(ctx, task) - } - return nil -} - // AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration. func (c *Consumer) AttemptConnect() error { var err error delay := c.opts.initialDelay - for i := 0; i < c.opts.maxRetries; i++ { conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig) if err == nil { c.conn = conn return nil } - sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent) - fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration) + log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration) time.Sleep(sleepDuration) delay *= 2 if delay > c.opts.maxBackoff { @@ -137,20 +136,19 @@ func (c *Consumer) AttemptConnect() error { return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err) } -// onMessage reads incoming messages from the connection. -func (c *Consumer) onMessage(ctx context.Context, conn net.Conn, message []byte) error { - return c.readMessage(ctx, message) -} - -// onClose handles connection close event. -func (c *Consumer) onClose(ctx context.Context, conn net.Conn) error { - fmt.Println("Consumer Connection closed", c.id, conn.RemoteAddr()) - return nil -} - -// onError handles errors while reading from the connection. -func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) { - fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr()) +func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error { + msg, err := c.receive(conn) + if err == nil { + ctx = SetHeaders(ctx, msg.Headers) + c.OnMessage(ctx, msg, conn) + return nil + } + if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { + c.OnClose(ctx, conn) + return err + } + c.OnError(ctx, conn, err) + return err } // Consume starts the consumer to consume tasks from the queues. @@ -159,26 +157,39 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } + for _, q := range c.queues { + if err := c.subscribe(ctx, q); err != nil { + return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) + } + } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - ReadFromConn(ctx, c.conn, Handlers{ - MessageHandler: c.opts.messageHandler, - CloseHandler: c.opts.closeHandler, - ErrorHandler: c.opts.errorHandler, - }) - fmt.Println("Stopping consumer") - }() - for _, q := range c.queues { - if err := c.subscribe(q); err != nil { - return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) + for { + if err := c.readMessage(ctx, c.conn); err != nil { + log.Println("Error reading message:", err) + break + } } - } + }() + wg.Wait() return nil } +func (c *Consumer) waitForAck(conn net.Conn) error { + msg, err := c.receive(conn) + if err != nil { + return err + } + if msg.Command == consts.SUBSCRIBE_ACK { + log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue) + return nil + } + return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) +} + // RegisterHandler registers a handler for a queue. func (c *Consumer) RegisterHandler(queue string, handler Handler) { c.queues = append(c.queues, queue) diff --git a/ctx.go b/ctx.go index 05afc3d..2817b3e 100644 --- a/ctx.go +++ b/ctx.go @@ -1,36 +1,31 @@ package mq import ( - "bufio" - "bytes" "context" "crypto/tls" "crypto/x509" "encoding/json" "fmt" - "io" "net" "os" - "strings" + "time" "github.com/oarkflow/xid" - "github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/consts" ) -type MessageHandler func(context.Context, net.Conn, []byte) error - -type CloseHandler func(context.Context, net.Conn) error - -type ErrorHandler func(context.Context, net.Conn, error) - -type Handlers struct { - MessageHandler MessageHandler - CloseHandler CloseHandler - ErrorHandler ErrorHandler +type Task struct { + ID string `json:"id"` + Payload json.RawMessage `json:"payload"` + CreatedAt time.Time `json:"created_at"` + ProcessedAt time.Time `json:"processed_at"` + Status string `json:"status"` + Error error `json:"error"` } +type Handler func(context.Context, Task) Result + func IsClosed(conn net.Conn) bool { _, err := conn.Read(make([]byte, 1)) if err != nil { @@ -52,11 +47,31 @@ func SetHeaders(ctx context.Context, headers map[string]string) context.Context return context.WithValue(ctx, consts.HeaderKey, hd) } +func WithHeaders(ctx context.Context, headers map[string]string) map[string]string { + hd, ok := GetHeaders(ctx) + if !ok { + hd = make(map[string]string) + } + for key, val := range headers { + hd[key] = val + } + return hd +} + func GetHeaders(ctx context.Context) (map[string]string, bool) { headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) return headers, ok } +func GetHeader(ctx context.Context, key string) (string, bool) { + headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) + if !ok { + return "", false + } + val, ok := headers[key] + return val, ok +} + func GetContentType(ctx context.Context) (string, bool) { headers, ok := GetHeaders(ctx) if !ok { @@ -66,6 +81,15 @@ func GetContentType(ctx context.Context) (string, bool) { return contentType, ok } +func GetQueue(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[consts.QueueKey] + return contentType, ok +} + func GetConsumerID(ctx context.Context) (string, bool) { headers, ok := GetHeaders(ctx) if !ok { @@ -93,70 +117,6 @@ func GetPublisherID(ctx context.Context) (string, bool) { return contentType, ok } -func Write(ctx context.Context, conn net.Conn, data any) error { - msg := codec.Message{Headers: make(map[string]string)} - if headers, ok := GetHeaders(ctx); ok { - msg.Headers = headers - } - dataBytes, err := json.Marshal(data) - if err != nil { - return err - } - msg.Payload = dataBytes - messageBytes, err := json.Marshal(msg) - if err != nil { - return err - } - _, err = conn.Write(append(messageBytes, '\n')) - return err -} - -func ReadFromConn(ctx context.Context, conn net.Conn, handlers Handlers) { - defer func() { - if handlers.CloseHandler != nil { - if err := handlers.CloseHandler(ctx, conn); err != nil { - fmt.Println("Error in close handler:", err) - } - } - conn.Close() - }() - reader := bufio.NewReader(conn) - for { - messageBytes, err := reader.ReadBytes('\n') - if err != nil { - if err == io.EOF || IsClosed(conn) || strings.Contains(err.Error(), "closed network connection") { - break - } - if handlers.ErrorHandler != nil { - handlers.ErrorHandler(ctx, conn, err) - } - continue - } - messageBytes = bytes.TrimSpace(messageBytes) - if len(messageBytes) == 0 { - continue - } - var msg codec.Message - err = json.Unmarshal(messageBytes, &msg) - if err != nil { - if handlers.ErrorHandler != nil { - handlers.ErrorHandler(ctx, conn, err) - } - continue - } - ctx = SetHeaders(ctx, msg.Headers) - if handlers.MessageHandler != nil { - err = handlers.MessageHandler(ctx, conn, msg.Payload) - if err != nil { - if handlers.ErrorHandler != nil { - handlers.ErrorHandler(ctx, conn, err) - } - continue - } - } - } -} - func NewID() string { return xid.New().String() } diff --git a/dag/dag.go b/dag/dag.go index 4bbde72..7ee7cc5 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -3,13 +3,13 @@ package dag import ( "context" "encoding/json" - "fmt" + "github.com/oarkflow/mq/consts" "log" "net/http" "sync" + "time" "github.com/oarkflow/mq" - "github.com/oarkflow/mq/consts" ) type taskContext struct { @@ -76,12 +76,20 @@ func (d *DAG) Start(ctx context.Context, addr string) error { if d.server.SyncMode() { return nil } - for _, con := range d.nodes { - go con.Consume(ctx) - } go func() { - d.server.Start(ctx) + err := d.server.Start(ctx) + if err != nil { + panic(err) + } }() + for _, con := range d.nodes { + go func(con *mq.Consumer) { + err := con.Consume(ctx) + if err != nil { + panic(err) + } + }(con) + } log.Printf("HTTP server started on %s", addr) config := d.server.TLSConfig() if config.UseTLS { @@ -90,16 +98,6 @@ func (d *DAG) Start(ctx context.Context, addr string) error { return http.ListenAndServe(addr, nil) } -func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) mq.Result { - task := mq.Task{ - Payload: payload, - } - if len(taskID) > 0 { - task.ID = taskID[0] - } - return d.server.Publish(ctx, task, queueName) -} - func (d *DAG) FindFirstNode() (string, bool) { inDegree := make(map[string]int) for n, _ := range d.nodes { @@ -121,86 +119,23 @@ func (d *DAG) FindFirstNode() (string, bool) { return "", false } -func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result { - return d.sendSync(ctx, mq.Result{Payload: payload}) -} - -func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { - if d.FirstNode == "" { - return mq.Result{Error: fmt.Errorf("initial node not defined")} +func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) error { + queue, ok := mq.GetQueue(ctx) + if !ok { + queue = d.FirstNode } - if d.server.SyncMode() { - return d.sendSync(ctx, mq.Result{Payload: payload}) + var id string + if len(taskID) > 0 { + id = taskID[0] + } else { + id = mq.NewID() } - resultCh := make(chan mq.Result) - result := d.PublishTask(ctx, payload, d.FirstNode) - if result.Error != nil { - return result + task := mq.Task{ + ID: id, + Payload: payload, + CreatedAt: time.Now(), } - d.mu.Lock() - d.taskChMap[result.MessageID] = resultCh - d.mu.Unlock() - finalResult := <-resultCh - return finalResult -} - -func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { - if con, ok := d.nodes[task.Queue]; ok { - return con.ProcessTask(ctx, mq.Task{ - ID: task.MessageID, - Payload: task.Payload, - CurrentQueue: task.Queue, - }) - } - return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)} -} - -func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { - if task.MessageID == "" { - task.MessageID = mq.NewID() - } - if task.Queue == "" { - task.Queue = d.FirstNode - } - result := d.processNode(ctx, task) - if result.Error != nil { - return result - } - for _, target := range d.loopEdges[task.Queue] { - var items, results []json.RawMessage - if err := json.Unmarshal(result.Payload, &items); err != nil { - return mq.Result{Error: err} - } - for _, item := range items { - result = d.sendSync(ctx, mq.Result{ - Command: result.Command, - Payload: item, - Queue: target, - MessageID: result.MessageID, - }) - if result.Error != nil { - return result - } - results = append(results, result.Payload) - } - bt, err := json.Marshal(results) - if err != nil { - return mq.Result{Error: err} - } - result.Payload = bt - } - if target, ok := d.edges[task.Queue]; ok { - result = d.sendSync(ctx, mq.Result{ - Command: result.Command, - Payload: result.Payload, - Queue: target, - MessageID: result.MessageID, - }) - if result.Error != nil { - return result - } - } - return result + return d.server.Publish(ctx, task, queue) } func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) { @@ -264,9 +199,12 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) for _, loopNode := range loopNodes { for _, item := range items { - rs := d.PublishTask(ctx, item, loopNode, task.MessageID) - if rs.Error != nil { - return rs + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: loopNode, + }) + err := d.PublishTask(ctx, item, task.MessageID) + if err != nil { + return mq.Result{Error: err} } } } @@ -284,15 +222,14 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { totalItems: 1, }, } - rs := d.PublishTask(ctx, payload, edge, task.MessageID) - if rs.Error != nil { - return rs + err := d.PublishTask(ctx, payload, edge, task.MessageID) + if err != nil { + return mq.Result{Error: err} } } else if completed { d.mu.Lock() if resultCh, ok := d.taskChMap[task.MessageID]; ok { resultCh <- mq.Result{ - Command: "complete", Payload: payload, Queue: task.Queue, MessageID: task.MessageID, diff --git a/examples/consumer.go b/examples/consumer.go index cf2e280..a312348 100644 --- a/examples/consumer.go +++ b/examples/consumer.go @@ -2,9 +2,9 @@ package main import ( "context" + "github.com/oarkflow/mq" "github.com/oarkflow/mq/examples/tasks" - mq "github.com/oarkflow/mq/v2" ) func main() { diff --git a/examples/dag.go b/examples/dag.go index 976a0f6..80e1d91 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -2,19 +2,15 @@ package main import ( "context" - "encoding/json" - "io" - "net/http" - - "github.com/oarkflow/mq" "github.com/oarkflow/mq/dag" "github.com/oarkflow/mq/examples/tasks" + "time" ) var d *dag.DAG func main() { - d = dag.New(mq.WithTLS(true, "server.crt", "server.key"), mq.WithCAPath("ca.crt")) + d = dag.New() d.AddNode("queue1", tasks.Node1) d.AddNode("queue2", tasks.Node2) d.AddNode("queue3", tasks.Node3) @@ -24,45 +20,14 @@ func main() { d.AddLoop("queue2", "queue3") d.AddEdge("queue2", "queue4") d.Prepare() - http.HandleFunc("POST /publish", requestHandler("publish")) - http.HandleFunc("POST /request", requestHandler("request")) - err := d.Start(context.TODO(), ":8083") + go func() { + d.Start(context.Background(), ":8081") + }() + time.Sleep(5 * time.Second) + err := d.PublishTask(context.Background(), []byte(`{"tast": 123}`)) if err != nil { panic(err) } -} -func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var payload []byte - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err = io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return - } - var rs mq.Result - if requestType == "request" { - rs = d.Request(context.Background(), payload) - } else { - rs = d.Send(context.Background(), payload) - } - w.Header().Set("Content-Type", "application/json") - result := map[string]any{ - "message_id": rs.MessageID, - "payload": string(rs.Payload), - "error": rs.Error, - } - json.NewEncoder(w).Encode(result) - } + time.Sleep(10 * time.Second) } diff --git a/examples/publisher.go b/examples/publisher.go index 480e231..44c94a1 100644 --- a/examples/publisher.go +++ b/examples/publisher.go @@ -3,17 +3,16 @@ package main import ( "context" "fmt" + mq2 "github.com/oarkflow/mq" "time" - - mq "github.com/oarkflow/mq/v2" ) func main() { payload := []byte(`{"message":"Message Publisher \n Task"}`) - task := mq.Task{ + task := mq2.Task{ Payload: payload, } - publisher := mq.NewPublisher("publish-1") + publisher := mq2.NewPublisher("publish-1") // publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key")) err := publisher.Publish(context.Background(), task, "queue1") if err != nil { @@ -21,7 +20,7 @@ func main() { } fmt.Println("Async task published successfully") payload = []byte(`{"message":"Fire-and-Forget \n Task"}`) - task = mq.Task{ + task = mq2.Task{ Payload: payload, } for i := 0; i < 100; i++ { diff --git a/examples/server.go b/examples/server.go index 6758b55..2c4dcb5 100644 --- a/examples/server.go +++ b/examples/server.go @@ -2,13 +2,13 @@ package main import ( "context" + mq2 "github.com/oarkflow/mq" "github.com/oarkflow/mq/examples/tasks" - mq "github.com/oarkflow/mq/v2" ) func main() { - b := mq.NewBroker(mq.WithCallback(tasks.Callback)) + b := mq2.NewBroker(mq2.WithCallback(tasks.Callback)) // b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) b.NewQueue("queue1") b.NewQueue("queue2") diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index 186efb3..014871d 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -4,42 +4,41 @@ import ( "context" "encoding/json" "fmt" - - mq "github.com/oarkflow/mq/v2" + mq2 "github.com/oarkflow/mq" ) -func Node1(ctx context.Context, task mq.Task) mq.Result { +func Node1(ctx context.Context, task mq2.Task) mq2.Result { fmt.Println("Processing queue1", task.ID) - return mq.Result{Payload: task.Payload, MessageID: task.ID} + return mq2.Result{Payload: task.Payload, MessageID: task.ID} } -func Node2(ctx context.Context, task mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, MessageID: task.ID} +func Node2(ctx context.Context, task mq2.Task) mq2.Result { + return mq2.Result{Payload: task.Payload, MessageID: task.ID} } -func Node3(ctx context.Context, task mq.Task) mq.Result { +func Node3(ctx context.Context, task mq2.Task) mq2.Result { var data map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { - return mq.Result{Error: err} + return mq2.Result{Error: err} } data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) bt, _ := json.Marshal(data) - return mq.Result{Payload: bt, MessageID: task.ID} + return mq2.Result{Payload: bt, MessageID: task.ID} } -func Node4(ctx context.Context, task mq.Task) mq.Result { +func Node4(ctx context.Context, task mq2.Task) mq2.Result { var data []map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { - return mq.Result{Error: err} + return mq2.Result{Error: err} } payload := map[string]any{"storage": data} bt, _ := json.Marshal(payload) - return mq.Result{Payload: bt, MessageID: task.ID} + return mq2.Result{Payload: bt, MessageID: task.ID} } -func Callback(ctx context.Context, task mq.Result) mq.Result { +func Callback(ctx context.Context, task mq2.Result) mq2.Result { fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue) - return mq.Result{} + return mq2.Result{} } diff --git a/options.go b/options.go index 206d6c9..76ea6ea 100644 --- a/options.go +++ b/options.go @@ -2,9 +2,18 @@ package mq import ( "context" + "encoding/json" "time" ) +type Result struct { + Payload json.RawMessage `json:"payload"` + Queue string `json:"queue"` + MessageID string `json:"message_id"` + Error error `json:"error,omitempty"` + Status string `json:"status"` +} + type TLSConfig struct { UseTLS bool CertPath string @@ -13,17 +22,18 @@ type TLSConfig struct { } type Options struct { - syncMode bool - brokerAddr string - messageHandler MessageHandler - closeHandler CloseHandler - errorHandler ErrorHandler - callback []func(context.Context, Result) Result - maxRetries int - initialDelay time.Duration - maxBackoff time.Duration - jitterPercent float64 - tlsConfig TLSConfig + syncMode bool + brokerAddr string + callback []func(context.Context, Result) Result + maxRetries int + initialDelay time.Duration + maxBackoff time.Duration + jitterPercent float64 + tlsConfig TLSConfig + aesKey json.RawMessage + hmacKey json.RawMessage + enableEncryption bool + queueSize int } func defaultOptions() Options { @@ -34,27 +44,29 @@ func defaultOptions() Options { initialDelay: 2 * time.Second, maxBackoff: 20 * time.Second, jitterPercent: 0.5, + queueSize: 100, } } -func defaultHandlers(options Options, onMessage MessageHandler, onClose CloseHandler, onError ErrorHandler) Options { - if options.messageHandler == nil { - options.messageHandler = onMessage - } - - if options.closeHandler == nil { - options.closeHandler = onClose - } - - if options.errorHandler == nil { - options.errorHandler = onError - } - return options -} - // Option defines a function type for setting options. type Option func(*Options) +func setupOptions(opts ...Option) Options { + options := defaultOptions() + for _, opt := range opts { + opt(&options) + } + return options +} + +func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { + return func(opts *Options) { + opts.aesKey = aesKey + opts.hmacKey = hmacKey + opts.enableEncryption = enableEncryption + } +} + // WithBrokerURL - func WithBrokerURL(url string) Option { return func(opts *Options) { @@ -119,24 +131,3 @@ func WithJitterPercent(val float64) Option { opts.jitterPercent = val } } - -// WithMessageHandler sets a custom MessageHandler. -func WithMessageHandler(handler MessageHandler) Option { - return func(opts *Options) { - opts.messageHandler = handler - } -} - -// WithErrorHandler sets a custom ErrorHandler. -func WithErrorHandler(handler ErrorHandler) Option { - return func(opts *Options) { - opts.errorHandler = handler - } -} - -// WithCloseHandler sets a custom CloseHandler. -func WithCloseHandler(handler CloseHandler) Option { - return func(opts *Options) { - opts.closeHandler = handler - } -} diff --git a/publisher.go b/publisher.go index 3dc040c..41c2b04 100644 --- a/publisher.go +++ b/publisher.go @@ -4,9 +4,13 @@ import ( "context" "encoding/json" "fmt" + "log" "net" + "time" + "github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/jsonparser" ) type Publisher struct { @@ -15,31 +19,59 @@ type Publisher struct { } func NewPublisher(id string, opts ...Option) *Publisher { - options := defaultOptions() - for _, opt := range opts { - opt(&options) - } - b := &Publisher{id: id} - b.opts = defaultHandlers(options, nil, b.onClose, b.onError) - return b + options := setupOptions(opts...) + return &Publisher{id: id, opts: options} } func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error { - ctx = SetHeaders(ctx, map[string]string{ + headers := WithHeaders(ctx, map[string]string{ consts.PublisherKey: p.id, consts.ContentType: consts.TypeJson, }) - cmd := Command{ - ID: NewID(), - Command: command, - Queue: queue, - MessageID: task.ID, - Payload: task.Payload, + if task.ID == "" { + task.ID = NewID() } - return Write(ctx, conn, cmd) + task.CreatedAt = time.Now() + payload, err := json.Marshal(task) + if err != nil { + return err + } + msg := codec.NewMessage(command, payload, queue, headers) + if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil { + return err + } + + return p.waitForAck(conn) } -func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error { +func (p *Publisher) waitForAck(conn net.Conn) error { + msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption) + if err != nil { + return err + } + if msg.Command == consts.PUBLISH_ACK { + taskID, _ := jsonparser.GetString(msg.Payload, "id") + log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID) + return nil + } + return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command) +} + +func (p *Publisher) waitForResponse(conn net.Conn) Result { + msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption) + if err != nil { + return Result{Error: err} + } + if msg.Command == consts.RESPONSE { + var result Result + err = json.Unmarshal(msg.Payload, &result) + return result + } + err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command) + return Result{Error: err} +} + +func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error { conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) if err != nil { return fmt.Errorf("failed to connect to broker: %w", err) @@ -57,30 +89,22 @@ func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) { fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr()) } -func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Result, error) { +func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result { + ctx = SetHeaders(ctx, map[string]string{ + consts.AwaitResponseKey: "true", + }) conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) if err != nil { - return Result{Error: err}, fmt.Errorf("failed to connect to broker: %w", err) + err = fmt.Errorf("failed to connect to broker: %w", err) + return Result{Error: err} } defer conn.Close() - var result Result - err = p.send(ctx, queue, task, conn, consts.REQUEST) - if err != nil { - return result, err - } - if p.opts.messageHandler == nil { - p.opts.messageHandler = func(ctx context.Context, conn net.Conn, message []byte) error { - err := json.Unmarshal(message, &result) - if err != nil { - return err - } - return conn.Close() - } - } - ReadFromConn(ctx, conn, Handlers{ - MessageHandler: p.opts.messageHandler, - CloseHandler: p.opts.closeHandler, - ErrorHandler: p.opts.errorHandler, - }) - return result, nil + err = p.send(ctx, queue, task, conn, consts.PUBLISH) + resultCh := make(chan Result) + go func() { + defer close(resultCh) + resultCh <- p.waitForResponse(conn) + }() + finalResult := <-resultCh + return finalResult } diff --git a/v2/queue.go b/queue.go similarity index 98% rename from v2/queue.go rename to queue.go index 8edb1c7..bde4805 100644 --- a/v2/queue.go +++ b/queue.go @@ -1,4 +1,4 @@ -package v2 +package mq import ( "github.com/oarkflow/xsync" diff --git a/v2/util.go b/util.go similarity index 97% rename from v2/util.go rename to util.go index bf04aea..01298fd 100644 --- a/v2/util.go +++ b/util.go @@ -1,4 +1,4 @@ -package v2 +package mq import ( "context" diff --git a/v2/broker.go b/v2/broker.go deleted file mode 100644 index 648514c..0000000 --- a/v2/broker.go +++ /dev/null @@ -1,323 +0,0 @@ -package v2 - -import ( - "context" - "crypto/tls" - "fmt" - "log" - "net" - "strings" - "time" - - "github.com/oarkflow/xsync" - - "github.com/oarkflow/mq/codec" - "github.com/oarkflow/mq/consts" - "github.com/oarkflow/mq/jsonparser" - "github.com/oarkflow/mq/utils" -) - -type QueuedTask struct { - Message *codec.Message - RetryCount int -} - -type consumer struct { - id string - conn net.Conn -} - -type publisher struct { - id string - conn net.Conn -} - -type Broker struct { - queues xsync.IMap[string, *Queue] - consumers xsync.IMap[string, *consumer] - publishers xsync.IMap[string, *publisher] - opts Options -} - -func NewBroker(opts ...Option) *Broker { - options := setupOptions(opts...) - return &Broker{ - queues: xsync.NewMap[string, *Queue](), - publishers: xsync.NewMap[string, *publisher](), - consumers: xsync.NewMap[string, *consumer](), - opts: options, - } -} - -func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error { - consumerID, ok := GetConsumerID(ctx) - if ok && consumerID != "" { - if con, exists := b.consumers.Get(consumerID); exists { - con.conn.Close() - b.consumers.Del(consumerID) - } - b.queues.ForEach(func(_ string, queue *Queue) bool { - queue.consumers.Del(consumerID) - return true - }) - } - publisherID, ok := GetPublisherID(ctx) - if ok && publisherID != "" { - if con, exists := b.publishers.Get(publisherID); exists { - con.conn.Close() - b.publishers.Del(publisherID) - } - } - return nil -} - -func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) { - fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) -} - -func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { - switch msg.Command { - case consts.PUBLISH: - b.PublishHandler(ctx, conn, msg) - case consts.SUBSCRIBE: - b.SubscribeHandler(ctx, conn, msg) - case consts.MESSAGE_RESPONSE: - b.MessageResponseHandler(ctx, msg) - case consts.MESSAGE_ACK: - b.MessageAck(ctx, msg) - } -} - -func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) { - consumerID, _ := GetConsumerID(ctx) - taskID, _ := jsonparser.GetString(msg.Payload, "id") - log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID) -} - -func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) { - msg.Command = consts.RESPONSE - headers, ok := GetHeaders(ctx) - if !ok { - return - } - b.HandleCallback(ctx, msg) - awaitResponse, ok := headers[consts.AwaitResponseKey] - if !(ok && awaitResponse == "true") { - return - } - publisherID, exists := headers[consts.PublisherKey] - if !exists { - return - } - con, ok := b.publishers.Get(publisherID) - if !ok { - return - } - err := b.send(con.conn, msg) - if err != nil { - panic(err) - } -} - -func (b *Broker) Publish(ctx context.Context, task Task, queue string) error { - headers, _ := GetHeaders(ctx) - msg := codec.NewMessage(consts.PUBLISH, task.Payload, queue, headers) - b.broadcastToConsumers(ctx, msg) - return nil -} - -func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { - pub := b.addPublisher(ctx, msg.Queue, conn) - taskID, _ := jsonparser.GetString(msg.Payload, "id") - log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID) - - ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) - if err := b.send(conn, ack); err != nil { - log.Printf("Error sending PUBLISH_ACK: %v\n", err) - } - b.broadcastToConsumers(ctx, msg) - go func() { - select { - case <-ctx.Done(): - b.publishers.Del(pub.id) - } - }() -} - -func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { - consumerID := b.addConsumer(ctx, msg.Queue, conn) - ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers) - if err := b.send(conn, ack); err != nil { - log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) - } - go func() { - select { - case <-ctx.Done(): - b.removeConsumer(msg.Queue, consumerID) - } - }() -} - -func (b *Broker) Start(ctx context.Context) error { - var listener net.Listener - var err error - - if b.opts.tlsConfig.UseTLS { - cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath) - if err != nil { - return fmt.Errorf("failed to load TLS certificates: %v", err) - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig) - if err != nil { - return fmt.Errorf("failed to start TLS listener: %v", err) - } - log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr) - } else { - listener, err = net.Listen("tcp", b.opts.brokerAddr) - if err != nil { - return fmt.Errorf("failed to start TCP listener: %v", err) - } - log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr) - } - defer listener.Close() - for { - conn, err := listener.Accept() - if err != nil { - b.OnError(ctx, conn, err) - continue - } - go func(c net.Conn) { - defer c.Close() - for { - err := b.readMessage(ctx, c) - if err != nil { - break - } - } - }(conn) - } -} - -func (b *Broker) send(conn net.Conn, msg *codec.Message) error { - return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) -} - -func (b *Broker) receive(c net.Conn) (*codec.Message, error) { - return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption) -} - -func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) { - if queue, ok := b.queues.Get(msg.Queue); ok { - task := &QueuedTask{Message: msg, RetryCount: 0} - queue.tasks <- task - log.Printf("Task enqueued for queue %s", msg.Queue) - } -} - -func (b *Broker) waitForConsumerAck(conn net.Conn) error { - msg, err := b.receive(conn) - if err != nil { - return err - } - if msg.Command == consts.MESSAGE_ACK { - log.Println("Received CONSUMER_ACK: Subscribed successfully") - return nil - } - return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command) -} - -func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher { - publisherID, ok := GetPublisherID(ctx) - _, ok = b.queues.Get(queueName) - if !ok { - b.NewQueue(queueName) - } - con := &publisher{id: publisherID, conn: conn} - b.publishers.Set(publisherID, con) - return con -} - -func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string { - consumerID, ok := GetConsumerID(ctx) - q, ok := b.queues.Get(queueName) - if !ok { - q = b.NewQueue(queueName) - } - con := &consumer{id: consumerID, conn: conn} - b.consumers.Set(consumerID, con) - q.consumers.Set(consumerID, con) - log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName) - return consumerID -} - -func (b *Broker) removeConsumer(queueName, consumerID string) { - if queue, ok := b.queues.Get(queueName); ok { - con, ok := queue.consumers.Get(consumerID) - if ok { - con.conn.Close() - queue.consumers.Del(consumerID) - } - b.queues.Del(queueName) - } -} - -func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { - msg, err := b.receive(c) - if err == nil { - ctx = SetHeaders(ctx, msg.Headers) - b.OnMessage(ctx, msg, c) - return nil - } - if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { - b.OnClose(ctx, c) - return err - } - b.OnError(ctx, c, err) - return err -} - -func (b *Broker) dispatchWorker(queue *Queue) { - delay := b.opts.initialDelay - for task := range queue.tasks { - success := false - for !success && task.RetryCount <= b.opts.maxRetries { - if b.dispatchTaskToConsumer(queue, task) { - success = true - } else { - task.RetryCount++ - delay = b.backoffRetry(queue, task, delay) - } - } - } -} - -func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool { - var consumerFound bool - queue.consumers.ForEach(func(_ string, con *consumer) bool { - if err := b.send(con.conn, task.Message); err == nil { - consumerFound = true - log.Printf("Task dispatched to consumer %s on queue %s", con.id, queue.name) - return false // break the loop once a consumer is found - } - return true - }) - if !consumerFound { - log.Printf("No available consumers for queue %s, retrying...", queue.name) - } - return consumerFound -} - -func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration { - backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent) - log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue) - time.Sleep(backoffDuration) - queue.tasks <- task - delay *= 2 - if delay > b.opts.maxBackoff { - delay = b.opts.maxBackoff - } - return delay -} diff --git a/v2/consumer.go b/v2/consumer.go deleted file mode 100644 index 7410a5a..0000000 --- a/v2/consumer.go +++ /dev/null @@ -1,197 +0,0 @@ -package v2 - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log" - "net" - "strings" - "sync" - "time" - - "github.com/oarkflow/mq/codec" - "github.com/oarkflow/mq/consts" - "github.com/oarkflow/mq/jsonparser" - "github.com/oarkflow/mq/utils" -) - -// Consumer structure to hold consumer-specific configurations and state. -type Consumer struct { - id string - handlers map[string]Handler - conn net.Conn - queues []string - opts Options -} - -// NewConsumer initializes a new consumer with the provided options. -func NewConsumer(id string, opts ...Option) *Consumer { - options := setupOptions(opts...) - return &Consumer{ - handlers: make(map[string]Handler), - id: id, - opts: options, - } -} - -func (c *Consumer) send(conn net.Conn, msg *codec.Message) error { - return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) -} - -func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) { - return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption) -} - -// Close closes the consumer's connection. -func (c *Consumer) Close() error { - return c.conn.Close() -} - -// Subscribe to a specific queue. -func (c *Consumer) subscribe(ctx context.Context, queue string) error { - headers := WithHeaders(ctx, map[string]string{ - consts.ConsumerKey: c.id, - consts.ContentType: consts.TypeJson, - }) - msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers) - if err := c.send(c.conn, msg); err != nil { - return err - } - - return c.waitForAck(c.conn) -} - -func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error { - fmt.Println("Consumer closed") - return nil -} - -func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) { - fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) -} - -func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { - headers := WithHeaders(ctx, map[string]string{ - consts.ConsumerKey: c.id, - consts.ContentType: consts.TypeJson, - }) - taskID, _ := jsonparser.GetString(msg.Payload, "id") - reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) - if err := c.send(conn, reply); err != nil { - fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err) - } - var task Task - err := json.Unmarshal(msg.Payload, &task) - if err != nil { - log.Println("Error unmarshalling message:", err) - return - } - ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) - result := c.ProcessTask(ctx, task) - result.MessageID = task.ID - result.Queue = msg.Queue - if result.Error != nil { - result.Status = "FAILED" - } else { - result.Status = "SUCCESS" - } - bt, _ := json.Marshal(result) - reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers) - if err := c.send(conn, reply); err != nil { - fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err) - } -} - -// ProcessTask handles a received task message and invokes the appropriate handler. -func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { - queue, _ := GetQueue(ctx) - handler, exists := c.handlers[queue] - if !exists { - return Result{Error: errors.New("No handler for queue " + queue)} - } - return handler(ctx, msg) -} - -// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration. -func (c *Consumer) AttemptConnect() error { - var err error - delay := c.opts.initialDelay - for i := 0; i < c.opts.maxRetries; i++ { - conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig) - if err == nil { - c.conn = conn - return nil - } - sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent) - log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration) - time.Sleep(sleepDuration) - delay *= 2 - if delay > c.opts.maxBackoff { - delay = c.opts.maxBackoff - } - } - - return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err) -} - -func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error { - msg, err := c.receive(conn) - if err == nil { - ctx = SetHeaders(ctx, msg.Headers) - c.OnMessage(ctx, msg, conn) - return nil - } - if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { - c.OnClose(ctx, conn) - return err - } - c.OnError(ctx, conn, err) - return err -} - -// Consume starts the consumer to consume tasks from the queues. -func (c *Consumer) Consume(ctx context.Context) error { - err := c.AttemptConnect() - if err != nil { - return err - } - for _, q := range c.queues { - if err := c.subscribe(ctx, q); err != nil { - return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) - } - } - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - for { - if err := c.readMessage(ctx, c.conn); err != nil { - log.Println("Error reading message:", err) - break - } - } - }() - - wg.Wait() - return nil -} - -func (c *Consumer) waitForAck(conn net.Conn) error { - msg, err := c.receive(conn) - if err != nil { - return err - } - if msg.Command == consts.SUBSCRIBE_ACK { - log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue) - return nil - } - return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) -} - -// RegisterHandler registers a handler for a queue. -func (c *Consumer) RegisterHandler(queue string, handler Handler) { - c.queues = append(c.queues, queue) - c.handlers[queue] = handler -} diff --git a/v2/ctx.go b/v2/ctx.go deleted file mode 100644 index 75476bd..0000000 --- a/v2/ctx.go +++ /dev/null @@ -1,158 +0,0 @@ -package v2 - -import ( - "context" - "crypto/tls" - "crypto/x509" - "encoding/json" - "fmt" - "net" - "os" - "time" - - "github.com/oarkflow/xid" - - "github.com/oarkflow/mq/consts" -) - -type Task struct { - ID string `json:"id"` - Payload json.RawMessage `json:"payload"` - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - Status string `json:"status"` - Error error `json:"error"` -} - -type Handler func(context.Context, Task) Result - -func IsClosed(conn net.Conn) bool { - _, err := conn.Read(make([]byte, 1)) - if err != nil { - if err == net.ErrClosed { - return true - } - } - return false -} - -func SetHeaders(ctx context.Context, headers map[string]string) context.Context { - hd, ok := GetHeaders(ctx) - if !ok { - hd = make(map[string]string) - } - for key, val := range headers { - hd[key] = val - } - return context.WithValue(ctx, consts.HeaderKey, hd) -} - -func WithHeaders(ctx context.Context, headers map[string]string) map[string]string { - hd, ok := GetHeaders(ctx) - if !ok { - hd = make(map[string]string) - } - for key, val := range headers { - hd[key] = val - } - return hd -} - -func GetHeaders(ctx context.Context) (map[string]string, bool) { - headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) - return headers, ok -} - -func GetHeader(ctx context.Context, key string) (string, bool) { - headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) - if !ok { - return "", false - } - val, ok := headers[key] - return val, ok -} - -func GetContentType(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.ContentType] - return contentType, ok -} - -func GetQueue(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.QueueKey] - return contentType, ok -} - -func GetConsumerID(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.ConsumerKey] - return contentType, ok -} - -func GetTriggerNode(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.TriggerNode] - return contentType, ok -} - -func GetPublisherID(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.PublisherKey] - return contentType, ok -} - -func NewID() string { - return xid.New().String() -} - -func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) { - cert, err := tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return nil, fmt.Errorf("failed to load client cert/key: %w", err) - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - ClientAuth: tls.RequireAndVerifyClientCert, - InsecureSkipVerify: true, - } - if len(caPath) > 0 && caPath[0] != "" { - caCert, err := os.ReadFile(caPath[0]) - if err != nil { - return nil, fmt.Errorf("failed to load CA cert: %w", err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - tlsConfig.RootCAs = caCertPool - tlsConfig.ClientCAs = caCertPool - } - conn, err := tls.Dial("tcp", addr, tlsConfig) - if err != nil { - return nil, fmt.Errorf("failed to dial TLS connection: %w", err) - } - - return conn, nil -} - -func GetConnection(addr string, config TLSConfig) (net.Conn, error) { - if config.UseTLS { - return createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath) - } else { - return net.Dial("tcp", addr) - } -} diff --git a/v2/options.go b/v2/options.go deleted file mode 100644 index eac7f53..0000000 --- a/v2/options.go +++ /dev/null @@ -1,133 +0,0 @@ -package v2 - -import ( - "context" - "encoding/json" - "time" -) - -type Result struct { - Payload json.RawMessage `json:"payload"` - Queue string `json:"queue"` - MessageID string `json:"message_id"` - Error error `json:"error,omitempty"` - Status string `json:"status"` -} - -type TLSConfig struct { - UseTLS bool - CertPath string - KeyPath string - CAPath string -} - -type Options struct { - syncMode bool - brokerAddr string - callback []func(context.Context, Result) Result - maxRetries int - initialDelay time.Duration - maxBackoff time.Duration - jitterPercent float64 - tlsConfig TLSConfig - aesKey json.RawMessage - hmacKey json.RawMessage - enableEncryption bool - queueSize int -} - -func defaultOptions() Options { - return Options{ - syncMode: false, - brokerAddr: ":8080", - maxRetries: 5, - initialDelay: 2 * time.Second, - maxBackoff: 20 * time.Second, - jitterPercent: 0.5, - queueSize: 100, - } -} - -// Option defines a function type for setting options. -type Option func(*Options) - -func setupOptions(opts ...Option) Options { - options := defaultOptions() - for _, opt := range opts { - opt(&options) - } - return options -} - -func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { - return func(opts *Options) { - opts.aesKey = aesKey - opts.hmacKey = hmacKey - opts.enableEncryption = enableEncryption - } -} - -// WithBrokerURL - -func WithBrokerURL(url string) Option { - return func(opts *Options) { - opts.brokerAddr = url - } -} - -// WithTLS - Option to enable/disable TLS -func WithTLS(enableTLS bool, certPath, keyPath string) Option { - return func(o *Options) { - o.tlsConfig.UseTLS = enableTLS - o.tlsConfig.CertPath = certPath - o.tlsConfig.KeyPath = keyPath - } -} - -// WithCAPath - Option to enable/disable TLS -func WithCAPath(caPath string) Option { - return func(o *Options) { - o.tlsConfig.CAPath = caPath - } -} - -// WithSyncMode - -func WithSyncMode(mode bool) Option { - return func(opts *Options) { - opts.syncMode = mode - } -} - -// WithMaxRetries - -func WithMaxRetries(val int) Option { - return func(opts *Options) { - opts.maxRetries = val - } -} - -// WithInitialDelay - -func WithInitialDelay(val time.Duration) Option { - return func(opts *Options) { - opts.initialDelay = val - } -} - -// WithMaxBackoff - -func WithMaxBackoff(val time.Duration) Option { - return func(opts *Options) { - opts.maxBackoff = val - } -} - -// WithCallback - -func WithCallback(val ...func(context.Context, Result) Result) Option { - return func(opts *Options) { - opts.callback = val - } -} - -// WithJitterPercent - -func WithJitterPercent(val float64) Option { - return func(opts *Options) { - opts.jitterPercent = val - } -} diff --git a/v2/publisher.go b/v2/publisher.go deleted file mode 100644 index ed16c58..0000000 --- a/v2/publisher.go +++ /dev/null @@ -1,110 +0,0 @@ -package v2 - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net" - "time" - - "github.com/oarkflow/mq/codec" - "github.com/oarkflow/mq/consts" - "github.com/oarkflow/mq/jsonparser" -) - -type Publisher struct { - id string - opts Options -} - -func NewPublisher(id string, opts ...Option) *Publisher { - options := setupOptions(opts...) - return &Publisher{id: id, opts: options} -} - -func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error { - headers := WithHeaders(ctx, map[string]string{ - consts.PublisherKey: p.id, - consts.ContentType: consts.TypeJson, - }) - if task.ID == "" { - task.ID = NewID() - } - task.CreatedAt = time.Now() - payload, err := json.Marshal(task) - if err != nil { - return err - } - msg := codec.NewMessage(command, payload, queue, headers) - if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil { - return err - } - - return p.waitForAck(conn) -} - -func (p *Publisher) waitForAck(conn net.Conn) error { - msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption) - if err != nil { - return err - } - if msg.Command == consts.PUBLISH_ACK { - taskID, _ := jsonparser.GetString(msg.Payload, "id") - log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID) - return nil - } - return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command) -} - -func (p *Publisher) waitForResponse(conn net.Conn) Result { - msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption) - if err != nil { - return Result{Error: err} - } - if msg.Command == consts.RESPONSE { - var result Result - err = json.Unmarshal(msg.Payload, &result) - return result - } - err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command) - return Result{Error: err} -} - -func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error { - conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) - if err != nil { - return fmt.Errorf("failed to connect to broker: %w", err) - } - defer conn.Close() - return p.send(ctx, queue, task, conn, consts.PUBLISH) -} - -func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error { - fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr()) - return nil -} - -func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) { - fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr()) -} - -func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result { - ctx = SetHeaders(ctx, map[string]string{ - consts.AwaitResponseKey: "true", - }) - conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) - if err != nil { - err = fmt.Errorf("failed to connect to broker: %w", err) - return Result{Error: err} - } - defer conn.Close() - err = p.send(ctx, queue, task, conn, consts.PUBLISH) - resultCh := make(chan Result) - go func() { - defer close(resultCh) - resultCh <- p.waitForResponse(conn) - }() - finalResult := <-resultCh - return finalResult -}