feat: separate broker

This commit is contained in:
Oarkflow
2024-10-05 22:07:20 +05:45
parent 5d759db34c
commit 12c704b01b
5 changed files with 131 additions and 99 deletions

View File

@@ -122,7 +122,10 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message)
func (b *Broker) Publish(ctx context.Context, task Task, queue string) error { func (b *Broker) Publish(ctx context.Context, task Task, queue string) error {
headers, _ := GetHeaders(ctx) headers, _ := GetHeaders(ctx)
payload, _ := json.Marshal(task) payload, err := json.Marshal(task)
if err != nil {
return err
}
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers) msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers)
b.broadcastToConsumers(ctx, msg) b.broadcastToConsumers(ctx, msg)
return nil return nil

View File

@@ -92,10 +92,12 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C
result := c.ProcessTask(ctx, task) result := c.ProcessTask(ctx, task)
result.MessageID = task.ID result.MessageID = task.ID
result.Queue = msg.Queue result.Queue = msg.Queue
if result.Error != nil { if result.Status == "" {
result.Status = "FAILED" if result.Error != nil {
} else { result.Status = "FAILED"
result.Status = "SUCCESS" } else {
result.Status = "SUCCESS"
}
} }
bt, _ := json.Marshal(result) bt, _ := json.Marshal(result)
reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers) reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers)

View File

@@ -26,6 +26,7 @@ type DAG struct {
server *mq.Broker server *mq.Broker
nodes map[string]*mq.Consumer nodes map[string]*mq.Consumer
edges map[string]string edges map[string]string
conditions map[string]map[string]string
loopEdges map[string][]string loopEdges map[string][]string
taskChMap map[string]chan mq.Result taskChMap map[string]chan mq.Result
taskResults map[string]map[string]*taskContext taskResults map[string]map[string]*taskContext
@@ -36,6 +37,7 @@ func New(opts ...mq.Option) *DAG {
d := &DAG{ d := &DAG{
nodes: make(map[string]*mq.Consumer), nodes: make(map[string]*mq.Consumer),
edges: make(map[string]string), edges: make(map[string]string),
conditions: make(map[string]map[string]string),
loopEdges: make(map[string][]string), loopEdges: make(map[string][]string),
taskChMap: make(map[string]chan mq.Result), taskChMap: make(map[string]chan mq.Result),
taskResults: make(map[string]map[string]*taskContext), taskResults: make(map[string]map[string]*taskContext),
@@ -55,6 +57,10 @@ func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) {
d.nodes[name] = con d.nodes[name] = con
} }
func (d *DAG) AddCondition(fromNode string, conditions map[string]string) {
d.conditions[fromNode] = conditions
}
func (d *DAG) AddEdge(fromNode string, toNodes string) { func (d *DAG) AddEdge(fromNode string, toNodes string) {
d.edges[fromNode] = toNodes d.edges[fromNode] = toNodes
} }
@@ -156,7 +162,7 @@ func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result {
return d.sendSync(ctx, mq.Result{Payload: payload}) return d.sendSync(ctx, mq.Result{Payload: payload})
} }
resultCh := make(chan mq.Result) resultCh := make(chan mq.Result)
result := d.PublishTask(ctx, payload, d.FirstNode) result := d.PublishTask(ctx, payload)
if result.Error != nil { if result.Error != nil {
return result return result
} }
@@ -216,6 +222,21 @@ func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result {
} }
result.Payload = bt result.Payload = bt
} }
if conditions, ok := d.conditions[task.Queue]; ok {
if target, exists := conditions[result.Status]; exists {
ctx = mq.SetHeaders(ctx, map[string]string{
consts.QueueKey: target,
})
result = d.sendSync(ctx, mq.Result{
Payload: result.Payload,
Queue: target,
MessageID: result.MessageID,
})
if result.Error != nil {
return result
}
}
}
if target, ok := d.edges[task.Queue]; ok { if target, ok := d.edges[task.Queue]; ok {
ctx = mq.SetHeaders(ctx, map[string]string{ ctx = mq.SetHeaders(ctx, map[string]string{
consts.QueueKey: target, consts.QueueKey: target,
@@ -277,6 +298,7 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
return mq.Result{Error: task.Error} return mq.Result{Error: task.Error}
} }
triggeredNode, ok := mq.GetTriggerNode(ctx) triggeredNode, ok := mq.GetTriggerNode(ctx)
fmt.Println(task.Queue, triggeredNode, ok)
payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode)
if loopNodes, exists := d.loopEdges[task.Queue]; exists { if loopNodes, exists := d.loopEdges[task.Queue]; exists {
var items []json.RawMessage var items []json.RawMessage
@@ -308,34 +330,48 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
if multipleResults && completed { if multipleResults && completed {
task.Queue = triggeredNode task.Queue = triggeredNode
} }
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) if conditions, ok := d.conditions[task.Queue]; ok {
edge, exists := d.edges[task.Queue] if target, exists := conditions[task.Status]; exists {
if exists { ctx = mq.SetHeaders(ctx, map[string]string{
d.taskResults[task.MessageID] = map[string]*taskContext{ consts.QueueKey: target,
task.Queue: { consts.TriggerNode: task.Queue,
totalItems: 1, })
}, result := d.PublishTask(ctx, payload, task.MessageID)
} if result.Error != nil {
ctx = mq.SetHeaders(ctx, map[string]string{ return result
consts.QueueKey: edge,
})
result := d.PublishTask(ctx, payload, task.MessageID)
if result.Error != nil {
return result
}
} else if completed {
d.mu.Lock()
if resultCh, ok := d.taskChMap[task.MessageID]; ok {
resultCh <- mq.Result{
Payload: payload,
Queue: task.Queue,
MessageID: task.MessageID,
Status: "done",
} }
delete(d.taskChMap, task.MessageID)
delete(d.taskResults, task.MessageID)
} }
d.mu.Unlock() } else {
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue})
edge, exists := d.edges[task.Queue]
if exists {
d.taskResults[task.MessageID] = map[string]*taskContext{
task.Queue: {
totalItems: 1,
},
}
ctx = mq.SetHeaders(ctx, map[string]string{
consts.QueueKey: edge,
})
result := d.PublishTask(ctx, payload, task.MessageID)
if result.Error != nil {
return result
}
} else if completed {
d.mu.Lock()
if resultCh, ok := d.taskChMap[task.MessageID]; ok {
resultCh <- mq.Result{
Payload: payload,
Queue: task.Queue,
MessageID: task.MessageID,
Status: "done",
}
delete(d.taskChMap, task.MessageID)
delete(d.taskResults, task.MessageID)
}
d.mu.Unlock()
}
} }
return task return task
} }

