feat: separate broker

This commit is contained in:
Oarkflow
2024-09-26 19:10:36 +05:45
parent 5db7f5706a
commit 4a67eeefe0
6 changed files with 188 additions and 171 deletions

View File

@@ -1,14 +1,11 @@
package broker package mq
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net" "net"
"slices"
"sync"
"time" "time"
"github.com/oarkflow/xid" "github.com/oarkflow/xid"
@@ -26,9 +23,9 @@ type Broker struct {
type Queue struct { type Queue struct {
name string name string
conn map[net.Conn]struct{}
messages *xsync.MapOf[string, *Task] messages *xsync.MapOf[string, *Task]
deferred *xsync.MapOf[string, *Task] deferred *xsync.MapOf[string, *Task]
conn map[net.Conn]struct{}
} }
type Task struct { type Task struct {
@@ -85,14 +82,18 @@ func (b *Broker) NewQueue(qName string) {
} }
} }
func (b *Broker) Send(ctx context.Context, cmd Command) { func (b *Broker) Send(ctx context.Context, cmd Command) error {
queue, ok := b.queues.Get(cmd.Queue) queue, ok := b.queues.Get(cmd.Queue)
if !ok || queue == nil { if !ok || queue == nil {
return return errors.New("invalid queue or not exists")
} }
for client := range queue.conn { for client := range queue.conn {
utils.Write(ctx, client, cmd) err := utils.Write(ctx, client, cmd)
if err != nil {
return err
}
} }
return nil
} }
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) error { func (b *Broker) Publish(ctx context.Context, message Task, queueName string) error {
@@ -106,7 +107,10 @@ func (b *Broker) Publish(ctx context.Context, message Task, queueName string) er
return nil return nil
} }
for client := range queue.conn { for client := range queue.conn {
utils.Write(ctx, client, message) err = utils.Write(ctx, client, message)
if err != nil {
return err
}
} }
return nil return nil
} }
@@ -191,7 +195,10 @@ func (b *Broker) subscribe(ctx context.Context, queueName string, conn net.Conn)
q.deferred = xsync.NewMap[string, *Task]() q.deferred = xsync.NewMap[string, *Task]()
} }
q.deferred.ForEach(func(_ string, message *Task) bool { q.deferred.ForEach(func(_ string, message *Task) bool {
b.Publish(ctx, *message, queueName) err := b.Publish(ctx, *message, queueName)
if err != nil {
return false
}
return true return true
}) })
q.deferred = nil q.deferred = nil
@@ -202,7 +209,9 @@ func (b *Broker) Start(ctx context.Context, addr string) error {
if err != nil { if err != nil {
return err return err
} }
defer listener.Close() defer func() {
_ = listener.Close()
}()
fmt.Println("Broker server started on", addr) fmt.Println("Broker server started on", addr)
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
@@ -213,135 +222,3 @@ func (b *Broker) Start(ctx context.Context, addr string) error {
go utils.ReadFromConn(ctx, conn, b.readMessage) go utils.ReadFromConn(ctx, conn, b.readMessage)
} }
} }
type Consumer struct {
serverAddr string
handlers map[string]Handler
queues []string
conn net.Conn
}
func NewConsumer(serverAddr string, queues ...string) *Consumer {
return &Consumer{
handlers: make(map[string]Handler),
serverAddr: serverAddr,
queues: queues,
}
}
func (c *Consumer) Close() error {
return c.conn.Close()
}
func (c *Consumer) subscribe(queue string) error {
ctx := context.Background()
subscribe := Command{Command: SUBSCRIBE, Queue: queue}
return utils.Write(ctx, c.conn, subscribe)
}
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
handler, exists := c.handlers[msg.CurrentQueue]
if !exists {
return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)}
}
return handler(ctx, msg)
}
func (c *Consumer) handleCommandMessage(msg Command) error {
switch msg.Command {
case STOP:
return c.Close()
default:
return fmt.Errorf("unknown command in consumer %s", msg.Command)
}
}
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
response := c.ProcessTask(ctx, msg)
response.Queue = msg.CurrentQueue
if msg.ID == "" {
response.Error = errors.New("task ID is empty")
response.Command = "error"
} else {
response.Command = "completed"
response.MessageID = msg.ID
}
return utils.Write(ctx, c.conn, response)
}
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
var cmdMsg Command
var task Task
err := json.Unmarshal(message, &cmdMsg)
if err == nil && cmdMsg.Command != 0 {
return c.handleCommandMessage(cmdMsg)
}
err = json.Unmarshal(message, &task)
if err == nil {
return c.handleTaskMessage(ctx, task)
}
return nil
}
const (
maxRetries = 5
initialDelay = 2 * time.Second
maxBackoff = 30 * time.Second // Upper limit for backoff delay
jitterPercent = 0.5 // 50% jitter
)
func (c *Consumer) AttemptConnect() error {
var conn net.Conn
var err error
delay := initialDelay
for i := 0; i < maxRetries; i++ {
conn, err = net.Dial("tcp", c.serverAddr)
if err == nil {
c.conn = conn
return nil
}
sleepDuration := calculateJitter(delay)
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.serverAddr, i+1, maxRetries, err, sleepDuration)
time.Sleep(sleepDuration)
delay *= 2
if delay > maxBackoff {
delay = maxBackoff
}
}
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.serverAddr, maxRetries, err)
}
func calculateJitter(baseDelay time.Duration) time.Duration {
jitter := time.Duration(rand.Float64()*jitterPercent*float64(baseDelay)) - time.Duration(jitterPercent*float64(baseDelay)/2)
return baseDelay + jitter
}
func (c *Consumer) Consume(ctx context.Context, queues ...string) error {
err := c.AttemptConnect()
if err != nil {
return err
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
utils.ReadFromConn(ctx, c.conn, func(ctx context.Context, conn net.Conn, message []byte) error {
return c.readMessage(ctx, message)
})
}()
c.queues = slices.Compact(append(c.queues, queues...))
for _, q := range c.queues {
if err := c.subscribe(q); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
}
fmt.Println("Consumer started on", q)
}
wg.Wait()
fmt.Println("Consumer stopped.")
return nil
}
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
c.handlers[queue] = handler
}

