diff --git a/consumer.go b/consumer.go new file mode 100644 index 0000000..a24cd95 --- /dev/null +++ b/consumer.go @@ -0,0 +1,322 @@ +package mq + +import ( + "context" + "fmt" + "log" + "net" + "strings" + "time" + + "github.com/oarkflow/json" + + "github.com/oarkflow/mq/codec" + "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/jsonparser" + "github.com/oarkflow/mq/utils" +) + +type Processor interface { + ProcessTask(ctx context.Context, msg *Task) Result + Consume(ctx context.Context) error + Pause(ctx context.Context) error + Resume(ctx context.Context) error + Stop(ctx context.Context) error + Close() error + GetKey() string + SetKey(key string) + GetType() string +} + +type Consumer struct { + conn net.Conn + handler Handler + pool *Pool + opts *Options + id string + queue string +} + +func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer { + options := SetupOptions(opts...) + return &Consumer{ + id: id, + opts: options, + queue: queue, + handler: handler, + } +} + +func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error { + return codec.SendMessage(ctx, conn, msg) +} + +func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) { + return codec.ReadMessage(ctx, conn) +} + +func (c *Consumer) Close() error { + c.pool.Stop() + err := c.conn.Close() + log.Printf("CONSUMER - Connection closed for consumer: %s", c.id) + return err +} + +func (c *Consumer) GetKey() string { + return c.id +} + +func (c *Consumer) GetType() string { + return "consumer" +} + +func (c *Consumer) SetKey(key string) { + c.id = key +} + +func (c *Consumer) Metrics() Metrics { + return c.pool.Metrics() +} + +func (c *Consumer) subscribe(ctx context.Context, queue string) error { + headers := HeadersWithConsumerID(ctx, c.id) + msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers) + if err := c.send(ctx, c.conn, msg); err != nil { + return fmt.Errorf("error while trying to subscribe: %v", err) + } + return c.waitForAck(ctx, c.conn) +} + +func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error { + fmt.Println("Consumer closed") + return nil +} + +func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) { + fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) +} + +func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) error { + switch msg.Command { + case consts.PUBLISH: + c.ConsumeMessage(ctx, msg, conn) + case consts.CONSUMER_PAUSE: + err := c.Pause(ctx) + if err != nil { + log.Printf("Unable to pause consumer: %v", err) + } + return err + case consts.CONSUMER_RESUME: + err := c.Resume(ctx) + if err != nil { + log.Printf("Unable to resume consumer: %v", err) + } + return err + case consts.CONSUMER_STOP: + err := c.Stop(ctx) + if err != nil { + log.Printf("Unable to stop consumer: %v", err) + } + return err + default: + log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue) + } + return nil +} + +func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) { + headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue) + taskID, _ := jsonparser.GetString(msg.Payload, "id") + reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) + if err := c.send(ctx, conn, reply); err != nil { + fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err) + } +} + +func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + c.sendMessageAck(ctx, msg, conn) + if msg.Payload == nil { + log.Printf("Received empty message payload") + return + } + var task Task + err := json.Unmarshal(msg.Payload, &task) + if err != nil { + log.Printf("Error unmarshalling message: %v", err) + return + } + ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) + if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil { + c.sendDenyMessage(ctx, task.ID, msg.Queue, err) + return + } +} + +func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result { + defer RecoverPanic(RecoverTitle) + queue, _ := GetQueue(ctx) + if msg.Topic == "" && queue != "" { + msg.Topic = queue + } + result := c.handler(ctx, msg) + result.Topic = msg.Topic + result.TaskID = msg.ID + return result +} + +func (c *Consumer) OnResponse(ctx context.Context, result Result) error { + if result.Status == "PENDING" && c.opts.respondPendingResult { + return nil + } + headers := HeadersWithConsumerIDAndQueue(ctx, c.id, result.Topic) + if result.Status == "" { + if result.Error != nil { + result.Status = "FAILED" + } else { + result.Status = "SUCCESS" + } + } + bt, _ := json.Marshal(result) + reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers) + if err := c.send(ctx, c.conn, reply); err != nil { + return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err) + } + return nil +} + +func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) { + headers := HeadersWithConsumerID(ctx, c.id) + reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers) + if sendErr := c.send(ctx, c.conn, reply); sendErr != nil { + log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr) + } +} + +func (c *Consumer) attemptConnect() error { + var err error + delay := c.opts.initialDelay + for i := 0; i < c.opts.maxRetries; i++ { + conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig) + if err == nil { + c.conn = conn + return nil + } + sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent) + log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration) + time.Sleep(sleepDuration) + delay *= 2 + if delay > c.opts.maxBackoff { + delay = c.opts.maxBackoff + } + } + + return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err) +} + +func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error { + msg, err := c.receive(ctx, conn) + if err == nil { + ctx = SetHeaders(ctx, msg.Headers) + return c.OnMessage(ctx, msg, conn) + } + if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { + err1 := c.OnClose(ctx, conn) + if err1 != nil { + return err1 + } + return err + } + c.OnError(ctx, conn, err) + return err +} + +func (c *Consumer) Consume(ctx context.Context) error { + err := c.attemptConnect() + if err != nil { + return err + } + c.pool = NewPool( + c.opts.numOfWorkers, + WithTaskQueueSize(c.opts.queueSize), + WithMaxMemoryLoad(c.opts.maxMemoryLoad), + WithHandler(c.ProcessTask), + WithPoolCallback(c.OnResponse), + WithTaskStorage(c.opts.storage), + ) + if err := c.subscribe(ctx, c.queue); err != nil { + return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) + } + c.pool.Start(c.opts.numOfWorkers) + // Infinite loop to continuously read messages and reconnect if needed. + for { + select { + case <-ctx.Done(): + log.Println("Context canceled, stopping consumer.") + return nil + default: + if c.opts.ConsumerRateLimiter != nil { + c.opts.ConsumerRateLimiter.Wait() + } + if err := c.readMessage(ctx, c.conn); err != nil { + log.Printf("Error reading message: %v, attempting reconnection...", err) + for { + if ctx.Err() != nil { + return nil + } + if rErr := c.attemptConnect(); rErr != nil { + log.Printf("Reconnection attempt failed: %v", rErr) + time.Sleep(c.opts.initialDelay) + } else { + break + } + } + if err := c.subscribe(ctx, c.queue); err != nil { + log.Printf("Failed to re-subscribe on reconnection: %v", err) + time.Sleep(c.opts.initialDelay) + } + } + } + } +} + +func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error { + msg, err := c.receive(ctx, conn) + if err != nil { + return err + } + if msg.Command == consts.SUBSCRIBE_ACK { + log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue) + return nil + } + return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) +} + +func (c *Consumer) Pause(ctx context.Context) error { + return c.operate(ctx, consts.CONSUMER_PAUSED, c.pool.Pause) +} + +func (c *Consumer) Resume(ctx context.Context) error { + return c.operate(ctx, consts.CONSUMER_RESUMED, c.pool.Resume) +} + +func (c *Consumer) Stop(ctx context.Context) error { + return c.operate(ctx, consts.CONSUMER_STOPPED, c.pool.Stop) +} + +func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation func()) error { + if err := c.sendOpsMessage(ctx, cmd); err != nil { + return err + } + poolOperation() + return nil +} + +func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error { + headers := HeadersWithConsumerID(ctx, c.id) + msg := codec.NewMessage(cmd, nil, c.queue, headers) + return c.send(ctx, c.conn, msg) +} + +func (c *Consumer) Conn() net.Conn { + return c.conn +} diff --git a/mq.go b/mq.go index 443cc00..eedccaf 100644 --- a/mq.go +++ b/mq.go @@ -7,7 +7,6 @@ import ( "log" "net" "strings" - "sync" "time" "github.com/oarkflow/errors" @@ -327,6 +326,12 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con } } +func (b *Broker) AdjustConsumerWorkers(noOfWorkers int, consumerID ...string) { + b.consumers.ForEach(func(_ string, c *consumer) bool { + return true + }) +} + func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) { consumerID, _ := GetConsumerID(ctx) taskID, _ := jsonparser.GetString(msg.Payload, "id") @@ -577,11 +582,11 @@ func (b *Broker) RemoveConsumer(consumerID string, queues ...string) { }) } -func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) { +func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, payload []byte, queues ...string) { fn := func(queue *Queue) { con, ok := queue.consumers.Get(consumerID) if ok { - ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID}) + ack := codec.NewMessage(cmd, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID}) err := b.send(ctx, con.conn, ack) if err == nil { con.state = state @@ -603,15 +608,15 @@ func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state const } func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) { - b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...) + b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, utils.ToByte("{}"), queues...) } func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) { - b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...) + b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, utils.ToByte("{}"), queues...) } func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) { - b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...) + b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, utils.ToByte("{}"), queues...) } func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { @@ -721,486 +726,6 @@ func (b *Broker) SetURL(url string) { b.opts.brokerAddr = url } -type Processor interface { - ProcessTask(ctx context.Context, msg *Task) Result - Consume(ctx context.Context) error - Pause(ctx context.Context) error - Resume(ctx context.Context) error - Stop(ctx context.Context) error - Close() error - GetKey() string - SetKey(key string) - GetType() string -} - -type Consumer struct { - conn net.Conn - handler Handler - pool *Pool - opts *Options - id string - queue string -} - -func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer { - options := SetupOptions(opts...) - return &Consumer{ - id: id, - opts: options, - queue: queue, - handler: handler, - } -} - -func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error { - return codec.SendMessage(ctx, conn, msg) -} - -func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) { - return codec.ReadMessage(ctx, conn) -} - -func (c *Consumer) Close() error { - c.pool.Stop() - err := c.conn.Close() - log.Printf("CONSUMER - Connection closed for consumer: %s", c.id) - return err -} - -func (c *Consumer) GetKey() string { - return c.id -} - -func (c *Consumer) GetType() string { - return "consumer" -} - -func (c *Consumer) SetKey(key string) { - c.id = key -} - -func (c *Consumer) Metrics() Metrics { - return c.pool.Metrics() -} - -func (c *Consumer) subscribe(ctx context.Context, queue string) error { - headers := HeadersWithConsumerID(ctx, c.id) - msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers) - if err := c.send(ctx, c.conn, msg); err != nil { - return fmt.Errorf("error while trying to subscribe: %v", err) - } - return c.waitForAck(ctx, c.conn) -} - -func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error { - fmt.Println("Consumer closed") - return nil -} - -func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) { - fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) -} - -func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) error { - switch msg.Command { - case consts.PUBLISH: - c.ConsumeMessage(ctx, msg, conn) - case consts.CONSUMER_PAUSE: - err := c.Pause(ctx) - if err != nil { - log.Printf("Unable to pause consumer: %v", err) - } - return err - case consts.CONSUMER_RESUME: - err := c.Resume(ctx) - if err != nil { - log.Printf("Unable to resume consumer: %v", err) - } - return err - case consts.CONSUMER_STOP: - err := c.Stop(ctx) - if err != nil { - log.Printf("Unable to stop consumer: %v", err) - } - return err - default: - log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue) - } - return nil -} - -func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) { - headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue) - taskID, _ := jsonparser.GetString(msg.Payload, "id") - reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) - if err := c.send(ctx, conn, reply); err != nil { - fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err) - } -} - -func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { - c.sendMessageAck(ctx, msg, conn) - if msg.Payload == nil { - log.Printf("Received empty message payload") - return - } - var task Task - err := json.Unmarshal(msg.Payload, &task) - if err != nil { - log.Printf("Error unmarshalling message: %v", err) - return - } - ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) - if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil { - c.sendDenyMessage(ctx, task.ID, msg.Queue, err) - return - } -} - -func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result { - defer RecoverPanic(RecoverTitle) - queue, _ := GetQueue(ctx) - if msg.Topic == "" && queue != "" { - msg.Topic = queue - } - result := c.handler(ctx, msg) - result.Topic = msg.Topic - result.TaskID = msg.ID - return result -} - -func (c *Consumer) OnResponse(ctx context.Context, result Result) error { - if result.Status == "PENDING" && c.opts.respondPendingResult { - return nil - } - headers := HeadersWithConsumerIDAndQueue(ctx, c.id, result.Topic) - if result.Status == "" { - if result.Error != nil { - result.Status = "FAILED" - } else { - result.Status = "SUCCESS" - } - } - bt, _ := json.Marshal(result) - reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers) - if err := c.send(ctx, c.conn, reply); err != nil { - return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err) - } - return nil -} - -func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) { - headers := HeadersWithConsumerID(ctx, c.id) - reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers) - if sendErr := c.send(ctx, c.conn, reply); sendErr != nil { - log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr) - } -} - -func (c *Consumer) attemptConnect() error { - var err error - delay := c.opts.initialDelay - for i := 0; i < c.opts.maxRetries; i++ { - conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig) - if err == nil { - c.conn = conn - return nil - } - sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent) - log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration) - time.Sleep(sleepDuration) - delay *= 2 - if delay > c.opts.maxBackoff { - delay = c.opts.maxBackoff - } - } - - return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err) -} - -func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error { - msg, err := c.receive(ctx, conn) - if err == nil { - ctx = SetHeaders(ctx, msg.Headers) - return c.OnMessage(ctx, msg, conn) - } - if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { - err1 := c.OnClose(ctx, conn) - if err1 != nil { - return err1 - } - return err - } - c.OnError(ctx, conn, err) - return err -} - -func (c *Consumer) Consume(ctx context.Context) error { - err := c.attemptConnect() - if err != nil { - return err - } - c.pool = NewPool( - c.opts.numOfWorkers, - WithTaskQueueSize(c.opts.queueSize), - WithMaxMemoryLoad(c.opts.maxMemoryLoad), - WithHandler(c.ProcessTask), - WithPoolCallback(c.OnResponse), - WithTaskStorage(c.opts.storage), - ) - if err := c.subscribe(ctx, c.queue); err != nil { - return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) - } - c.pool.Start(c.opts.numOfWorkers) - // Infinite loop to continuously read messages and reconnect if needed. - for { - select { - case <-ctx.Done(): - log.Println("Context canceled, stopping consumer.") - return nil - default: - if c.opts.ConsumerRateLimiter != nil { - c.opts.ConsumerRateLimiter.Wait() - } - if err := c.readMessage(ctx, c.conn); err != nil { - log.Printf("Error reading message: %v, attempting reconnection...", err) - // Attempt reconnection loop. - for { - if ctx.Err() != nil { - return nil - } - if rErr := c.attemptConnect(); rErr != nil { - log.Printf("Reconnection attempt failed: %v", rErr) - time.Sleep(c.opts.initialDelay) - } else { - break - } - } - if err := c.subscribe(ctx, c.queue); err != nil { - log.Printf("Failed to re-subscribe on reconnection: %v", err) - time.Sleep(c.opts.initialDelay) - } - } - } - } -} - -func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error { - msg, err := c.receive(ctx, conn) - if err != nil { - return err - } - if msg.Command == consts.SUBSCRIBE_ACK { - log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue) - return nil - } - return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) -} - -func (c *Consumer) Pause(ctx context.Context) error { - return c.operate(ctx, consts.CONSUMER_PAUSED, c.pool.Pause) -} - -func (c *Consumer) Resume(ctx context.Context) error { - return c.operate(ctx, consts.CONSUMER_RESUMED, c.pool.Resume) -} - -func (c *Consumer) Stop(ctx context.Context) error { - return c.operate(ctx, consts.CONSUMER_STOPPED, c.pool.Stop) -} - -func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation func()) error { - if err := c.sendOpsMessage(ctx, cmd); err != nil { - return err - } - poolOperation() - return nil -} - -func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error { - headers := HeadersWithConsumerID(ctx, c.id) - msg := codec.NewMessage(cmd, nil, c.queue, headers) - return c.send(ctx, c.conn, msg) -} - -func (c *Consumer) Conn() net.Conn { - return c.conn -} - -type Publisher struct { - opts *Options - id string - conn net.Conn - connLock sync.Mutex -} - -func NewPublisher(id string, opts ...Option) *Publisher { - options := SetupOptions(opts...) - return &Publisher{ - id: id, - opts: options, - conn: nil, - } -} - -// New method to ensure a persistent connection. -func (p *Publisher) ensureConnection(ctx context.Context) error { - p.connLock.Lock() - defer p.connLock.Unlock() - if p.conn != nil { - return nil - } - var err error - delay := p.opts.initialDelay - for i := 0; i < p.opts.maxRetries; i++ { - var conn net.Conn - conn, err = GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) - if err == nil { - p.conn = conn - return nil - } - sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent) - log.Printf("PUBLISHER - ensureConnection failed: %v, attempt %d/%d, retrying in %v...", err, i+1, p.opts.maxRetries, sleepDuration) - time.Sleep(sleepDuration) - delay *= 2 - if delay > p.opts.maxBackoff { - delay = p.opts.maxBackoff - } - } - return fmt.Errorf("failed to connect to broker after retries: %w", err) -} - -// Modified Publish method that uses the persistent connection. -func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error { - // Ensure connection is established. - if err := p.ensureConnection(ctx); err != nil { - return err - } - delay := p.opts.initialDelay - for i := 0; i < p.opts.maxRetries; i++ { - // Use the persistent connection. - p.connLock.Lock() - conn := p.conn - p.connLock.Unlock() - err := p.send(ctx, queue, task, conn, consts.PUBLISH) - if err == nil { - return nil - } - log.Printf("PUBLISHER - Failed publishing: %v, attempt %d/%d, retrying...", err, i+1, p.opts.maxRetries) - // On error, close and reset the connection. - p.connLock.Lock() - if p.conn != nil { - p.conn.Close() - p.conn = nil - } - p.connLock.Unlock() - sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent) - time.Sleep(sleepDuration) - delay *= 2 - if delay > p.opts.maxBackoff { - delay = p.opts.maxBackoff - } - // Ensure connection is re-established. - if err := p.ensureConnection(ctx); err != nil { - return err - } - } - return fmt.Errorf("failed to publish after retries") -} - -func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error { - headers := WithHeaders(ctx, map[string]string{ - consts.PublisherKey: p.id, - consts.ContentType: consts.TypeJson, - }) - if task.ID == "" { - task.ID = NewID() - } - task.CreatedAt = time.Now() - payload, err := json.Marshal(task) - if err != nil { - return err - } - msg := codec.NewMessage(command, payload, queue, headers) - if err := codec.SendMessage(ctx, conn, msg); err != nil { - return err - } - - return p.waitForAck(ctx, conn) -} - -func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error { - msg, err := codec.ReadMessage(ctx, conn) - if err != nil { - return err - } - if msg.Command == consts.PUBLISH_ACK { - taskID, _ := jsonparser.GetString(msg.Payload, "id") - log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID) - return nil - } - return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command) -} - -func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result { - msg, err := codec.ReadMessage(ctx, conn) - if err != nil { - return Result{Error: err} - } - if msg.Command == consts.RESPONSE { - var result Result - err = json.Unmarshal(msg.Payload, &result) - return result - } - err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command) - return Result{Error: err} -} - -func (p *Publisher) onClose(_ context.Context, conn net.Conn) error { - fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr()) - return nil -} - -func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) { - fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr()) -} - -func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result { - ctx = SetHeaders(ctx, map[string]string{ - consts.AwaitResponseKey: "true", - }) - conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) - if err != nil { - err = fmt.Errorf("failed to connect to broker: %w", err) - return Result{Error: err} - } - defer func() { - _ = conn.Close() - }() - err = p.send(ctx, queue, task, conn, consts.PUBLISH) - resultCh := make(chan Result) - go func() { - defer close(resultCh) - resultCh <- p.waitForResponse(ctx, conn) - }() - finalResult := <-resultCh - return finalResult -} - -type Queue struct { - consumers storage.IMap[string, *consumer] - tasks chan *QueuedTask // channel to hold tasks - name string -} - -func newQueue(name string, queueSize int) *Queue { - return &Queue{ - name: name, - consumers: memory.New[string, *consumer](), - tasks: make(chan *QueuedTask, queueSize), // buffer size for tasks - } -} - func (b *Broker) NewQueue(name string) *Queue { q := &Queue{ name: name, @@ -1222,76 +747,6 @@ func (b *Broker) NewQueue(name string) *Queue { return q } -type QueueTask struct { - ctx context.Context - payload *Task - priority int - retryCount int - index int -} - -type PriorityQueue []*QueueTask - -func (pq PriorityQueue) Len() int { return len(pq) } -func (pq PriorityQueue) Less(i, j int) bool { - return pq[i].priority > pq[j].priority -} -func (pq PriorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j -} -func (pq *PriorityQueue) Push(x interface{}) { - n := len(*pq) - task := x.(*QueueTask) - task.index = n - *pq = append(*pq, task) -} -func (pq *PriorityQueue) Pop() interface{} { - old := *pq - n := len(old) - task := old[n-1] - task.index = -1 - *pq = old[0 : n-1] - return task -} - -type Task struct { - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - Expiry time.Time `json:"expiry"` - Error error `json:"error"` - ID string `json:"id"` - Topic string `json:"topic"` - Status string `json:"status"` - Payload json.RawMessage `json:"payload"` - dag any -} - -func (t *Task) GetFlow() any { - return t.dag -} - -func NewTask(id string, payload json.RawMessage, nodeKey string, opts ...TaskOption) *Task { - if id == "" { - id = NewID() - } - task := &Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()} - for _, opt := range opts { - opt(task) - } - return task -} - -// TaskOption defines a function type for setting options. -type TaskOption func(*Task) - -func WithDAG(dag any) TaskOption { - return func(opts *Task) { - opts.dag = dag - } -} - func (b *Broker) TLSConfig() TLSConfig { return b.opts.tlsConfig } diff --git a/options.go b/options.go index 8e1dbd5..5e7e3b4 100644 --- a/options.go +++ b/options.go @@ -280,3 +280,12 @@ func DisableConsumerRateLimit() Option { opts.ConsumerRateLimiter = nil } } + +// TaskOption defines a function type for setting options. +type TaskOption func(*Task) + +func WithDAG(dag any) TaskOption { + return func(opts *Task) { + opts.dag = dag + } +} diff --git a/publisher.go b/publisher.go new file mode 100644 index 0000000..a541158 --- /dev/null +++ b/publisher.go @@ -0,0 +1,177 @@ +package mq + +import ( + "context" + "fmt" + "log" + "net" + "sync" + "time" + + "github.com/oarkflow/json" + + "github.com/oarkflow/mq/codec" + "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/jsonparser" + "github.com/oarkflow/mq/utils" +) + +type Publisher struct { + opts *Options + id string + conn net.Conn + connLock sync.Mutex +} + +func NewPublisher(id string, opts ...Option) *Publisher { + options := SetupOptions(opts...) + return &Publisher{ + id: id, + opts: options, + conn: nil, + } +} + +// New method to ensure a persistent connection. +func (p *Publisher) ensureConnection(ctx context.Context) error { + p.connLock.Lock() + defer p.connLock.Unlock() + if p.conn != nil { + return nil + } + var err error + delay := p.opts.initialDelay + for i := 0; i < p.opts.maxRetries; i++ { + var conn net.Conn + conn, err = GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) + if err == nil { + p.conn = conn + return nil + } + sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent) + log.Printf("PUBLISHER - ensureConnection failed: %v, attempt %d/%d, retrying in %v...", err, i+1, p.opts.maxRetries, sleepDuration) + time.Sleep(sleepDuration) + delay *= 2 + if delay > p.opts.maxBackoff { + delay = p.opts.maxBackoff + } + } + return fmt.Errorf("failed to connect to broker after retries: %w", err) +} + +// Publish method that uses the persistent connection. +func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error { + // Ensure connection is established. + if err := p.ensureConnection(ctx); err != nil { + return err + } + delay := p.opts.initialDelay + for i := 0; i < p.opts.maxRetries; i++ { + // Use the persistent connection. + p.connLock.Lock() + conn := p.conn + p.connLock.Unlock() + err := p.send(ctx, queue, task, conn, consts.PUBLISH) + if err == nil { + return nil + } + log.Printf("PUBLISHER - Failed publishing: %v, attempt %d/%d, retrying...", err, i+1, p.opts.maxRetries) + // On error, close and reset the connection. + p.connLock.Lock() + if p.conn != nil { + p.conn.Close() + p.conn = nil + } + p.connLock.Unlock() + sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent) + time.Sleep(sleepDuration) + delay *= 2 + if delay > p.opts.maxBackoff { + delay = p.opts.maxBackoff + } + // Ensure connection is re-established. + if err := p.ensureConnection(ctx); err != nil { + return err + } + } + return fmt.Errorf("failed to publish after retries") +} + +func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error { + headers := WithHeaders(ctx, map[string]string{ + consts.PublisherKey: p.id, + consts.ContentType: consts.TypeJson, + }) + if task.ID == "" { + task.ID = NewID() + } + task.CreatedAt = time.Now() + payload, err := json.Marshal(task) + if err != nil { + return err + } + msg := codec.NewMessage(command, payload, queue, headers) + if err := codec.SendMessage(ctx, conn, msg); err != nil { + return err + } + + return p.waitForAck(ctx, conn) +} + +func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error { + msg, err := codec.ReadMessage(ctx, conn) + if err != nil { + return err + } + if msg.Command == consts.PUBLISH_ACK { + taskID, _ := jsonparser.GetString(msg.Payload, "id") + log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID) + return nil + } + return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command) +} + +func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result { + msg, err := codec.ReadMessage(ctx, conn) + if err != nil { + return Result{Error: err} + } + if msg.Command == consts.RESPONSE { + var result Result + err = json.Unmarshal(msg.Payload, &result) + return result + } + err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command) + return Result{Error: err} +} + +func (p *Publisher) onClose(_ context.Context, conn net.Conn) error { + fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr()) + return nil +} + +func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) { + fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr()) +} + +func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result { + ctx = SetHeaders(ctx, map[string]string{ + consts.AwaitResponseKey: "true", + }) + conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig) + if err != nil { + err = fmt.Errorf("failed to connect to broker: %w", err) + return Result{Error: err} + } + defer func() { + _ = conn.Close() + }() + err = p.send(ctx, queue, task, conn, consts.PUBLISH) + resultCh := make(chan Result) + go func() { + defer close(resultCh) + resultCh <- p.waitForResponse(ctx, conn) + }() + finalResult := <-resultCh + return finalResult +} diff --git a/task.go b/task.go new file mode 100644 index 0000000..e581dc7 --- /dev/null +++ b/task.go @@ -0,0 +1,86 @@ +package mq + +import ( + "context" + "time" + + "github.com/oarkflow/json" + + "github.com/oarkflow/mq/storage" + "github.com/oarkflow/mq/storage/memory" +) + +type Queue struct { + consumers storage.IMap[string, *consumer] + tasks chan *QueuedTask // channel to hold tasks + name string +} + +func newQueue(name string, queueSize int) *Queue { + return &Queue{ + name: name, + consumers: memory.New[string, *consumer](), + tasks: make(chan *QueuedTask, queueSize), // buffer size for tasks + } +} + +type QueueTask struct { + ctx context.Context + payload *Task + priority int + retryCount int + index int +} + +type PriorityQueue []*QueueTask + +func (pq PriorityQueue) Len() int { return len(pq) } +func (pq PriorityQueue) Less(i, j int) bool { + return pq[i].priority > pq[j].priority +} +func (pq PriorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j +} +func (pq *PriorityQueue) Push(x interface{}) { + n := len(*pq) + task := x.(*QueueTask) + task.index = n + *pq = append(*pq, task) +} +func (pq *PriorityQueue) Pop() interface{} { + old := *pq + n := len(old) + task := old[n-1] + task.index = -1 + *pq = old[0 : n-1] + return task +} + +type Task struct { + CreatedAt time.Time `json:"created_at"` + ProcessedAt time.Time `json:"processed_at"` + Expiry time.Time `json:"expiry"` + Error error `json:"error"` + ID string `json:"id"` + Topic string `json:"topic"` + Status string `json:"status"` + Payload json.RawMessage `json:"payload"` + dag any +} + +func (t *Task) GetFlow() any { + return t.dag +} + +func NewTask(id string, payload json.RawMessage, nodeKey string, opts ...TaskOption) *Task { + if id == "" { + id = NewID() + } + task := &Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()} + for _, opt := range opts { + opt(task) + } + return task +}