From ff2922eddffe4d7150574631a1ca08e3cd27613d Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 18:34:21 +0545 Subject: [PATCH] feat: add example --- ctx.go | 5 +++ dag/dag.go | 2 +- examples/dag.go | 6 ++-- examples/dag_v2.go | 79 +++++++++++++++++++++++++++++++++++----------- options.go | 3 -- v2/dag.go | 40 +++++++++++++++++++++-- v2/task_manager.go | 33 ++++++++++++------- 7 files changed, 129 insertions(+), 39 deletions(-) diff --git a/ctx.go b/ctx.go index b1e66d1..47bdfe1 100644 --- a/ctx.go +++ b/ctx.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "os" + "sync" "time" "github.com/oarkflow/xid" @@ -38,7 +39,11 @@ func IsClosed(conn net.Conn) bool { return false } +var m = sync.RWMutex{} + func SetHeaders(ctx context.Context, headers map[string]string) context.Context { + m.Lock() + defer m.Unlock() hd, ok := GetHeaders(ctx) if !ok { hd = make(map[string]string) diff --git a/dag/dag.go b/dag/dag.go index 0f2b51f..8139bc5 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -176,7 +176,7 @@ func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { if con, ok := d.nodes[task.Topic]; ok { - return con.ProcessTask(ctx, mq.Task{ + return con.ProcessTask(ctx, &mq.Task{ ID: task.TaskID, Payload: task.Payload, }) diff --git a/examples/dag.go b/examples/dag.go index f2e66cd..59f8431 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -1,5 +1,6 @@ package main +/* import ( "context" "encoding/json" @@ -46,13 +47,13 @@ func main() { }() time.Sleep(10 * time.Second) - /*d.Prepare() + d.Prepare() http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) err := d.Start(context.TODO(), ":8083") if err != nil { panic(err) - }*/ + } } func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { @@ -89,3 +90,4 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ json.NewEncoder(w).Encode(result) } } +*/ diff --git a/examples/dag_v2.go b/examples/dag_v2.go index b3a7b34..9bcf5ba 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -4,19 +4,21 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" "github.com/oarkflow/mq" "github.com/oarkflow/mq/v2" ) func handler1(ctx context.Context, task *mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, Ctx: ctx} + return mq.Result{Payload: task.Payload} } func handler2(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) - return mq.Result{Payload: task.Payload, Ctx: ctx} + return mq.Result{Payload: task.Payload} } func handler3(ctx context.Context, task *mq.Task) mq.Result { @@ -29,7 +31,7 @@ func handler3(ctx context.Context, task *mq.Task) mq.Result { } user["status"] = status resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Status: status, Ctx: ctx} + return mq.Result{Payload: resultPayload, Status: status} } func handler4(ctx context.Context, task *mq.Task) mq.Result { @@ -37,7 +39,7 @@ func handler4(ctx context.Context, task *mq.Task) mq.Result { json.Unmarshal(task.Payload, &user) user["final"] = "D" resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload} } func handler5(ctx context.Context, task *mq.Task) mq.Result { @@ -45,34 +47,73 @@ func handler5(ctx context.Context, task *mq.Task) mq.Result { json.Unmarshal(task.Payload, &user) user["salary"] = "E" resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload} } func handler6(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) resultPayload, _ := json.Marshal(map[string]any{"storage": user}) - return mq.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload} } +var ( + d = v2.NewDAG(mq.WithSyncMode(true)) +) + func main() { - dag := v2.NewDAG() - dag.AddNode("A", handler1) - dag.AddNode("B", handler2) - dag.AddNode("C", handler3) - dag.AddNode("D", handler4) - dag.AddNode("E", handler5) - dag.AddNode("F", handler6) - dag.AddEdge("A", "B", v2.LoopEdge) - dag.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) - dag.AddEdge("B", "C") - dag.AddEdge("D", "F") - dag.AddEdge("E", "F") + d.AddNode("A", handler1) + d.AddNode("B", handler2) + d.AddNode("C", handler3) + d.AddNode("D", handler4) + d.AddNode("E", handler5) + d.AddNode("F", handler6) + d.AddEdge("A", "B", v2.LoopEdge) + d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) + d.AddEdge("B", "C") + d.AddEdge("D", "F") + d.AddEdge("E", "F") initialPayload, _ := json.Marshal([]map[string]any{ {"user_id": 1, "age": 12}, {"user_id": 2, "age": 34}, }) - rs := dag.ProcessTask(context.Background(), "A", initialPayload) + rs := d.ProcessTask(context.Background(), "A", initialPayload) fmt.Println(string(rs.Payload)) + http.HandleFunc("POST /publish", requestHandler("publish")) + http.HandleFunc("POST /request", requestHandler("request")) + err := d.Start(context.TODO(), ":8083") + if err != nil { + panic(err) + } +} + +func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) + return + } + var payload []byte + if r.Body != nil { + defer r.Body.Close() + var err error + payload, err = io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + } else { + http.Error(w, "Empty request body", http.StatusBadRequest) + return + } + rs := d.ProcessTask(context.Background(), "A", payload) + w.Header().Set("Content-Type", "application/json") + result := map[string]any{ + "message_id": rs.TaskID, + "payload": string(rs.Payload), + "error": rs.Error, + } + json.NewEncoder(w).Encode(result) + } } diff --git a/options.go b/options.go index 520c8fe..096cc18 100644 --- a/options.go +++ b/options.go @@ -8,7 +8,6 @@ import ( ) type Result struct { - Ctx context.Context Payload json.RawMessage `json:"payload"` Topic string `json:"topic"` TaskID string `json:"task_id"` @@ -38,7 +37,6 @@ func HandleError(ctx context.Context, err error, status ...string) Result { return Result{ Status: st, Error: err, - Ctx: ctx, } } @@ -50,7 +48,6 @@ func (r Result) WithData(status string, data []byte) Result { Status: status, Payload: data, Error: nil, - Ctx: r.Ctx, } } diff --git a/v2/dag.go b/v2/dag.go index 8a1d5ce..9c61695 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -3,6 +3,8 @@ package v2 import ( "context" "encoding/json" + "log" + "net/http" "sync" "github.com/oarkflow/xid" @@ -44,17 +46,51 @@ type Edge struct { type DAG struct { Nodes map[string]*Node + server *mq.Broker taskContext map[string]*TaskManager conditions map[string]map[string]string mu sync.RWMutex } -func NewDAG() *DAG { - return &DAG{ +func NewDAG(opts ...mq.Option) *DAG { + d := &DAG{ Nodes: make(map[string]*Node), taskContext: make(map[string]*TaskManager), conditions: make(map[string]map[string]string), } + opts = append(opts, mq.WithCallback(d.onTaskCallback)) + d.server = mq.NewBroker(opts...) + return d +} + +func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { + if taskContext, ok := tm.taskContext[result.TaskID]; ok { + return taskContext.handleCallback(ctx, result) + } + return mq.Result{} +} + +func (tm *DAG) Start(ctx context.Context, addr string) error { + if tm.server.SyncMode() { + return nil + } + go func() { + err := tm.server.Start(ctx) + if err != nil { + panic(err) + } + }() + for _, con := range tm.Nodes { + go func(con *Node) { + con.consumer.Consume(ctx) + }(con) + } + log.Printf("HTTP server started on %s", addr) + config := tm.server.TLSConfig() + if config.UseTLS { + return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) + } + return http.ListenAndServe(addr, nil) } func (tm *DAG) AddNode(key string, handler mq.Handler) { diff --git a/v2/task_manager.go b/v2/task_manager.go index 83c06a5..435d163 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -46,13 +46,18 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *mq. tm.mutex.Lock() defer tm.mutex.Unlock() if len(tm.results) == 1 { - return tm.callback(ctx, tm.results[0]) + return tm.handleResult(ctx, tm.results[0]) } - return tm.callback(ctx, tm.results) + return tm.handleResult(ctx, tm.results) } } -func (tm *TaskManager) callback(ctx context.Context, results any) mq.Result { +func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result { + fmt.Println(string(result.Payload), result.Topic, result.TaskID) + return mq.Result{} +} + +func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result { var rs mq.Result switch res := results.(type) { case []mq.Result: @@ -101,17 +106,21 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas return default: ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) - result = node.consumer.ProcessTask(ctx, task) - result.Topic = node.Key - if result.Error != nil { - tm.appendFinalResult(result) - return + if tm.dag.server.SyncMode() { + result = node.consumer.ProcessTask(ctx, task) + result.Topic = node.Key + if result.Error != nil { + tm.appendFinalResult(result) + return + } + } else { + err := tm.dag.server.Publish(ctx, *task, node.Key) + if err != nil { + tm.appendFinalResult(mq.Result{Error: err}) + return + } } } - if result.Ctx == nil { - result.Ctx = ctx - } - ctx = result.Ctx tm.mutex.Lock() task.Results[node.Key] = result tm.mutex.Unlock()