diff --git a/broker.go b/broker.go index c3fcc29..9e6282d 100644 --- a/broker.go +++ b/broker.go @@ -32,10 +32,10 @@ func (p *publisher) send(ctx context.Context, cmd any) error { type Handler func(context.Context, Task) Result type Broker struct { - queues xsync.IMap[string, *Queue] - taskCallback func(context.Context, *Task) error - consumers xsync.IMap[string, *consumer] - publishers xsync.IMap[string, *publisher] + queues xsync.IMap[string, *Queue] + consumers xsync.IMap[string, *consumer] + publishers xsync.IMap[string, *publisher] + opts Options } type Queue struct { @@ -93,15 +93,29 @@ type Result struct { Status string `json:"status"` } -func NewBroker(callback ...func(context.Context, *Task) error) *Broker { +func NewBroker(opts ...Option) *Broker { + options := defaultOptions() + for _, opt := range opts { + opt(&options) + } broker := &Broker{ queues: xsync.NewMap[string, *Queue](), publishers: xsync.NewMap[string, *publisher](), consumers: xsync.NewMap[string, *consumer](), } - if len(callback) > 0 { - broker.taskCallback = callback[0] + + if options.messageHandler == nil { + options.messageHandler = broker.readMessage } + + if options.closeHandler == nil { + options.closeHandler = broker.onClose + } + + if options.errorHandler == nil { + options.errorHandler = broker.onError + } + broker.opts = options return broker } @@ -163,7 +177,7 @@ func (b *Broker) Start(ctx context.Context, addr string) error { fmt.Println("Error accepting connection:", err) continue } - go ReadFromConn(ctx, conn, b.readMessage, b.onClose, b.onError) + go ReadFromConn(ctx, conn, b.opts.messageHandler, b.opts.closeHandler, b.opts.errorHandler) } } @@ -221,8 +235,13 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e if clientMsg.Error != nil { msg.Status = "error" } - if b.taskCallback != nil { - return b.taskCallback(ctx, msg) + for _, callback := range b.opts.callback { + if callback != nil { + err := callback(ctx, msg) + if err != nil { + return err + } + } } } } diff --git a/consumer.go b/consumer.go index 977074a..563f723 100644 --- a/consumer.go +++ b/consumer.go @@ -7,26 +7,40 @@ import ( "fmt" "math/rand" "net" - "slices" "sync" "time" ) type Consumer struct { - id string - serverAddr string - handlers map[string]Handler - queues []string - conn net.Conn + id string + handlers map[string]Handler + conn net.Conn + queues []string + opts Options } -func NewConsumer(id, serverAddr string, queues ...string) *Consumer { - return &Consumer{ - handlers: make(map[string]Handler), - serverAddr: serverAddr, - queues: queues, - id: id, +func NewConsumer(id string, opts ...Option) *Consumer { + options := defaultOptions() + for _, opt := range opts { + opt(&options) } + con := &Consumer{ + handlers: make(map[string]Handler), + id: id, + } + if options.messageHandler == nil { + options.messageHandler = con.readConn + } + + if options.closeHandler == nil { + options.closeHandler = con.onClose + } + + if options.errorHandler == nil { + options.errorHandler = con.onError + } + con.opts = options + return con } func (c *Consumer) Close() error { @@ -107,13 +121,13 @@ func (c *Consumer) AttemptConnect() error { var err error delay := initialDelay for i := 0; i < maxRetries; i++ { - conn, err = net.Dial("tcp", c.serverAddr) + conn, err = net.Dial("tcp", c.opts.brokerAddr) if err == nil { c.conn = conn return nil } sleepDuration := calculateJitter(delay) - fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.serverAddr, i+1, maxRetries, err, sleepDuration) + fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, maxRetries, err, sleepDuration) time.Sleep(sleepDuration) delay *= 2 if delay > maxBackoff { @@ -121,7 +135,7 @@ func (c *Consumer) AttemptConnect() error { } } - return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.serverAddr, maxRetries, err) + return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, maxRetries, err) } func calculateJitter(baseDelay time.Duration) time.Duration { @@ -142,7 +156,7 @@ func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) { fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr()) } -func (c *Consumer) Consume(ctx context.Context, queues ...string) error { +func (c *Consumer) Consume(ctx context.Context) error { err := c.AttemptConnect() if err != nil { return err @@ -151,10 +165,9 @@ func (c *Consumer) Consume(ctx context.Context, queues ...string) error { wg.Add(1) go func() { defer wg.Done() - ReadFromConn(ctx, c.conn, c.readConn, c.onClose, c.onError) + ReadFromConn(ctx, c.conn, c.opts.messageHandler, c.opts.closeHandler, c.opts.errorHandler) fmt.Println("Stopping consumer") }() - c.queues = slices.Compact(append(c.queues, queues...)) for _, q := range c.queues { if err := c.subscribe(q); err != nil { return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) @@ -165,5 +178,6 @@ func (c *Consumer) Consume(ctx context.Context, queues ...string) error { } func (c *Consumer) RegisterHandler(queue string, handler Handler) { + c.queues = append(c.queues, queue) c.handlers[queue] = handler } diff --git a/dag.go b/dag.go index 263d501..e5b38b3 100644 --- a/dag.go +++ b/dag.go @@ -55,7 +55,7 @@ func NewDAG(brokerAddr string, syncMode bool) *DAG { conditions: make(map[string]map[string]string), syncMode: syncMode, } - dag.broker = NewBroker(dag.TaskCallback) + dag.broker = NewBroker(WithCallback(dag.TaskCallback)) return dag } @@ -103,7 +103,7 @@ func (dag *DAG) TaskCallback(ctx context.Context, task *Task) error { } func (dag *DAG) AddNode(queue string, handler Handler, firstNode ...bool) { - consumer := NewConsumer(dag.brokerAddr, queue) + consumer := NewConsumer(dag.brokerAddr) consumer.RegisterHandler(queue, handler) dag.broker.NewQueue(queue) n := &node{ diff --git a/examples/consumer.go b/examples/consumer.go index bdd31f0..fa0d23d 100644 --- a/examples/consumer.go +++ b/examples/consumer.go @@ -8,7 +8,7 @@ import ( ) func main() { - consumer := mq.NewConsumer("consumer-1", ":8080") + consumer := mq.NewConsumer("consumer-1") consumer.RegisterHandler("queue1", func(ctx context.Context, task mq.Task) mq.Result { fmt.Println("Handling task for queue1:", string(task.Payload)) return mq.Result{Payload: []byte(`{"task": 123}`), MessageID: task.ID} @@ -17,5 +17,5 @@ func main() { fmt.Println("Handling task for queue2:", task.ID) return mq.Result{Payload: task.Payload, MessageID: task.ID} }) - consumer.Consume(context.Background(), "queue2", "queue1") + consumer.Consume(context.Background()) } diff --git a/examples/publisher.go b/examples/publisher.go index 3a2e601..2a1896f 100644 --- a/examples/publisher.go +++ b/examples/publisher.go @@ -4,17 +4,13 @@ import ( "context" "encoding/json" "fmt" - "log" "github.com/oarkflow/mq" ) func main() { - // Fire-and-Forget Example - err := publishAsync() - if err != nil { - log.Fatalf("Failed to publish async: %v", err) - } + publishAsync() + publishSync() } // publishAsync sends a task in Fire-and-Forget (async) mode @@ -27,7 +23,7 @@ func publishAsync() error { } // Create publisher and send the task without waiting for a result - publisher := mq.NewPublisher("publish-1", ":8080") + publisher := mq.NewPublisher("publish-1") err := publisher.Publish(context.Background(), "queue1", task) if err != nil { return fmt.Errorf("failed to publish async task: %w", err) @@ -47,7 +43,7 @@ func publishSync() error { } // Create publisher and send the task, waiting for the result - publisher := mq.NewPublisher("publish-2", ":8080") + publisher := mq.NewPublisher("publish-2") result, err := publisher.Request(context.Background(), "queue1", task) if err != nil { return fmt.Errorf("failed to publish sync task: %w", err) diff --git a/examples/server.go b/examples/server.go index b95d855..6d42a0b 100644 --- a/examples/server.go +++ b/examples/server.go @@ -8,10 +8,10 @@ import ( ) func main() { - b := mq.NewBroker(func(ctx context.Context, task *mq.Task) error { + b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) error { fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue) return nil - }) + })) b.NewQueue("queue1") b.NewQueue("queue2") b.Start(context.Background(), ":8080") diff --git a/options.go b/options.go new file mode 100644 index 0000000..dc539e6 --- /dev/null +++ b/options.go @@ -0,0 +1,94 @@ +package mq + +import ( + "context" + "time" +) + +type Options struct { + brokerAddr string + messageHandler MessageHandler + closeHandler CloseHandler + errorHandler ErrorHandler + callback []func(context.Context, *Task) error + maxRetries int + initialDelay time.Duration + maxBackoff time.Duration + jitterPercent float64 +} + +func defaultOptions() Options { + return Options{ + brokerAddr: ":8080", + maxRetries: 5, + initialDelay: 2 * time.Second, + maxBackoff: 20 * time.Second, + jitterPercent: 0.5, + } +} + +// Option defines a function type for setting options. +type Option func(*Options) + +// WithBrokerURL - +func WithBrokerURL(url string) Option { + return func(opts *Options) { + opts.brokerAddr = url + } +} + +// WithMaxRetries - +func WithMaxRetries(val int) Option { + return func(opts *Options) { + opts.maxRetries = val + } +} + +// WithInitialDelay - +func WithInitialDelay(val time.Duration) Option { + return func(opts *Options) { + opts.initialDelay = val + } +} + +// WithMaxBackoff - +func WithMaxBackoff(val time.Duration) Option { + return func(opts *Options) { + opts.maxBackoff = val + } +} + +// WithCallback - +func WithCallback(val ...func(context.Context, *Task) error) Option { + return func(opts *Options) { + opts.callback = val + } +} + +// WithJitterPercent - +func WithJitterPercent(val float64) Option { + return func(opts *Options) { + opts.jitterPercent = val + } +} + +// WithMessageHandler sets a custom MessageHandler. +func WithMessageHandler(handler MessageHandler) Option { + return func(opts *Options) { + opts.messageHandler = handler + } +} + +// WithErrorHandler sets a custom ErrorHandler. +func WithErrorHandler(handler ErrorHandler) Option { + return func(opts *Options) { + opts.errorHandler = handler + } +} + +// WithCloseHandler sets a custom CloseHandler. +func WithCloseHandler(handler CloseHandler) Option { + return func(opts *Options) { + opts.closeHandler = handler + } +} diff --git a/publisher.go b/publisher.go index 598a537..5a71c62 100644 --- a/publisher.go +++ b/publisher.go @@ -7,16 +7,33 @@ import ( ) type Publisher struct { - id string - brokerAddr string + id string + opts Options } -func NewPublisher(id, brokerAddr string) *Publisher { - return &Publisher{brokerAddr: brokerAddr, id: id} +func NewPublisher(id string, opts ...Option) *Publisher { + options := defaultOptions() + for _, opt := range opts { + opt(&options) + } + pub := &Publisher{id: id} + if options.messageHandler == nil { + options.messageHandler = pub.readConn + } + + if options.closeHandler == nil { + options.closeHandler = pub.onClose + } + + if options.errorHandler == nil { + options.errorHandler = pub.onError + } + pub.opts = options + return pub } func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error { - conn, err := net.Dial("tcp", p.brokerAddr) + conn, err := net.Dial("tcp", p.opts.brokerAddr) if err != nil { return fmt.Errorf("failed to connect to broker: %w", err) } @@ -50,7 +67,7 @@ func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) { } func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Result, error) { - conn, err := net.Dial("tcp", p.brokerAddr) + conn, err := net.Dial("tcp", p.opts.brokerAddr) if err != nil { return Result{}, fmt.Errorf("failed to connect to broker: %w", err) } @@ -71,6 +88,6 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Resul if err != nil { return result, err } - ReadFromConn(ctx, conn, p.readConn, p.onClose, p.onError) + ReadFromConn(ctx, conn, p.opts.messageHandler, p.opts.closeHandler, p.opts.errorHandler) return result, nil }