init: publisher

This commit is contained in:
sujit
2024-09-27 17:44:27 +05:45
parent 3b80ab6ba6
commit 6657728f6c
8 changed files with 189 additions and 49 deletions

View File

@@ -32,10 +32,10 @@ func (p *publisher) send(ctx context.Context, cmd any) error {
type Handler func(context.Context, Task) Result type Handler func(context.Context, Task) Result
type Broker struct { type Broker struct {
queues xsync.IMap[string, *Queue] queues xsync.IMap[string, *Queue]
taskCallback func(context.Context, *Task) error consumers xsync.IMap[string, *consumer]
consumers xsync.IMap[string, *consumer] publishers xsync.IMap[string, *publisher]
publishers xsync.IMap[string, *publisher] opts Options
} }
type Queue struct { type Queue struct {
@@ -93,15 +93,29 @@ type Result struct {
Status string `json:"status"` Status string `json:"status"`
} }
func NewBroker(callback ...func(context.Context, *Task) error) *Broker { func NewBroker(opts ...Option) *Broker {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
broker := &Broker{ broker := &Broker{
queues: xsync.NewMap[string, *Queue](), queues: xsync.NewMap[string, *Queue](),
publishers: xsync.NewMap[string, *publisher](), publishers: xsync.NewMap[string, *publisher](),
consumers: xsync.NewMap[string, *consumer](), consumers: xsync.NewMap[string, *consumer](),
} }
if len(callback) > 0 {
broker.taskCallback = callback[0] if options.messageHandler == nil {
options.messageHandler = broker.readMessage
} }
if options.closeHandler == nil {
options.closeHandler = broker.onClose
}
if options.errorHandler == nil {
options.errorHandler = broker.onError
}
broker.opts = options
return broker return broker
} }
@@ -163,7 +177,7 @@ func (b *Broker) Start(ctx context.Context, addr string) error {
fmt.Println("Error accepting connection:", err) fmt.Println("Error accepting connection:", err)
continue continue
} }
go ReadFromConn(ctx, conn, b.readMessage, b.onClose, b.onError) go ReadFromConn(ctx, conn, b.opts.messageHandler, b.opts.closeHandler, b.opts.errorHandler)
} }
} }
@@ -221,8 +235,13 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e
if clientMsg.Error != nil { if clientMsg.Error != nil {
msg.Status = "error" msg.Status = "error"
} }
if b.taskCallback != nil { for _, callback := range b.opts.callback {
return b.taskCallback(ctx, msg) if callback != nil {
err := callback(ctx, msg)
if err != nil {
return err
}
}
} }
} }
} }

View File

