diff --git a/broker/broker.go b/broker.go similarity index 57% rename from broker/broker.go rename to broker.go index 0ba2b48..c5b325b 100644 --- a/broker/broker.go +++ b/broker.go @@ -1,14 +1,11 @@ -package broker +package mq import ( "context" "encoding/json" "errors" "fmt" - "math/rand" "net" - "slices" - "sync" "time" "github.com/oarkflow/xid" @@ -26,9 +23,9 @@ type Broker struct { type Queue struct { name string + conn map[net.Conn]struct{} messages *xsync.MapOf[string, *Task] deferred *xsync.MapOf[string, *Task] - conn map[net.Conn]struct{} } type Task struct { @@ -85,14 +82,18 @@ func (b *Broker) NewQueue(qName string) { } } -func (b *Broker) Send(ctx context.Context, cmd Command) { +func (b *Broker) Send(ctx context.Context, cmd Command) error { queue, ok := b.queues.Get(cmd.Queue) if !ok || queue == nil { - return + return errors.New("invalid queue or not exists") } for client := range queue.conn { - utils.Write(ctx, client, cmd) + err := utils.Write(ctx, client, cmd) + if err != nil { + return err + } } + return nil } func (b *Broker) Publish(ctx context.Context, message Task, queueName string) error { @@ -106,7 +107,10 @@ func (b *Broker) Publish(ctx context.Context, message Task, queueName string) er return nil } for client := range queue.conn { - utils.Write(ctx, client, message) + err = utils.Write(ctx, client, message) + if err != nil { + return err + } } return nil } @@ -191,7 +195,10 @@ func (b *Broker) subscribe(ctx context.Context, queueName string, conn net.Conn) q.deferred = xsync.NewMap[string, *Task]() } q.deferred.ForEach(func(_ string, message *Task) bool { - b.Publish(ctx, *message, queueName) + err := b.Publish(ctx, *message, queueName) + if err != nil { + return false + } return true }) q.deferred = nil @@ -202,7 +209,9 @@ func (b *Broker) Start(ctx context.Context, addr string) error { if err != nil { return err } - defer listener.Close() + defer func() { + _ = listener.Close() + }() fmt.Println("Broker server started on", addr) for { conn, err := listener.Accept() @@ -213,135 +222,3 @@ func (b *Broker) Start(ctx context.Context, addr string) error { go utils.ReadFromConn(ctx, conn, b.readMessage) } } - -type Consumer struct { - serverAddr string - handlers map[string]Handler - queues []string - conn net.Conn -} - -func NewConsumer(serverAddr string, queues ...string) *Consumer { - return &Consumer{ - handlers: make(map[string]Handler), - serverAddr: serverAddr, - queues: queues, - } -} - -func (c *Consumer) Close() error { - return c.conn.Close() -} - -func (c *Consumer) subscribe(queue string) error { - ctx := context.Background() - subscribe := Command{Command: SUBSCRIBE, Queue: queue} - return utils.Write(ctx, c.conn, subscribe) -} - -func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { - handler, exists := c.handlers[msg.CurrentQueue] - if !exists { - return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)} - } - return handler(ctx, msg) -} - -func (c *Consumer) handleCommandMessage(msg Command) error { - switch msg.Command { - case STOP: - return c.Close() - default: - return fmt.Errorf("unknown command in consumer %s", msg.Command) - } -} - -func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error { - response := c.ProcessTask(ctx, msg) - response.Queue = msg.CurrentQueue - if msg.ID == "" { - response.Error = errors.New("task ID is empty") - response.Command = "error" - } else { - response.Command = "completed" - response.MessageID = msg.ID - } - return utils.Write(ctx, c.conn, response) -} - -func (c *Consumer) readMessage(ctx context.Context, message []byte) error { - var cmdMsg Command - var task Task - err := json.Unmarshal(message, &cmdMsg) - if err == nil && cmdMsg.Command != 0 { - return c.handleCommandMessage(cmdMsg) - } - err = json.Unmarshal(message, &task) - if err == nil { - return c.handleTaskMessage(ctx, task) - } - return nil -} - -const ( - maxRetries = 5 - initialDelay = 2 * time.Second - maxBackoff = 30 * time.Second // Upper limit for backoff delay - jitterPercent = 0.5 // 50% jitter -) - -func (c *Consumer) AttemptConnect() error { - var conn net.Conn - var err error - delay := initialDelay - for i := 0; i < maxRetries; i++ { - conn, err = net.Dial("tcp", c.serverAddr) - if err == nil { - c.conn = conn - return nil - } - sleepDuration := calculateJitter(delay) - fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.serverAddr, i+1, maxRetries, err, sleepDuration) - time.Sleep(sleepDuration) - delay *= 2 - if delay > maxBackoff { - delay = maxBackoff - } - } - - return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.serverAddr, maxRetries, err) -} - -func calculateJitter(baseDelay time.Duration) time.Duration { - jitter := time.Duration(rand.Float64()*jitterPercent*float64(baseDelay)) - time.Duration(jitterPercent*float64(baseDelay)/2) - return baseDelay + jitter -} - -func (c *Consumer) Consume(ctx context.Context, queues ...string) error { - err := c.AttemptConnect() - if err != nil { - return err - } - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - utils.ReadFromConn(ctx, c.conn, func(ctx context.Context, conn net.Conn, message []byte) error { - return c.readMessage(ctx, message) - }) - }() - c.queues = slices.Compact(append(c.queues, queues...)) - for _, q := range c.queues { - if err := c.subscribe(q); err != nil { - return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) - } - fmt.Println("Consumer started on", q) - } - wg.Wait() - fmt.Println("Consumer stopped.") - return nil -} - -func (c *Consumer) RegisterHandler(queue string, handler Handler) { - c.handlers[queue] = handler -} diff --git a/consumer.go b/consumer.go new file mode 100644 index 0000000..51bf77b --- /dev/null +++ b/consumer.go @@ -0,0 +1,144 @@ +package mq + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/oarkflow/mq/utils" + "math/rand" + "net" + "slices" + "sync" + "time" +) + +type Consumer struct { + serverAddr string + handlers map[string]Handler + queues []string + conn net.Conn +} + +func NewConsumer(serverAddr string, queues ...string) *Consumer { + return &Consumer{ + handlers: make(map[string]Handler), + serverAddr: serverAddr, + queues: queues, + } +} + +func (c *Consumer) Close() error { + return c.conn.Close() +} + +func (c *Consumer) subscribe(queue string) error { + ctx := context.Background() + subscribe := Command{Command: SUBSCRIBE, Queue: queue} + return utils.Write(ctx, c.conn, subscribe) +} + +func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { + handler, exists := c.handlers[msg.CurrentQueue] + if !exists { + return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)} + } + return handler(ctx, msg) +} + +func (c *Consumer) handleCommandMessage(msg Command) error { + switch msg.Command { + case STOP: + return c.Close() + default: + return fmt.Errorf("unknown command in consumer %d", msg.Command) + } +} + +func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error { + response := c.ProcessTask(ctx, msg) + response.Queue = msg.CurrentQueue + if msg.ID == "" { + response.Error = errors.New("task ID is empty") + response.Command = "error" + } else { + response.Command = "completed" + response.MessageID = msg.ID + } + return utils.Write(ctx, c.conn, response) +} + +func (c *Consumer) readMessage(ctx context.Context, message []byte) error { + var cmdMsg Command + var task Task + err := json.Unmarshal(message, &cmdMsg) + if err == nil && cmdMsg.Command != 0 { + return c.handleCommandMessage(cmdMsg) + } + err = json.Unmarshal(message, &task) + if err == nil { + return c.handleTaskMessage(ctx, task) + } + return nil +} + +const ( + maxRetries = 5 + initialDelay = 2 * time.Second + maxBackoff = 30 * time.Second // Upper limit for backoff delay + jitterPercent = 0.5 // 50% jitter +) + +func (c *Consumer) AttemptConnect() error { + var conn net.Conn + var err error + delay := initialDelay + for i := 0; i < maxRetries; i++ { + conn, err = net.Dial("tcp", c.serverAddr) + if err == nil { + c.conn = conn + return nil + } + sleepDuration := calculateJitter(delay) + fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.serverAddr, i+1, maxRetries, err, sleepDuration) + time.Sleep(sleepDuration) + delay *= 2 + if delay > maxBackoff { + delay = maxBackoff + } + } + + return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.serverAddr, maxRetries, err) +} + +func calculateJitter(baseDelay time.Duration) time.Duration { + jitter := time.Duration(rand.Float64()*jitterPercent*float64(baseDelay)) - time.Duration(jitterPercent*float64(baseDelay)/2) + return baseDelay + jitter +} + +func (c *Consumer) Consume(ctx context.Context, queues ...string) error { + err := c.AttemptConnect() + if err != nil { + return err + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + utils.ReadFromConn(ctx, c.conn, func(ctx context.Context, conn net.Conn, message []byte) error { + return c.readMessage(ctx, message) + }) + }() + c.queues = slices.Compact(append(c.queues, queues...)) + for _, q := range c.queues { + if err := c.subscribe(q); err != nil { + return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) + } + } + wg.Wait() + return nil +} + +func (c *Consumer) RegisterHandler(queue string, handler Handler) { + c.handlers[queue] = handler +} diff --git a/broker/dag.go b/dag.go similarity index 98% rename from broker/dag.go rename to dag.go index 431601e..8603ac3 100644 --- a/broker/dag.go +++ b/dag.go @@ -1,4 +1,4 @@ -package broker +package mq import ( "context" @@ -124,7 +124,6 @@ func (dag *DAG) AddEdge(fromNodeID, toNodeID string) error { return err } dag.edges = append(dag.edges, []string{fromNodeID, toNodeID}) - fmt.Printf("Edge added from %s to %s\n", fromNodeID, toNodeID) return nil } diff --git a/examples/consumer.go b/examples/consumer.go index 6020e2e..0618965 100644 --- a/examples/consumer.go +++ b/examples/consumer.go @@ -3,19 +3,18 @@ package main import ( "context" "fmt" - - "github.com/oarkflow/mq/broker" + "github.com/oarkflow/mq" ) func main() { - consumer := broker.NewConsumer(":8080") - consumer.RegisterHandler("queue1", func(ctx context.Context, task broker.Task) broker.Result { + consumer := mq.NewConsumer(":8080") + consumer.RegisterHandler("queue1", func(ctx context.Context, task mq.Task) mq.Result { fmt.Println("Handling task for queue1:", task.ID) - return broker.Result{Payload: task.Payload, MessageID: task.ID} + return mq.Result{Payload: task.Payload, MessageID: task.ID} }) - consumer.RegisterHandler("queue2", func(ctx context.Context, task broker.Task) broker.Result { + consumer.RegisterHandler("queue2", func(ctx context.Context, task mq.Task) mq.Result { fmt.Println("Handling task for queue2:", task.ID) - return broker.Result{Payload: task.Payload, MessageID: task.ID} + return mq.Result{Payload: task.Payload, MessageID: task.ID} }) consumer.Consume(context.Background(), "queue2", "queue1") } diff --git a/examples/dag.go b/examples/dag.go index 2588f79..0b0aef8 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -4,12 +4,11 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq" "time" - - "github.com/oarkflow/mq/broker" ) -func handleNode1(_ context.Context, task broker.Task) broker.Result { +func handleNode1(_ context.Context, task mq.Task) mq.Result { result := []map[string]string{ {"field": "facility", "item": "item1"}, {"field": "facility", "item": "item2"}, @@ -18,18 +17,18 @@ func handleNode1(_ context.Context, task broker.Task) broker.Result { var payload string err := json.Unmarshal(task.Payload, &payload) if err != nil { - return broker.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)} + return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)} } fmt.Printf("Processing task at node1: %s\n", string(task.Payload)) bt, _ := json.Marshal(result) - return broker.Result{Status: "completed", Payload: bt} + return mq.Result{Status: "completed", Payload: bt} } -func handleNode2(_ context.Context, task broker.Task) broker.Result { +func handleNode2(_ context.Context, task mq.Task) mq.Result { var payload map[string]string err := json.Unmarshal(task.Payload, &payload) if err != nil { - return broker.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)} + return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)} } status := "fail" if payload["item"] == "item2" { @@ -37,24 +36,24 @@ func handleNode2(_ context.Context, task broker.Task) broker.Result { } fmt.Printf("Processing task at node2: %s %s\n", payload, status) bt, _ := json.Marshal(payload) - return broker.Result{Status: status, Payload: bt} + return mq.Result{Status: status, Payload: bt} } -func handleNode3(_ context.Context, task broker.Task) broker.Result { +func handleNode3(_ context.Context, task mq.Task) mq.Result { result := `{"field": "node3", "item": %s}` fmt.Printf("Processing task at node3: %s\n", string(task.Payload)) - return broker.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))} + return mq.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))} } -func handleNode4(_ context.Context, task broker.Task) broker.Result { +func handleNode4(_ context.Context, task mq.Task) mq.Result { result := `{"field": "node4", "item": %s}` fmt.Printf("Processing task at node4: %s\n", string(task.Payload)) - return broker.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))} + return mq.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))} } func main() { ctx := context.Background() - d := broker.NewDAG(":8082", false) + d := mq.NewDAG(":8082", false) d.AddNode("node1", handleNode1, true) d.AddNode("node2", handleNode2) @@ -75,7 +74,7 @@ func main() { fmt.Println("Error starting DAG:", err) } }() - result := d.ProcessTask(ctx, broker.Task{Payload: []byte(`"Start processing"`)}) + result := d.ProcessTask(ctx, mq.Task{Payload: []byte(`"Start processing"`)}) fmt.Println(string(result.Payload)) time.Sleep(50 * time.Second) } diff --git a/examples/server.go b/examples/server.go index e19ede6..afdcc06 100644 --- a/examples/server.go +++ b/examples/server.go @@ -3,13 +3,12 @@ package main import ( "context" "fmt" + "github.com/oarkflow/mq" "time" - - "github.com/oarkflow/mq/broker" ) func main() { - b := broker.NewBroker(func(ctx context.Context, task *broker.Task) error { + b := mq.NewBroker(func(ctx context.Context, task *mq.Task) error { fmt.Println("Received task", task.ID, string(task.Payload), string(task.Result), task.Error, task.CurrentQueue) return nil }) @@ -17,11 +16,11 @@ func main() { b.NewQueue("queue2") go func() { for i := 0; i < 10; i++ { - b.Publish(context.Background(), broker.Task{ + b.Publish(context.Background(), mq.Task{ ID: fmt.Sprint(i), Payload: []byte(`"Hello"`), }, "queue1") - b.Publish(context.Background(), broker.Task{ + b.Publish(context.Background(), mq.Task{ ID: fmt.Sprint(i), Payload: []byte(`"World"`), }, "queue2")