View File

@@ -2,67 +2,33 @@ package main
import ( import (
"context" "context"
"encoding/json" "fmt"
"io"
"net/http"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag" "github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/examples/tasks"
"time"
) )
var d *dag.DAG var d *dag.DAG
func main() { func main() {
d = dag.New(mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.crt")) d = dag.New()
d.AddNode("queue1", tasks.Node1) d.AddNode("queue1", tasks.CheckCondition, true)
d.AddNode("queue2", tasks.Node2) d.AddNode("queue2", tasks.Pass)
d.AddNode("queue3", tasks.Node3) d.AddNode("queue3", tasks.Fail)
d.AddNode("queue4", tasks.Node4)
d.AddEdge("queue1", "queue2") d.AddCondition("queue1", map[string]string{"pass": "queue2", "fail": "queue3"})
d.AddLoop("queue2", "queue3")
d.AddEdge("queue2", "queue4")
d.Prepare() d.Prepare()
http.HandleFunc("POST /publish", requestHandler("publish")) go func() {
http.HandleFunc("POST /request", requestHandler("request")) d.Start(context.Background(), ":8081")
err := d.Start(context.TODO(), ":8083") }()
if err != nil { go func() {
panic(err) time.Sleep(3 * time.Second)
} result := d.Send(context.Background(), []byte(`{"user_id": 1}`))
} if result.Error != nil {
panic(result.Error)
}
fmt.Println("Response", string(result.Payload))
}()
func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { time.Sleep(10 * time.Second)
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var payload []byte
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err = io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
var rs mq.Result
if requestType == "request" {
rs = d.Request(context.Background(), payload)
} else {
rs = d.Send(context.Background(), payload)
}
w.Header().Set("Content-Type", "application/json")
result := map[string]any{
"message_id": rs.MessageID,
"payload": string(rs.Payload),
"error": rs.Error,
}
json.NewEncoder(w).Encode(result)
}
} }

View File

@@ -4,40 +4,65 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
mq2 "github.com/oarkflow/mq" "github.com/oarkflow/mq"
) )
func Node1(ctx context.Context, task mq2.Task) mq2.Result { func Node1(ctx context.Context, task mq.Task) mq.Result {
return mq2.Result{Payload: task.Payload, MessageID: task.ID} return mq.Result{Payload: task.Payload, MessageID: task.ID}
} }
func Node2(ctx context.Context, task mq2.Task) mq2.Result { func Node2(ctx context.Context, task mq.Task) mq.Result {
return mq2.Result{Payload: task.Payload, MessageID: task.ID} return mq.Result{Payload: task.Payload, MessageID: task.ID}
} }
func Node3(ctx context.Context, task mq2.Task) mq2.Result { func Node3(ctx context.Context, task mq.Task) mq.Result {
var data map[string]any var data map[string]any
err := json.Unmarshal(task.Payload, &data) err := json.Unmarshal(task.Payload, &data)
if err != nil { if err != nil {
return mq2.Result{Error: err} return mq.Result{Error: err}
} }
data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
bt, _ := json.Marshal(data) bt, _ := json.Marshal(data)
return mq2.Result{Payload: bt, MessageID: task.ID} return mq.Result{Payload: bt, MessageID: task.ID}
} }
func Node4(ctx context.Context, task mq2.Task) mq2.Result { func Node4(ctx context.Context, task mq.Task) mq.Result {
var data []map[string]any var data []map[string]any
err := json.Unmarshal(task.Payload, &data) err := json.Unmarshal(task.Payload, &data)
if err != nil { if err != nil {
return mq2.Result{Error: err} return mq.Result{Error: err}
} }
payload := map[string]any{"storage": data} payload := map[string]any{"storage": data}
bt, _ := json.Marshal(payload) bt, _ := json.Marshal(payload)
return mq2.Result{Payload: bt, MessageID: task.ID} return mq.Result{Payload: bt, MessageID: task.ID}
} }
func Callback(ctx context.Context, task mq2.Result) mq2.Result { func CheckCondition(ctx context.Context, task mq.Task) mq.Result {
fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue) var data map[string]any
return mq2.Result{} err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
}
var status string
if data["user_id"].(float64) == 2 {
status = "pass"
} else {
status = "fail"
}
return mq.Result{Status: status, Payload: task.Payload, MessageID: task.ID}
}
func Pass(ctx context.Context, task mq.Task) mq.Result {
fmt.Println("Pass")
return mq.Result{Payload: task.Payload}
}
func Fail(ctx context.Context, task mq.Task) mq.Result {
fmt.Println("Fail")
return mq.Result{Payload: []byte(`{"test2": "asdsa"}`)}
}
func Callback(ctx context.Context, task mq.Result) mq.Result {
fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue)
return mq.Result{}
} }