@@ -7,26 +7,40 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"slices"
"sync" "sync"
"time" "time"
) )
type Consumer struct { type Consumer struct {
id string id string
serverAddr string handlers map[string]Handler
handlers map[string]Handler conn net.Conn
queues []string queues []string
conn net.Conn opts Options
} }
func NewConsumer(id, serverAddr string, queues ...string) *Consumer { func NewConsumer(id string, opts ...Option) *Consumer {
return &Consumer{ options := defaultOptions()
handlers: make(map[string]Handler), for _, opt := range opts {
serverAddr: serverAddr, opt(&options)
queues: queues,
id: id,
} }
con := &Consumer{
handlers: make(map[string]Handler),
id: id,
}
if options.messageHandler == nil {
options.messageHandler = con.readConn
}
if options.closeHandler == nil {
options.closeHandler = con.onClose
}
if options.errorHandler == nil {
options.errorHandler = con.onError
}
con.opts = options
return con
} }
func (c *Consumer) Close() error { func (c *Consumer) Close() error {
@@ -107,13 +121,13 @@ func (c *Consumer) AttemptConnect() error {
var err error var err error
delay := initialDelay delay := initialDelay
for i := 0; i < maxRetries; i++ { for i := 0; i < maxRetries; i++ {
conn, err = net.Dial("tcp", c.serverAddr) conn, err = net.Dial("tcp", c.opts.brokerAddr)
if err == nil { if err == nil {
c.conn = conn c.conn = conn
return nil return nil
} }
sleepDuration := calculateJitter(delay) sleepDuration := calculateJitter(delay)
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.serverAddr, i+1, maxRetries, err, sleepDuration) fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, maxRetries, err, sleepDuration)
time.Sleep(sleepDuration) time.Sleep(sleepDuration)
delay *= 2 delay *= 2
if delay > maxBackoff { if delay > maxBackoff {
@@ -121,7 +135,7 @@ func (c *Consumer) AttemptConnect() error {
} }
} }
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.serverAddr, maxRetries, err) return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, maxRetries, err)
} }
func calculateJitter(baseDelay time.Duration) time.Duration { func calculateJitter(baseDelay time.Duration) time.Duration {
@@ -142,7 +156,7 @@ func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr()) fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr())
} }
func (c *Consumer) Consume(ctx context.Context, queues ...string) error { func (c *Consumer) Consume(ctx context.Context) error {
err := c.AttemptConnect() err := c.AttemptConnect()
if err != nil { if err != nil {
return err return err
@@ -151,10 +165,9 @@ func (c *Consumer) Consume(ctx context.Context, queues ...string) error {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
ReadFromConn(ctx, c.conn, c.readConn, c.onClose, c.onError) ReadFromConn(ctx, c.conn, c.opts.messageHandler, c.opts.closeHandler, c.opts.errorHandler)
fmt.Println("Stopping consumer") fmt.Println("Stopping consumer")
}() }()
c.queues = slices.Compact(append(c.queues, queues...))
for _, q := range c.queues { for _, q := range c.queues {
if err := c.subscribe(q); err != nil { if err := c.subscribe(q); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
@@ -165,5 +178,6 @@ func (c *Consumer) Consume(ctx context.Context, queues ...string) error {
} }
func (c *Consumer) RegisterHandler(queue string, handler Handler) { func (c *Consumer) RegisterHandler(queue string, handler Handler) {
c.queues = append(c.queues, queue)
c.handlers[queue] = handler c.handlers[queue] = handler
} }

4
dag.go
View File

@@ -55,7 +55,7 @@ func NewDAG(brokerAddr string, syncMode bool) *DAG {
conditions: make(map[string]map[string]string), conditions: make(map[string]map[string]string),
syncMode: syncMode, syncMode: syncMode,
} }
dag.broker = NewBroker(dag.TaskCallback) dag.broker = NewBroker(WithCallback(dag.TaskCallback))
return dag return dag
} }
@@ -103,7 +103,7 @@ func (dag *DAG) TaskCallback(ctx context.Context, task *Task) error {
} }
func (dag *DAG) AddNode(queue string, handler Handler, firstNode ...bool) { func (dag *DAG) AddNode(queue string, handler Handler, firstNode ...bool) {
consumer := NewConsumer(dag.brokerAddr, queue) consumer := NewConsumer(dag.brokerAddr)
consumer.RegisterHandler(queue, handler) consumer.RegisterHandler(queue, handler)
dag.broker.NewQueue(queue) dag.broker.NewQueue(queue)
n := &node{ n := &node{

View File

@@ -8,7 +8,7 @@ import (
) )
func main() { func main() {
consumer := mq.NewConsumer("consumer-1", ":8080") consumer := mq.NewConsumer("consumer-1")
consumer.RegisterHandler("queue1", func(ctx context.Context, task mq.Task) mq.Result { consumer.RegisterHandler("queue1", func(ctx context.Context, task mq.Task) mq.Result {
fmt.Println("Handling task for queue1:", string(task.Payload)) fmt.Println("Handling task for queue1:", string(task.Payload))
return mq.Result{Payload: []byte(`{"task": 123}`), MessageID: task.ID} return mq.Result{Payload: []byte(`{"task": 123}`), MessageID: task.ID}
@@ -17,5 +17,5 @@ func main() {
fmt.Println("Handling task for queue2:", task.ID) fmt.Println("Handling task for queue2:", task.ID)
return mq.Result{Payload: task.Payload, MessageID: task.ID} return mq.Result{Payload: task.Payload, MessageID: task.ID}
}) })
consumer.Consume(context.Background(), "queue2", "queue1") consumer.Consume(context.Background())
} }

View File

@@ -4,17 +4,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
) )
func main() { func main() {
// Fire-and-Forget Example publishAsync()
err := publishAsync() publishSync()
if err != nil {
log.Fatalf("Failed to publish async: %v", err)
}
} }
// publishAsync sends a task in Fire-and-Forget (async) mode // publishAsync sends a task in Fire-and-Forget (async) mode
@@ -27,7 +23,7 @@ func publishAsync() error {
} }
// Create publisher and send the task without waiting for a result // Create publisher and send the task without waiting for a result
publisher := mq.NewPublisher("publish-1", ":8080") publisher := mq.NewPublisher("publish-1")
err := publisher.Publish(context.Background(), "queue1", task) err := publisher.Publish(context.Background(), "queue1", task)
if err != nil { if err != nil {
return fmt.Errorf("failed to publish async task: %w", err) return fmt.Errorf("failed to publish async task: %w", err)
@@ -47,7 +43,7 @@ func publishSync() error {
} }
// Create publisher and send the task, waiting for the result // Create publisher and send the task, waiting for the result
publisher := mq.NewPublisher("publish-2", ":8080") publisher := mq.NewPublisher("publish-2")
result, err := publisher.Request(context.Background(), "queue1", task) result, err := publisher.Request(context.Background(), "queue1", task)
if err != nil { if err != nil {
return fmt.Errorf("failed to publish sync task: %w", err) return fmt.Errorf("failed to publish sync task: %w", err)

View File

@@ -8,10 +8,10 @@ import (
) )
func main() { func main() {
b := mq.NewBroker(func(ctx context.Context, task *mq.Task) error { b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) error {
fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue) fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue)
return nil return nil
}) }))
b.NewQueue("queue1") b.NewQueue("queue1")
b.NewQueue("queue2") b.NewQueue("queue2")
b.Start(context.Background(), ":8080") b.Start(context.Background(), ":8080")

94
options.go Normal file
View File

@@ -0,0 +1,94 @@
package mq
import (
"context"
"time"
)
type Options struct {
brokerAddr string
messageHandler MessageHandler
closeHandler CloseHandler
errorHandler ErrorHandler
callback []func(context.Context, *Task) error
maxRetries int
initialDelay time.Duration
maxBackoff time.Duration
jitterPercent float64
}
func defaultOptions() Options {
return Options{
brokerAddr: ":8080",
maxRetries: 5,
initialDelay: 2 * time.Second,
maxBackoff: 20 * time.Second,
jitterPercent: 0.5,
}
}
// Option defines a function type for setting options.
type Option func(*Options)
// WithBrokerURL -
func WithBrokerURL(url string) Option {
return func(opts *Options) {
opts.brokerAddr = url
}
}
// WithMaxRetries -
func WithMaxRetries(val int) Option {
return func(opts *Options) {
opts.maxRetries = val
}
}
// WithInitialDelay -
func WithInitialDelay(val time.Duration) Option {
return func(opts *Options) {
opts.initialDelay = val
}
}
// WithMaxBackoff -
func WithMaxBackoff(val time.Duration) Option {
return func(opts *Options) {
opts.maxBackoff = val
}
}
// WithCallback -
func WithCallback(val ...func(context.Context, *Task) error) Option {
return func(opts *Options) {
opts.callback = val
}
}
// WithJitterPercent -
func WithJitterPercent(val float64) Option {
return func(opts *Options) {
opts.jitterPercent = val
}
}
// WithMessageHandler sets a custom MessageHandler.
func WithMessageHandler(handler MessageHandler) Option {
return func(opts *Options) {
opts.messageHandler = handler
}
}
// WithErrorHandler sets a custom ErrorHandler.
func WithErrorHandler(handler ErrorHandler) Option {
return func(opts *Options) {
opts.errorHandler = handler
}
}
// WithCloseHandler sets a custom CloseHandler.
func WithCloseHandler(handler CloseHandler) Option {
return func(opts *Options) {
opts.closeHandler = handler
}
}

View File

@@ -7,16 +7,33 @@ import (
) )
type Publisher struct { type Publisher struct {
id string id string
brokerAddr string opts Options
} }
func NewPublisher(id, brokerAddr string) *Publisher { func NewPublisher(id string, opts ...Option) *Publisher {
return &Publisher{brokerAddr: brokerAddr, id: id} options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
pub := &Publisher{id: id}
if options.messageHandler == nil {
options.messageHandler = pub.readConn
}
if options.closeHandler == nil {
options.closeHandler = pub.onClose
}
if options.errorHandler == nil {
options.errorHandler = pub.onError
}
pub.opts = options
return pub
} }
func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error { func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error {
conn, err := net.Dial("tcp", p.brokerAddr) conn, err := net.Dial("tcp", p.opts.brokerAddr)
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to broker: %w", err) return fmt.Errorf("failed to connect to broker: %w", err)
} }
@@ -50,7 +67,7 @@ func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
} }
func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Result, error) { func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Result, error) {
conn, err := net.Dial("tcp", p.brokerAddr) conn, err := net.Dial("tcp", p.opts.brokerAddr)
if err != nil { if err != nil {
return Result{}, fmt.Errorf("failed to connect to broker: %w", err) return Result{}, fmt.Errorf("failed to connect to broker: %w", err)
} }
@@ -71,6 +88,6 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Resul
if err != nil { if err != nil {
return result, err return result, err
} }
ReadFromConn(ctx, conn, p.readConn, p.onClose, p.onError) ReadFromConn(ctx, conn, p.opts.messageHandler, p.opts.closeHandler, p.opts.errorHandler)
return result, nil return result, nil
} }