From 12c704b01b0862d06ec8e9620507e5c44d4c37d9 Mon Sep 17 00:00:00 2001 From: Oarkflow Date: Sat, 5 Oct 2024 22:07:20 +0545 Subject: [PATCH] feat: separate broker --- broker.go | 5 ++- consumer.go | 10 +++-- dag/dag.go | 90 ++++++++++++++++++++++++++++------------- examples/dag.go | 72 +++++++++------------------------ examples/tasks/tasks.go | 53 +++++++++++++++++------- 5 files changed, 131 insertions(+), 99 deletions(-) diff --git a/broker.go b/broker.go index dac2fe9..95ef01a 100644 --- a/broker.go +++ b/broker.go @@ -122,7 +122,10 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) func (b *Broker) Publish(ctx context.Context, task Task, queue string) error { headers, _ := GetHeaders(ctx) - payload, _ := json.Marshal(task) + payload, err := json.Marshal(task) + if err != nil { + return err + } msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers) b.broadcastToConsumers(ctx, msg) return nil diff --git a/consumer.go b/consumer.go index 570c03e..2aa2926 100644 --- a/consumer.go +++ b/consumer.go @@ -92,10 +92,12 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C result := c.ProcessTask(ctx, task) result.MessageID = task.ID result.Queue = msg.Queue - if result.Error != nil { - result.Status = "FAILED" - } else { - result.Status = "SUCCESS" + if result.Status == "" { + if result.Error != nil { + result.Status = "FAILED" + } else { + result.Status = "SUCCESS" + } } bt, _ := json.Marshal(result) reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers) diff --git a/dag/dag.go b/dag/dag.go index f9ffa70..52307c1 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -26,6 +26,7 @@ type DAG struct { server *mq.Broker nodes map[string]*mq.Consumer edges map[string]string + conditions map[string]map[string]string loopEdges map[string][]string taskChMap map[string]chan mq.Result taskResults map[string]map[string]*taskContext @@ -36,6 +37,7 @@ func New(opts ...mq.Option) *DAG { d := &DAG{ nodes: make(map[string]*mq.Consumer), edges: make(map[string]string), + conditions: make(map[string]map[string]string), loopEdges: make(map[string][]string), taskChMap: make(map[string]chan mq.Result), taskResults: make(map[string]map[string]*taskContext), @@ -55,6 +57,10 @@ func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { d.nodes[name] = con } +func (d *DAG) AddCondition(fromNode string, conditions map[string]string) { + d.conditions[fromNode] = conditions +} + func (d *DAG) AddEdge(fromNode string, toNodes string) { d.edges[fromNode] = toNodes } @@ -156,7 +162,7 @@ func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { return d.sendSync(ctx, mq.Result{Payload: payload}) } resultCh := make(chan mq.Result) - result := d.PublishTask(ctx, payload, d.FirstNode) + result := d.PublishTask(ctx, payload) if result.Error != nil { return result } @@ -216,6 +222,21 @@ func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { } result.Payload = bt } + if conditions, ok := d.conditions[task.Queue]; ok { + if target, exists := conditions[result.Status]; exists { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + }) + result = d.sendSync(ctx, mq.Result{ + Payload: result.Payload, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + } + } if target, ok := d.edges[task.Queue]; ok { ctx = mq.SetHeaders(ctx, map[string]string{ consts.QueueKey: target, @@ -277,6 +298,7 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { return mq.Result{Error: task.Error} } triggeredNode, ok := mq.GetTriggerNode(ctx) + fmt.Println(task.Queue, triggeredNode, ok) payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) if loopNodes, exists := d.loopEdges[task.Queue]; exists { var items []json.RawMessage @@ -308,34 +330,48 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { if multipleResults && completed { task.Queue = triggeredNode } - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) - edge, exists := d.edges[task.Queue] - if exists { - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { - totalItems: 1, - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: edge, - }) - result := d.PublishTask(ctx, payload, task.MessageID) - if result.Error != nil { - return result - } - } else if completed { - d.mu.Lock() - if resultCh, ok := d.taskChMap[task.MessageID]; ok { - resultCh <- mq.Result{ - Payload: payload, - Queue: task.Queue, - MessageID: task.MessageID, - Status: "done", + if conditions, ok := d.conditions[task.Queue]; ok { + if target, exists := conditions[task.Status]; exists { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + consts.TriggerNode: task.Queue, + }) + result := d.PublishTask(ctx, payload, task.MessageID) + if result.Error != nil { + return result } - delete(d.taskChMap, task.MessageID) - delete(d.taskResults, task.MessageID) } - d.mu.Unlock() + } else { + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) + edge, exists := d.edges[task.Queue] + if exists { + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { + totalItems: 1, + }, + } + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: edge, + }) + result := d.PublishTask(ctx, payload, task.MessageID) + if result.Error != nil { + return result + } + } else if completed { + d.mu.Lock() + if resultCh, ok := d.taskChMap[task.MessageID]; ok { + resultCh <- mq.Result{ + Payload: payload, + Queue: task.Queue, + MessageID: task.MessageID, + Status: "done", + } + delete(d.taskChMap, task.MessageID) + delete(d.taskResults, task.MessageID) + } + d.mu.Unlock() + } } + return task } diff --git a/examples/dag.go b/examples/dag.go index 19054b7..a848832 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -2,67 +2,33 @@ package main import ( "context" - "encoding/json" - "io" - "net/http" - - "github.com/oarkflow/mq" + "fmt" "github.com/oarkflow/mq/dag" "github.com/oarkflow/mq/examples/tasks" + "time" ) var d *dag.DAG func main() { - d = dag.New(mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.crt")) - d.AddNode("queue1", tasks.Node1) - d.AddNode("queue2", tasks.Node2) - d.AddNode("queue3", tasks.Node3) - d.AddNode("queue4", tasks.Node4) + d = dag.New() + d.AddNode("queue1", tasks.CheckCondition, true) + d.AddNode("queue2", tasks.Pass) + d.AddNode("queue3", tasks.Fail) - d.AddEdge("queue1", "queue2") - d.AddLoop("queue2", "queue3") - d.AddEdge("queue2", "queue4") + d.AddCondition("queue1", map[string]string{"pass": "queue2", "fail": "queue3"}) d.Prepare() - http.HandleFunc("POST /publish", requestHandler("publish")) - http.HandleFunc("POST /request", requestHandler("request")) - err := d.Start(context.TODO(), ":8083") - if err != nil { - panic(err) - } -} + go func() { + d.Start(context.Background(), ":8081") + }() + go func() { + time.Sleep(3 * time.Second) + result := d.Send(context.Background(), []byte(`{"user_id": 1}`)) + if result.Error != nil { + panic(result.Error) + } + fmt.Println("Response", string(result.Payload)) + }() -func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var payload []byte - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err = io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return - } - var rs mq.Result - if requestType == "request" { - rs = d.Request(context.Background(), payload) - } else { - rs = d.Send(context.Background(), payload) - } - w.Header().Set("Content-Type", "application/json") - result := map[string]any{ - "message_id": rs.MessageID, - "payload": string(rs.Payload), - "error": rs.Error, - } - json.NewEncoder(w).Encode(result) - } + time.Sleep(10 * time.Second) } diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index 605de66..fd7d534 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -4,40 +4,65 @@ import ( "context" "encoding/json" "fmt" - mq2 "github.com/oarkflow/mq" + "github.com/oarkflow/mq" ) -func Node1(ctx context.Context, task mq2.Task) mq2.Result { - return mq2.Result{Payload: task.Payload, MessageID: task.ID} +func Node1(ctx context.Context, task mq.Task) mq.Result { + return mq.Result{Payload: task.Payload, MessageID: task.ID} } -func Node2(ctx context.Context, task mq2.Task) mq2.Result { - return mq2.Result{Payload: task.Payload, MessageID: task.ID} +func Node2(ctx context.Context, task mq.Task) mq.Result { + return mq.Result{Payload: task.Payload, MessageID: task.ID} } -func Node3(ctx context.Context, task mq2.Task) mq2.Result { +func Node3(ctx context.Context, task mq.Task) mq.Result { var data map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { - return mq2.Result{Error: err} + return mq.Result{Error: err} } data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) bt, _ := json.Marshal(data) - return mq2.Result{Payload: bt, MessageID: task.ID} + return mq.Result{Payload: bt, MessageID: task.ID} } -func Node4(ctx context.Context, task mq2.Task) mq2.Result { +func Node4(ctx context.Context, task mq.Task) mq.Result { var data []map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { - return mq2.Result{Error: err} + return mq.Result{Error: err} } payload := map[string]any{"storage": data} bt, _ := json.Marshal(payload) - return mq2.Result{Payload: bt, MessageID: task.ID} + return mq.Result{Payload: bt, MessageID: task.ID} } -func Callback(ctx context.Context, task mq2.Result) mq2.Result { - fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue) - return mq2.Result{} +func CheckCondition(ctx context.Context, task mq.Task) mq.Result { + var data map[string]any + err := json.Unmarshal(task.Payload, &data) + if err != nil { + return mq.Result{Error: err} + } + var status string + if data["user_id"].(float64) == 2 { + status = "pass" + } else { + status = "fail" + } + return mq.Result{Status: status, Payload: task.Payload, MessageID: task.ID} +} + +func Pass(ctx context.Context, task mq.Task) mq.Result { + fmt.Println("Pass") + return mq.Result{Payload: task.Payload} +} + +func Fail(ctx context.Context, task mq.Task) mq.Result { + fmt.Println("Fail") + return mq.Result{Payload: []byte(`{"test2": "asdsa"}`)} +} + +func Callback(ctx context.Context, task mq.Result) mq.Result { + fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue) + return mq.Result{} }