144
consumer.go Normal file
View File

@@ -0,0 +1,144 @@
package mq
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/oarkflow/mq/utils"
"math/rand"
"net"
"slices"
"sync"
"time"
)
type Consumer struct {
serverAddr string
handlers map[string]Handler
queues []string
conn net.Conn
}
func NewConsumer(serverAddr string, queues ...string) *Consumer {
return &Consumer{
handlers: make(map[string]Handler),
serverAddr: serverAddr,
queues: queues,
}
}
func (c *Consumer) Close() error {
return c.conn.Close()
}
func (c *Consumer) subscribe(queue string) error {
ctx := context.Background()
subscribe := Command{Command: SUBSCRIBE, Queue: queue}
return utils.Write(ctx, c.conn, subscribe)
}
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
handler, exists := c.handlers[msg.CurrentQueue]
if !exists {
return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)}
}
return handler(ctx, msg)
}
func (c *Consumer) handleCommandMessage(msg Command) error {
switch msg.Command {
case STOP:
return c.Close()
default:
return fmt.Errorf("unknown command in consumer %d", msg.Command)
}
}
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
response := c.ProcessTask(ctx, msg)
response.Queue = msg.CurrentQueue
if msg.ID == "" {
response.Error = errors.New("task ID is empty")
response.Command = "error"
} else {
response.Command = "completed"
response.MessageID = msg.ID
}
return utils.Write(ctx, c.conn, response)
}
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
var cmdMsg Command
var task Task
err := json.Unmarshal(message, &cmdMsg)
if err == nil && cmdMsg.Command != 0 {
return c.handleCommandMessage(cmdMsg)
}
err = json.Unmarshal(message, &task)
if err == nil {
return c.handleTaskMessage(ctx, task)
}
return nil
}
const (
maxRetries = 5
initialDelay = 2 * time.Second
maxBackoff = 30 * time.Second // Upper limit for backoff delay
jitterPercent = 0.5 // 50% jitter
)
func (c *Consumer) AttemptConnect() error {
var conn net.Conn
var err error
delay := initialDelay
for i := 0; i < maxRetries; i++ {
conn, err = net.Dial("tcp", c.serverAddr)
if err == nil {
c.conn = conn
return nil
}
sleepDuration := calculateJitter(delay)
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.serverAddr, i+1, maxRetries, err, sleepDuration)
time.Sleep(sleepDuration)
delay *= 2
if delay > maxBackoff {
delay = maxBackoff
}
}
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.serverAddr, maxRetries, err)
}
func calculateJitter(baseDelay time.Duration) time.Duration {
jitter := time.Duration(rand.Float64()*jitterPercent*float64(baseDelay)) - time.Duration(jitterPercent*float64(baseDelay)/2)
return baseDelay + jitter
}
func (c *Consumer) Consume(ctx context.Context, queues ...string) error {
err := c.AttemptConnect()
if err != nil {
return err
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
utils.ReadFromConn(ctx, c.conn, func(ctx context.Context, conn net.Conn, message []byte) error {
return c.readMessage(ctx, message)
})
}()
c.queues = slices.Compact(append(c.queues, queues...))
for _, q := range c.queues {
if err := c.subscribe(q); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
}
}
wg.Wait()
return nil
}
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
c.handlers[queue] = handler
}

