From 9c8712994d89cdbe4516dd1e15ca319d9106fed2 Mon Sep 17 00:00:00 2001 From: Oarkflow Date: Wed, 9 Oct 2024 11:24:54 +0545 Subject: [PATCH] feat: separate broker --- dag/task_manager.go | 93 ++++++++++++++++------------------------- examples/dag.go | 73 ++++++-------------------------- examples/tasks/tasks.go | 66 +++++++++++++---------------- options.go | 23 ++++++---- util.go | 4 ++ 5 files changed, 98 insertions(+), 161 deletions(-) diff --git a/dag/task_manager.go b/dag/task_manager.go index 3a2347b..e09842c 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -6,6 +6,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/oarkflow/mq" "github.com/oarkflow/mq/consts" @@ -14,13 +15,13 @@ import ( type TaskManager struct { taskID string dag *DAG - wg sync.WaitGroup mutex sync.Mutex + createdAt time.Time + processedAt time.Time results []mq.Result waitingCallback int64 nodeResults map[string]mq.Result - done chan struct{} - finalResult chan mq.Result // Channel to collect final results + finalResult chan mq.Result } func NewTaskManager(d *DAG, taskID string) *TaskManager { @@ -32,64 +33,45 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { } } -func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result { - tm.done = make(chan struct{}) - tm.wg.Add(1) - go tm.processNode(ctx, node, payload) - go func() { - tm.wg.Wait() - close(tm.done) - }() - select { - case <-ctx.Done(): - return mq.Result{Error: ctx.Err()} - case <-tm.done: - tm.mutex.Lock() - defer tm.mutex.Unlock() - if len(tm.results) == 1 { - return tm.handleResult(ctx, tm.results[0]) - } - return tm.handleResult(ctx, tm.results) - } -} - -func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result { - tm.finalResult = make(chan mq.Result) - tm.wg.Add(1) - go tm.processNode(ctx, node, payload) - go func() { - tm.wg.Wait() - }() - select { - case result := <-tm.finalResult: // Block until a result is available - return result - case <-ctx.Done(): // Handle context cancellation - return mq.Result{Error: ctx.Err()} - } -} - func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { node, ok := tm.dag.Nodes[nodeID] if !ok { return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } - if tm.dag.server.SyncMode() { - return tm.handleSyncTask(ctx, node, payload) + tm.createdAt = time.Now() + tm.finalResult = make(chan mq.Result, 0) + go tm.processNode(ctx, node, payload) + awaitResponse, ok := mq.GetAwaitResponse(ctx) + if awaitResponse != "true" { + go func() { + finalResult := <-tm.finalResult + finalResult.CreatedAt = tm.createdAt + finalResult.ProcessedAt = time.Now() + if tm.dag.server.NotifyHandler() != nil { + tm.dag.server.NotifyHandler()(ctx, finalResult) + } + }() + return mq.Result{CreatedAt: tm.createdAt, TaskID: tm.taskID, Topic: nodeID, Status: "PENDING"} + } else { + finalResult := <-tm.finalResult + finalResult.CreatedAt = tm.createdAt + finalResult.ProcessedAt = time.Now() + if tm.dag.server.NotifyHandler() != nil { + tm.dag.server.NotifyHandler()(ctx, finalResult) + } + return finalResult } - return tm.handleAsyncTask(ctx, node, payload) } func (tm *TaskManager) dispatchFinalResult(ctx context.Context) { - if !tm.dag.server.SyncMode() { - var rs mq.Result - if len(tm.results) == 1 { - rs = tm.handleResult(ctx, tm.results[0]) - } else { - rs = tm.handleResult(ctx, tm.results) - } - if tm.waitingCallback == 0 { - tm.finalResult <- rs - } + var rs mq.Result + if len(tm.results) == 1 { + rs = tm.handleResult(ctx, tm.results[0]) + } else { + rs = tm.handleResult(ctx, tm.results) + } + if tm.waitingCallback == 0 { + tm.finalResult <- rs } } @@ -127,13 +109,11 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. return result } for _, item := range items { - tm.wg.Add(1) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) go tm.processNode(ctx, edge.To, item) } case SimpleEdge: if edge.To != nil { - tm.wg.Add(1) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) go tm.processNode(ctx, edge.To, result.Payload) } @@ -147,10 +127,11 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result switch res := results.(type) { case []mq.Result: aggregatedOutput := make([]json.RawMessage, 0) - status := "" + var status, topic string for i, result := range res { if i == 0 { status = result.Status + topic = result.Topic } if result.Error != nil { return mq.HandleError(ctx, result.Error) @@ -170,6 +151,7 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result TaskID: tm.taskID, Payload: finalOutput, Status: status, + Topic: topic, } case mq.Result: return res @@ -186,7 +168,6 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) { func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { atomic.AddInt64(&tm.waitingCallback, 1) - defer tm.wg.Done() var result mq.Result select { case <-ctx.Done(): diff --git a/examples/dag.go b/examples/dag.go index c3fa15a..527185c 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -3,6 +3,8 @@ package main import ( "context" "encoding/json" + "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/examples/tasks" "io" "net/http" @@ -10,70 +12,23 @@ import ( "github.com/oarkflow/mq/dag" ) -func handler1(ctx context.Context, task *mq.Task) mq.Result { - return mq.Result{Payload: task.Payload} -} - -func handler2(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - return mq.Result{Payload: task.Payload} -} - -func handler3(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - age := int(user["age"].(float64)) - status := "FAIL" - if age > 20 { - status = "PASS" - } - user["status"] = status - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Status: status} -} - -func handler4(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - user["final"] = "D" - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload} -} - -func handler5(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - user["salary"] = "E" - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload} -} - -func handler6(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - resultPayload, _ := json.Marshal(map[string]any{"storage": user}) - return mq.Result{Payload: resultPayload} -} - var ( - d = dag.NewDAG(mq.WithSyncMode(true)) + d = dag.NewDAG(mq.WithSyncMode(false), mq.WithNotifyResponse(tasks.NotifyResponse)) // d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) ) func main() { - d.AddNode("A", handler1, true) - d.AddNode("B", handler2) - d.AddNode("C", handler3) - d.AddNode("D", handler4) - d.AddNode("E", handler5) - d.AddNode("F", handler6) + d.AddNode("A", tasks.Node1, true) + d.AddNode("B", tasks.Node2) + d.AddNode("C", tasks.Node3) + d.AddNode("D", tasks.Node4) + d.AddNode("E", tasks.Node5) + d.AddNode("F", tasks.Node6) d.AddEdge("A", "B", dag.LoopEdge) d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) d.AddEdge("B", "C") d.AddEdge("D", "F") d.AddEdge("E", "F") - // fmt.Println(rs.TaskID, "Task", string(rs.Payload)) http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) err := d.Start(context.TODO(), ":8083") @@ -102,14 +57,12 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ return } ctx := context.Background() + if requestType == "request" { + ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) + } // ctx = context.WithValue(ctx, "initial_node", "E") rs := d.ProcessTask(ctx, payload) w.Header().Set("Content-Type", "application/json") - result := map[string]any{ - "message_id": rs.TaskID, - "payload": rs.Payload, - "error": rs.Error, - } - json.NewEncoder(w).Encode(result) + json.NewEncoder(w).Encode(rs) } } diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index 3a6f64c..670c408 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log" "github.com/oarkflow/mq" ) @@ -17,53 +18,46 @@ func Node2(ctx context.Context, task *mq.Task) mq.Result { } func Node3(ctx context.Context, task *mq.Task) mq.Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return mq.Result{Error: err} + var user map[string]any + json.Unmarshal(task.Payload, &user) + age := int(user["age"].(float64)) + status := "FAIL" + if age > 20 { + status = "PASS" } - data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) - bt, _ := json.Marshal(data) - return mq.Result{Payload: bt, TaskID: task.ID} + user["status"] = status + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload, Status: status} } func Node4(ctx context.Context, task *mq.Task) mq.Result { - var data []map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return mq.Result{Error: err} - } - payload := map[string]any{"storage": data} - bt, _ := json.Marshal(payload) - return mq.Result{Payload: bt, TaskID: task.ID} + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["final"] = "D" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} } -func CheckCondition(ctx context.Context, task *mq.Task) mq.Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return mq.Result{Error: err} - } - var status string - if data["user_id"].(float64) == 2 { - status = "pass" - } else { - status = "fail" - } - return mq.Result{Status: status, Payload: task.Payload, TaskID: task.ID} +func Node5(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["salary"] = "E" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} } -func Pass(ctx context.Context, task *mq.Task) mq.Result { - fmt.Println("Pass") - return mq.Result{Payload: task.Payload} -} - -func Fail(ctx context.Context, task *mq.Task) mq.Result { - fmt.Println("Fail") - return mq.Result{Payload: []byte(`{"test2": "asdsa"}`)} +func Node6(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + resultPayload, _ := json.Marshal(map[string]any{"storage": user}) + return mq.Result{Payload: resultPayload} } func Callback(ctx context.Context, task mq.Result) mq.Result { fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic) return mq.Result{} } + +func NotifyResponse(ctx context.Context, result mq.Result) { + log.Printf("DAG Final response: TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic) +} diff --git a/options.go b/options.go index 096cc18..5d278d2 100644 --- a/options.go +++ b/options.go @@ -8,11 +8,13 @@ import ( ) type Result struct { - Payload json.RawMessage `json:"payload"` - Topic string `json:"topic"` - TaskID string `json:"task_id"` - Error error `json:"error,omitempty"` - Status string `json:"status"` + Payload json.RawMessage `json:"payload"` + Topic string `json:"topic"` + CreatedAt time.Time `json:"created_at"` + ProcessedAt time.Time `json:"processed_at,omitempty"` + TaskID string `json:"task_id"` + Error error `json:"error,omitempty"` + Status string `json:"status"` } func (r Result) Unmarshal(data any) error { @@ -22,10 +24,6 @@ func (r Result) Unmarshal(data any) error { return json.Unmarshal(r.Payload, data) } -func (r Result) String() string { - return string(r.Payload) -} - func HandleError(ctx context.Context, err error, status ...string) Result { st := "Failed" if len(status) > 0 { @@ -63,6 +61,7 @@ type Options struct { brokerAddr string callback []func(context.Context, Result) Result maxRetries int + notifyResponse func(context.Context, Result) initialDelay time.Duration maxBackoff time.Duration jitterPercent float64 @@ -96,6 +95,12 @@ func setupOptions(opts ...Option) Options { return options } +func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option { + return func(opts *Options) { + opts.notifyResponse = handler + } +} + func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { return func(opts *Options) { opts.aesKey = aesKey diff --git a/util.go b/util.go index 01298fd..c2e9fd2 100644 --- a/util.go +++ b/util.go @@ -15,6 +15,10 @@ func (b *Broker) SyncMode() bool { return b.opts.syncMode } +func (b *Broker) NotifyHandler() func(context.Context, Result) { + return b.opts.notifyResponse +} + func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) { if b.opts.callback != nil { var result Result