mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-03 23:26:28 +08:00
feat: separate broker
This commit is contained in:
@@ -1,14 +1,11 @@
|
||||
package broker
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xid"
|
||||
@@ -26,9 +23,9 @@ type Broker struct {
|
||||
|
||||
type Queue struct {
|
||||
name string
|
||||
conn map[net.Conn]struct{}
|
||||
messages *xsync.MapOf[string, *Task]
|
||||
deferred *xsync.MapOf[string, *Task]
|
||||
conn map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
type Task struct {
|
||||
@@ -85,14 +82,18 @@ func (b *Broker) NewQueue(qName string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) Send(ctx context.Context, cmd Command) {
|
||||
func (b *Broker) Send(ctx context.Context, cmd Command) error {
|
||||
queue, ok := b.queues.Get(cmd.Queue)
|
||||
if !ok || queue == nil {
|
||||
return
|
||||
return errors.New("invalid queue or not exists")
|
||||
}
|
||||
for client := range queue.conn {
|
||||
utils.Write(ctx, client, cmd)
|
||||
err := utils.Write(ctx, client, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) error {
|
||||
@@ -106,7 +107,10 @@ func (b *Broker) Publish(ctx context.Context, message Task, queueName string) er
|
||||
return nil
|
||||
}
|
||||
for client := range queue.conn {
|
||||
utils.Write(ctx, client, message)
|
||||
err = utils.Write(ctx, client, message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -191,7 +195,10 @@ func (b *Broker) subscribe(ctx context.Context, queueName string, conn net.Conn)
|
||||
q.deferred = xsync.NewMap[string, *Task]()
|
||||
}
|
||||
q.deferred.ForEach(func(_ string, message *Task) bool {
|
||||
b.Publish(ctx, *message, queueName)
|
||||
err := b.Publish(ctx, *message, queueName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
q.deferred = nil
|
||||
@@ -202,7 +209,9 @@ func (b *Broker) Start(ctx context.Context, addr string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
fmt.Println("Broker server started on", addr)
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
@@ -213,135 +222,3 @@ func (b *Broker) Start(ctx context.Context, addr string) error {
|
||||
go utils.ReadFromConn(ctx, conn, b.readMessage)
|
||||
}
|
||||
}
|
||||
|
||||
type Consumer struct {
|
||||
serverAddr string
|
||||
handlers map[string]Handler
|
||||
queues []string
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func NewConsumer(serverAddr string, queues ...string) *Consumer {
|
||||
return &Consumer{
|
||||
handlers: make(map[string]Handler),
|
||||
serverAddr: serverAddr,
|
||||
queues: queues,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Consumer) subscribe(queue string) error {
|
||||
ctx := context.Background()
|
||||
subscribe := Command{Command: SUBSCRIBE, Queue: queue}
|
||||
return utils.Write(ctx, c.conn, subscribe)
|
||||
}
|
||||
|
||||
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
||||
handler, exists := c.handlers[msg.CurrentQueue]
|
||||
if !exists {
|
||||
return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)}
|
||||
}
|
||||
return handler(ctx, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) handleCommandMessage(msg Command) error {
|
||||
switch msg.Command {
|
||||
case STOP:
|
||||
return c.Close()
|
||||
default:
|
||||
return fmt.Errorf("unknown command in consumer %s", msg.Command)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
|
||||
response := c.ProcessTask(ctx, msg)
|
||||
response.Queue = msg.CurrentQueue
|
||||
if msg.ID == "" {
|
||||
response.Error = errors.New("task ID is empty")
|
||||
response.Command = "error"
|
||||
} else {
|
||||
response.Command = "completed"
|
||||
response.MessageID = msg.ID
|
||||
}
|
||||
return utils.Write(ctx, c.conn, response)
|
||||
}
|
||||
|
||||
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
|
||||
var cmdMsg Command
|
||||
var task Task
|
||||
err := json.Unmarshal(message, &cmdMsg)
|
||||
if err == nil && cmdMsg.Command != 0 {
|
||||
return c.handleCommandMessage(cmdMsg)
|
||||
}
|
||||
err = json.Unmarshal(message, &task)
|
||||
if err == nil {
|
||||
return c.handleTaskMessage(ctx, task)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
maxRetries = 5
|
||||
initialDelay = 2 * time.Second
|
||||
maxBackoff = 30 * time.Second // Upper limit for backoff delay
|
||||
jitterPercent = 0.5 // 50% jitter
|
||||
)
|
||||
|
||||
func (c *Consumer) AttemptConnect() error {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
delay := initialDelay
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
conn, err = net.Dial("tcp", c.serverAddr)
|
||||
if err == nil {
|
||||
c.conn = conn
|
||||
return nil
|
||||
}
|
||||
sleepDuration := calculateJitter(delay)
|
||||
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.serverAddr, i+1, maxRetries, err, sleepDuration)
|
||||
time.Sleep(sleepDuration)
|
||||
delay *= 2
|
||||
if delay > maxBackoff {
|
||||
delay = maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.serverAddr, maxRetries, err)
|
||||
}
|
||||
|
||||
func calculateJitter(baseDelay time.Duration) time.Duration {
|
||||
jitter := time.Duration(rand.Float64()*jitterPercent*float64(baseDelay)) - time.Duration(jitterPercent*float64(baseDelay)/2)
|
||||
return baseDelay + jitter
|
||||
}
|
||||
|
||||
func (c *Consumer) Consume(ctx context.Context, queues ...string) error {
|
||||
err := c.AttemptConnect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
utils.ReadFromConn(ctx, c.conn, func(ctx context.Context, conn net.Conn, message []byte) error {
|
||||
return c.readMessage(ctx, message)
|
||||
})
|
||||
}()
|
||||
c.queues = slices.Compact(append(c.queues, queues...))
|
||||
for _, q := range c.queues {
|
||||
if err := c.subscribe(q); err != nil {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
||||
}
|
||||
fmt.Println("Consumer started on", q)
|
||||
}
|
||||
wg.Wait()
|
||||
fmt.Println("Consumer stopped.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
||||
c.handlers[queue] = handler
|
||||
}
|
144
consumer.go
Normal file
144
consumer.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
"math/rand"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Consumer struct {
|
||||
serverAddr string
|
||||
handlers map[string]Handler
|
||||
queues []string
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func NewConsumer(serverAddr string, queues ...string) *Consumer {
|
||||
return &Consumer{
|
||||
handlers: make(map[string]Handler),
|
||||
serverAddr: serverAddr,
|
||||
queues: queues,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Consumer) subscribe(queue string) error {
|
||||
ctx := context.Background()
|
||||
subscribe := Command{Command: SUBSCRIBE, Queue: queue}
|
||||
return utils.Write(ctx, c.conn, subscribe)
|
||||
}
|
||||
|
||||
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
||||
handler, exists := c.handlers[msg.CurrentQueue]
|
||||
if !exists {
|
||||
return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)}
|
||||
}
|
||||
return handler(ctx, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) handleCommandMessage(msg Command) error {
|
||||
switch msg.Command {
|
||||
case STOP:
|
||||
return c.Close()
|
||||
default:
|
||||
return fmt.Errorf("unknown command in consumer %d", msg.Command)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
|
||||
response := c.ProcessTask(ctx, msg)
|
||||
response.Queue = msg.CurrentQueue
|
||||
if msg.ID == "" {
|
||||
response.Error = errors.New("task ID is empty")
|
||||
response.Command = "error"
|
||||
} else {
|
||||
response.Command = "completed"
|
||||
response.MessageID = msg.ID
|
||||
}
|
||||
return utils.Write(ctx, c.conn, response)
|
||||
}
|
||||
|
||||
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
|
||||
var cmdMsg Command
|
||||
var task Task
|
||||
err := json.Unmarshal(message, &cmdMsg)
|
||||
if err == nil && cmdMsg.Command != 0 {
|
||||
return c.handleCommandMessage(cmdMsg)
|
||||
}
|
||||
err = json.Unmarshal(message, &task)
|
||||
if err == nil {
|
||||
return c.handleTaskMessage(ctx, task)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
maxRetries = 5
|
||||
initialDelay = 2 * time.Second
|
||||
maxBackoff = 30 * time.Second // Upper limit for backoff delay
|
||||
jitterPercent = 0.5 // 50% jitter
|
||||
)
|
||||
|
||||
func (c *Consumer) AttemptConnect() error {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
delay := initialDelay
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
conn, err = net.Dial("tcp", c.serverAddr)
|
||||
if err == nil {
|
||||
c.conn = conn
|
||||
return nil
|
||||
}
|
||||
sleepDuration := calculateJitter(delay)
|
||||
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.serverAddr, i+1, maxRetries, err, sleepDuration)
|
||||
time.Sleep(sleepDuration)
|
||||
delay *= 2
|
||||
if delay > maxBackoff {
|
||||
delay = maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.serverAddr, maxRetries, err)
|
||||
}
|
||||
|
||||
func calculateJitter(baseDelay time.Duration) time.Duration {
|
||||
jitter := time.Duration(rand.Float64()*jitterPercent*float64(baseDelay)) - time.Duration(jitterPercent*float64(baseDelay)/2)
|
||||
return baseDelay + jitter
|
||||
}
|
||||
|
||||
func (c *Consumer) Consume(ctx context.Context, queues ...string) error {
|
||||
err := c.AttemptConnect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
utils.ReadFromConn(ctx, c.conn, func(ctx context.Context, conn net.Conn, message []byte) error {
|
||||
return c.readMessage(ctx, message)
|
||||
})
|
||||
}()
|
||||
c.queues = slices.Compact(append(c.queues, queues...))
|
||||
for _, q := range c.queues {
|
||||
if err := c.subscribe(q); err != nil {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
||||
c.handlers[queue] = handler
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package broker
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -124,7 +124,6 @@ func (dag *DAG) AddEdge(fromNodeID, toNodeID string) error {
|
||||
return err
|
||||
}
|
||||
dag.edges = append(dag.edges, []string{fromNodeID, toNodeID})
|
||||
fmt.Printf("Edge added from %s to %s\n", fromNodeID, toNodeID)
|
||||
return nil
|
||||
}
|
||||
|
@@ -3,19 +3,18 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/oarkflow/mq/broker"
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
func main() {
|
||||
consumer := broker.NewConsumer(":8080")
|
||||
consumer.RegisterHandler("queue1", func(ctx context.Context, task broker.Task) broker.Result {
|
||||
consumer := mq.NewConsumer(":8080")
|
||||
consumer.RegisterHandler("queue1", func(ctx context.Context, task mq.Task) mq.Result {
|
||||
fmt.Println("Handling task for queue1:", task.ID)
|
||||
return broker.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
})
|
||||
consumer.RegisterHandler("queue2", func(ctx context.Context, task broker.Task) broker.Result {
|
||||
consumer.RegisterHandler("queue2", func(ctx context.Context, task mq.Task) mq.Result {
|
||||
fmt.Println("Handling task for queue2:", task.ID)
|
||||
return broker.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
})
|
||||
consumer.Consume(context.Background(), "queue2", "queue1")
|
||||
}
|
||||
|
@@ -4,12 +4,11 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/oarkflow/mq"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/broker"
|
||||
)
|
||||
|
||||
func handleNode1(_ context.Context, task broker.Task) broker.Result {
|
||||
func handleNode1(_ context.Context, task mq.Task) mq.Result {
|
||||
result := []map[string]string{
|
||||
{"field": "facility", "item": "item1"},
|
||||
{"field": "facility", "item": "item2"},
|
||||
@@ -18,18 +17,18 @@ func handleNode1(_ context.Context, task broker.Task) broker.Result {
|
||||
var payload string
|
||||
err := json.Unmarshal(task.Payload, &payload)
|
||||
if err != nil {
|
||||
return broker.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)}
|
||||
return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)}
|
||||
}
|
||||
fmt.Printf("Processing task at node1: %s\n", string(task.Payload))
|
||||
bt, _ := json.Marshal(result)
|
||||
return broker.Result{Status: "completed", Payload: bt}
|
||||
return mq.Result{Status: "completed", Payload: bt}
|
||||
}
|
||||
|
||||
func handleNode2(_ context.Context, task broker.Task) broker.Result {
|
||||
func handleNode2(_ context.Context, task mq.Task) mq.Result {
|
||||
var payload map[string]string
|
||||
err := json.Unmarshal(task.Payload, &payload)
|
||||
if err != nil {
|
||||
return broker.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)}
|
||||
return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)}
|
||||
}
|
||||
status := "fail"
|
||||
if payload["item"] == "item2" {
|
||||
@@ -37,24 +36,24 @@ func handleNode2(_ context.Context, task broker.Task) broker.Result {
|
||||
}
|
||||
fmt.Printf("Processing task at node2: %s %s\n", payload, status)
|
||||
bt, _ := json.Marshal(payload)
|
||||
return broker.Result{Status: status, Payload: bt}
|
||||
return mq.Result{Status: status, Payload: bt}
|
||||
}
|
||||
|
||||
func handleNode3(_ context.Context, task broker.Task) broker.Result {
|
||||
func handleNode3(_ context.Context, task mq.Task) mq.Result {
|
||||
result := `{"field": "node3", "item": %s}`
|
||||
fmt.Printf("Processing task at node3: %s\n", string(task.Payload))
|
||||
return broker.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))}
|
||||
return mq.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))}
|
||||
}
|
||||
|
||||
func handleNode4(_ context.Context, task broker.Task) broker.Result {
|
||||
func handleNode4(_ context.Context, task mq.Task) mq.Result {
|
||||
result := `{"field": "node4", "item": %s}`
|
||||
fmt.Printf("Processing task at node4: %s\n", string(task.Payload))
|
||||
return broker.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))}
|
||||
return mq.Result{Status: "completed", Payload: json.RawMessage(fmt.Sprintf(result, string(task.Payload)))}
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
d := broker.NewDAG(":8082", false)
|
||||
d := mq.NewDAG(":8082", false)
|
||||
|
||||
d.AddNode("node1", handleNode1, true)
|
||||
d.AddNode("node2", handleNode2)
|
||||
@@ -75,7 +74,7 @@ func main() {
|
||||
fmt.Println("Error starting DAG:", err)
|
||||
}
|
||||
}()
|
||||
result := d.ProcessTask(ctx, broker.Task{Payload: []byte(`"Start processing"`)})
|
||||
result := d.ProcessTask(ctx, mq.Task{Payload: []byte(`"Start processing"`)})
|
||||
fmt.Println(string(result.Payload))
|
||||
time.Sleep(50 * time.Second)
|
||||
}
|
||||
|
@@ -3,13 +3,12 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/oarkflow/mq"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/broker"
|
||||
)
|
||||
|
||||
func main() {
|
||||
b := broker.NewBroker(func(ctx context.Context, task *broker.Task) error {
|
||||
b := mq.NewBroker(func(ctx context.Context, task *mq.Task) error {
|
||||
fmt.Println("Received task", task.ID, string(task.Payload), string(task.Result), task.Error, task.CurrentQueue)
|
||||
return nil
|
||||
})
|
||||
@@ -17,11 +16,11 @@ func main() {
|
||||
b.NewQueue("queue2")
|
||||
go func() {
|
||||
for i := 0; i < 10; i++ {
|
||||
b.Publish(context.Background(), broker.Task{
|
||||
b.Publish(context.Background(), mq.Task{
|
||||
ID: fmt.Sprint(i),
|
||||
Payload: []byte(`"Hello"`),
|
||||
}, "queue1")
|
||||
b.Publish(context.Background(), broker.Task{
|
||||
b.Publish(context.Background(), mq.Task{
|
||||
ID: fmt.Sprint(i),
|
||||
Payload: []byte(`"World"`),
|
||||
}, "queue2")
|
||||
|
Reference in New Issue
Block a user