View File

@@ -1,4 +1,4 @@
package broker package mq
import ( import (
"context" "context"
@@ -124,7 +124,6 @@ func (dag *DAG) AddEdge(fromNodeID, toNodeID string) error {
return err return err
} }
dag.edges = append(dag.edges, []string{fromNodeID, toNodeID}) dag.edges = append(dag.edges, []string{fromNodeID, toNodeID})
fmt.Printf("Edge added from %s to %s\n", fromNodeID, toNodeID)
return nil return nil
} }

View File

@@ -3,19 +3,18 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/broker"
) )
func main() { func main() {
consumer := broker.NewConsumer(":8080") consumer := mq.NewConsumer(":8080")
consumer.RegisterHandler("queue1", func(ctx context.Context, task broker.Task) broker.Result { consumer.RegisterHandler("queue1", func(ctx context.Context, task mq.Task) mq.Result {
fmt.Println("Handling task for queue1:", task.ID) fmt.Println("Handling task for queue1:", task.ID)
return broker.Result{Payload: task.Payload, MessageID: task.ID} return mq.Result{Payload: task.Payload, MessageID: task.ID}
}) })
consumer.RegisterHandler("queue2", func(ctx context.Context, task broker.Task) broker.Result { consumer.RegisterHandler("queue2", func(ctx context.Context, task mq.Task) mq.Result {
fmt.Println("Handling task for queue2:", task.ID) fmt.Println("Handling task for queue2:", task.ID)
return broker.Result{Payload: task.Payload, MessageID: task.ID} return mq.Result{Payload: task.Payload, MessageID: task.ID}
}) })
consumer.Consume(context.Background(), "queue2", "queue1") consumer.Consume(context.Background(), "queue2", "queue1")
} }

View File

@@ -4,12 +4,11 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/oarkflow/mq"
"time" "time"
"github.com/oarkflow/mq/broker"
) )
func handleNode1(_ context.Context, task broker.Task) broker.Result { func handleNode1(_ context.Context, task mq.Task) mq.Result {
result := []map[string]string{ result := []map[string]string{
{"field": "facility", "item": "item1"}, {"field": "facility", "item": "item1"},
{"field": "facility", "item": "item2"}, {"field": "facility", "item": "item2"},
@@ -18,18 +17,18 @@ func handleNode1(_ context.Context, task broker.Task) broker.Result {
var payload string var payload string
err := json.Unmarshal(task.Payload, &payload) err := json.Unmarshal(task.Payload, &payload)
if err != nil { if err != nil {
return broker.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)} return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)}
} }
fmt.Printf("Processing task at node1: %s\n", string(task.Payload)) fmt.Printf("Processing task at node1: %s\n", string(task.Payload))
bt, _ := json.Marshal(result) bt, _ := json.Marshal(result)
return broker.Result{Status: "completed", Payload: bt} return mq.Result{Status: "completed", Payload: bt}
} }
func handleNode2(_ context.Context, task broker.Task) broker.Result { func handleNode2(_ context.Context, task mq.Task) mq.Result {
var payload map[string]string var payload map[string]string
err := json.Unmarshal(task.Payload, &payload) err := json.Unmarshal(task.Payload, &payload)
if err != nil { if err != nil {
return broker.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)} return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)}
} }
status := "fail" status := "fail"
if payload["item"] == "item2" { if payload["item"] == "item2" {
@@ -37,24 +36,24 @@ func handleNode2(_ context.Context, task broker.Task) broker.Result {
} }
fmt.Printf("Processing task at node2: %s %s\n", payload, status) fmt.Printf("Processing task at node2: %s %s\n", payload, status)
bt, _ := json.Marshal(payload) bt, _ := json.Marshal(payload)
return broker.Result{Status: status, Payload: bt} return mq.Result{Status: status, Payload: bt}
} }
func handleNode3(_ context.Context, task broker.Task) broker.Result { func handleNode3(_ context.Context, task mq.Task) mq.Result {
result := `{"field": "node3", "item": %s}` result := `{"field": "node3", "item": %s}`
fmt.Printf("Processing task at node3: %s\n", string(task.Payload)) fmt.Printf("Processing task at node3: %s\n", string(task.Payload))
return broker.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))} return mq.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))}
} }
func handleNode4(_ context.Context, task broker.Task) broker.Result { func handleNode4(_ context.Context, task mq.Task) mq.Result {
result := `{"field": "node4", "item": %s}` result := `{"field": "node4", "item": %s}`
fmt.Printf("Processing task at node4: %s\n", string(task.Payload)) fmt.Printf("Processing task at node4: %s\n", string(task.Payload))
return broker.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))} return mq.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))}
} }
func main() { func main() {
ctx := context.Background() ctx := context.Background()
d := broker.NewDAG(":8082", false) d := mq.NewDAG(":8082", false)
d.AddNode("node1", handleNode1, true) d.AddNode("node1", handleNode1, true)
d.AddNode("node2", handleNode2) d.AddNode("node2", handleNode2)
@@ -75,7 +74,7 @@ func main() {
fmt.Println("Error starting DAG:", err) fmt.Println("Error starting DAG:", err)
} }
}() }()
result := d.ProcessTask(ctx, broker.Task{Payload: []byte(`"Start processing"`)}) result := d.ProcessTask(ctx, mq.Task{Payload: []byte(`"Start processing"`)})
fmt.Println(string(result.Payload)) fmt.Println(string(result.Payload))
time.Sleep(50 * time.Second) time.Sleep(50 * time.Second)
} }

View File

@@ -3,13 +3,12 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/oarkflow/mq"
"time" "time"
"github.com/oarkflow/mq/broker"
) )
func main() { func main() {
b := broker.NewBroker(func(ctx context.Context, task *broker.Task) error { b := mq.NewBroker(func(ctx context.Context, task *mq.Task) error {
fmt.Println("Received task", task.ID, string(task.Payload), string(task.Result), task.Error, task.CurrentQueue) fmt.Println("Received task", task.ID, string(task.Payload), string(task.Result), task.Error, task.CurrentQueue)
return nil return nil
}) })
@@ -17,11 +16,11 @@ func main() {
b.NewQueue("queue2") b.NewQueue("queue2")
go func() { go func() {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
b.Publish(context.Background(), broker.Task{ b.Publish(context.Background(), mq.Task{
ID: fmt.Sprint(i), ID: fmt.Sprint(i),
Payload: []byte(`"Hello"`), Payload: []byte(`"Hello"`),
}, "queue1") }, "queue1")
b.Publish(context.Background(), broker.Task{ b.Publish(context.Background(), mq.Task{
ID: fmt.Sprint(i), ID: fmt.Sprint(i),
Payload: []byte(`"World"`), Payload: []byte(`"World"`),
}, "queue2") }, "queue2")