From a7466ba791b1a40563c351b9876e2f2ebd949896 Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 11:15:44 +0545 Subject: [PATCH 01/17] feat: add example --- dag/dag.go | 430 ++++++++------------------------------------ dag/task_manager.go | 158 ++++++++++++++++ examples/dag.go | 2 +- 3 files changed, 234 insertions(+), 356 deletions(-) create mode 100644 dag/task_manager.go diff --git a/dag/dag.go b/dag/dag.go index c0745ef..64666c3 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -1,382 +1,102 @@ -package dag +package v2 import ( "context" "encoding/json" - "fmt" - "log" - "net/http" "sync" - "time" - "github.com/oarkflow/mq/consts" - - "github.com/oarkflow/mq" + "github.com/oarkflow/xid" ) -type taskContext struct { - totalItems int - completed int - results []json.RawMessage - result json.RawMessage - multipleResults bool +type Handler func(ctx context.Context, task *Task) Result + +type Result struct { + TaskID string `json:"task_id"` + NodeKey string `json:"node_key"` + Payload json.RawMessage `json:"payload"` + Status string `json:"status"` + Error error `json:"error"` +} + +type Task struct { + ID string `json:"id"` + NodeKey string `json:"node_key"` + Payload json.RawMessage `json:"payload"` + Results map[string]Result `json:"results"` +} + +type Node struct { + Key string + Edges []Edge + handler Handler +} + +type EdgeType int + +const ( + SimpleEdge EdgeType = iota + LoopEdge + ConditionEdge +) + +type Edge struct { + From *Node + To *Node + Type EdgeType + Condition func(result Result) bool } type DAG struct { - FirstNode string - server *mq.Broker - nodes map[string]*mq.Consumer - edges map[string]string - conditions map[string]map[string]string - loopEdges map[string][]string - taskChMap map[string]chan mq.Result - taskResults map[string]map[string]*taskContext - mu sync.Mutex + Nodes map[string]*Node + taskContext map[string]*TaskManager + mu sync.RWMutex } -func New(opts ...mq.Option) *DAG { - d := &DAG{ - nodes: make(map[string]*mq.Consumer), - edges: make(map[string]string), - conditions: make(map[string]map[string]string), - loopEdges: make(map[string][]string), - taskChMap: make(map[string]chan mq.Result), - taskResults: make(map[string]map[string]*taskContext), - } - opts = append(opts, mq.WithCallback(d.TaskCallback)) - d.server = mq.NewBroker(opts...) - return d -} - -func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { - tlsConfig := d.server.TLSConfig() - con := mq.NewConsumer(name, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath)) - if len(firstNode) > 0 { - d.FirstNode = name - } - con.RegisterHandler(name, handler) - d.nodes[name] = con -} - -func (d *DAG) AddCondition(fromNode string, conditions map[string]string) { - d.conditions[fromNode] = conditions -} - -func (d *DAG) AddEdge(fromNode string, toNodes string) { - d.edges[fromNode] = toNodes -} - -func (d *DAG) AddLoop(fromNode string, toNode ...string) { - d.loopEdges[fromNode] = toNode -} - -func (d *DAG) Prepare() { - if d.FirstNode == "" { - firstNode, ok := d.FindFirstNode() - if ok && firstNode != "" { - d.FirstNode = firstNode - } +func NewDAG() *DAG { + return &DAG{ + Nodes: make(map[string]*Node), + taskContext: make(map[string]*TaskManager), } } -func (d *DAG) Start(ctx context.Context, addr string) error { - d.Prepare() - if d.server.SyncMode() { - return nil +func (tm *DAG) AddNode(key string, handler Handler) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.Nodes[key] = &Node{ + Key: key, + handler: handler, } - go func() { - err := d.server.Start(ctx) - if err != nil { - panic(err) - } - }() - for _, con := range d.nodes { - go func(con *mq.Consumer) { - con.Consume(ctx) - }(con) - } - log.Printf("HTTP server started on %s", addr) - config := d.server.TLSConfig() - if config.UseTLS { - return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) - } - return http.ListenAndServe(addr, nil) } -func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) mq.Result { - queue, ok := mq.GetQueue(ctx) +func (tm *DAG) AddEdge(from, to string, edgeType EdgeType) { + tm.mu.Lock() + defer tm.mu.Unlock() + fromNode, ok := tm.Nodes[from] if !ok { - queue = d.FirstNode + return } - var id string - if len(taskID) > 0 { - id = taskID[0] - } else { - id = mq.NewID() + toNode, ok := tm.Nodes[to] + if !ok { + return } - task := mq.Task{ - ID: id, - Payload: payload, - CreatedAt: time.Now(), - } - err := d.server.Publish(ctx, task, queue) - if err != nil { - return mq.Result{Error: err} - } - return mq.Result{ - Payload: payload, - Queue: queue, - MessageID: id, - } -} - -func (d *DAG) FindFirstNode() (string, bool) { - inDegree := make(map[string]int) - for n, _ := range d.nodes { - inDegree[n] = 0 - } - for _, outNode := range d.edges { - inDegree[outNode]++ - } - for _, targets := range d.loopEdges { - for _, outNode := range targets { - inDegree[outNode]++ - } - } - for n, count := range inDegree { - if count == 0 { - return n, true - } - } - return "", false -} - -func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result { - return d.sendSync(ctx, mq.Result{Payload: payload}) -} - -func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { - if d.FirstNode == "" { - return mq.Result{Error: fmt.Errorf("initial node not defined")} - } - if d.server.SyncMode() { - return d.sendSync(ctx, mq.Result{Payload: payload}) - } - resultCh := make(chan mq.Result) - result := d.PublishTask(ctx, payload) - if result.Error != nil { - return result - } - d.mu.Lock() - d.taskChMap[result.MessageID] = resultCh - d.mu.Unlock() - finalResult := <-resultCh - return finalResult -} - -func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { - if con, ok := d.nodes[task.Queue]; ok { - return con.ProcessTask(ctx, mq.Task{ - ID: task.MessageID, - Payload: task.Payload, - }) - } - return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)} -} - -func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { - if task.MessageID == "" { - task.MessageID = mq.NewID() - } - if task.Queue == "" { - task.Queue = d.FirstNode - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: task.Queue, + fromNode.Edges = append(fromNode.Edges, Edge{ + From: fromNode, + To: toNode, + Type: edgeType, }) - result := d.processNode(ctx, task) - if result.Error != nil { - return result - } - for _, target := range d.loopEdges[task.Queue] { - var items, results []json.RawMessage - if err := json.Unmarshal(result.Payload, &items); err != nil { - return mq.Result{Error: err} - } - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: item, - Queue: target, - MessageID: result.MessageID, - }) - if result.Error != nil { - return result - } - results = append(results, result.Payload) - } - bt, err := json.Marshal(results) - if err != nil { - return mq.Result{Error: err} - } - result.Payload = bt - } - if conditions, ok := d.conditions[task.Queue]; ok { - if target, exists := conditions[result.Status]; exists { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Queue: target, - MessageID: result.MessageID, - }) - if result.Error != nil { - return result - } - } - } - if target, ok := d.edges[task.Queue]; ok { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Queue: target, - MessageID: result.MessageID, - }) - if result.Error != nil { - return result - } - } - return result } -func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) { - var result any - var payload []byte - completed := false - multipleResults := false - if ok && triggeredNode != "" { - taskResults, ok := d.taskResults[task.MessageID] - if ok { - nodeResult, exists := taskResults[triggeredNode] - if exists { - multipleResults = nodeResult.multipleResults - nodeResult.completed++ - if nodeResult.completed == nodeResult.totalItems { - completed = true - } - if multipleResults { - nodeResult.results = append(nodeResult.results, task.Payload) - if completed { - result = nodeResult.results - } - } else { - nodeResult.result = task.Payload - if completed { - result = nodeResult.result - } - } - } - if completed { - delete(taskResults, triggeredNode) - } - } +func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) Result { + tm.mu.Lock() + defer tm.mu.Unlock() + taskID := xid.New().String() + task := &Task{ + ID: taskID, + NodeKey: node, + Payload: payload, + Results: make(map[string]Result), } - if completed { - payload, _ = json.Marshal(result) - } else { - payload = task.Payload - } - return payload, completed, multipleResults -} - -func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { - if task.Error != nil { - return mq.Result{Error: task.Error} - } - triggeredNode, ok := mq.GetTriggerNode(ctx) - payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) - if loopNodes, exists := d.loopEdges[task.Queue]; exists { - var items []json.RawMessage - if err := json.Unmarshal(payload, &items); err != nil { - return mq.Result{Error: task.Error} - } - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { - totalItems: len(items), - multipleResults: true, - }, - } - - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) - for _, loopNode := range loopNodes { - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: loopNode, - }) - result := d.PublishTask(ctx, item, task.MessageID) - if result.Error != nil { - return result - } - } - } - - return task - } - if multipleResults && completed { - task.Queue = triggeredNode - } - if conditions, ok := d.conditions[task.Queue]; ok { - if target, exists := conditions[task.Status]; exists { - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { - totalItems: len(conditions), - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - consts.TriggerNode: task.Queue, - }) - result := d.PublishTask(ctx, payload, task.MessageID) - if result.Error != nil { - return result - } - } - } else { - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) - edge, exists := d.edges[task.Queue] - if exists { - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { - totalItems: 1, - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: edge, - }) - result := d.PublishTask(ctx, payload, task.MessageID) - if result.Error != nil { - return result - } - } else if completed { - d.mu.Lock() - if resultCh, ok := d.taskChMap[task.MessageID]; ok { - resultCh <- mq.Result{ - Payload: payload, - Queue: task.Queue, - MessageID: task.MessageID, - Status: "done", - } - delete(d.taskChMap, task.MessageID) - delete(d.taskResults, task.MessageID) - } - d.mu.Unlock() - } - } - - return task + manager := NewTaskManager(tm) + tm.taskContext[taskID] = manager + return manager.processTask(ctx, node, task) } diff --git a/dag/task_manager.go b/dag/task_manager.go new file mode 100644 index 0000000..ae92306 --- /dev/null +++ b/dag/task_manager.go @@ -0,0 +1,158 @@ +package v2 + +import ( + "context" + "encoding/json" + "fmt" + "sync" +) + +type TaskManager struct { + dag *DAG + wg sync.WaitGroup + mutex sync.Mutex + results []Result + nodeResults map[string]Result // Store results per node for future reference + done chan struct{} +} + +func NewTaskManager(d *DAG) *TaskManager { + return &TaskManager{ + dag: d, + results: make([]Result, 0), + done: make(chan struct{}), + } +} + +func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *Task) Result { + node, ok := tm.dag.Nodes[nodeID] + if !ok { + return Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} + } + tm.wg.Add(1) + go tm.processNode(ctx, node, task, nil) + go func() { + tm.wg.Wait() + close(tm.done) + }() + select { + case <-ctx.Done(): + return Result{Error: ctx.Err()} + case <-tm.done: + tm.mutex.Lock() + defer tm.mutex.Unlock() + if len(tm.results) == 1 { + return tm.callback(tm.results[0]) + } + return tm.callback(tm.results) + } +} + +func (tm *TaskManager) callback(results any) Result { + var rs Result + switch res := results.(type) { + case []Result: + aggregatedOutput := make([]json.RawMessage, 0) + for i, result := range res { + if i == 0 { + rs.TaskID = result.TaskID + } + var item json.RawMessage + err := json.Unmarshal(result.Payload, &item) + if err != nil { + rs.Error = err + return rs + } + aggregatedOutput = append(aggregatedOutput, item) + } + finalOutput, err := json.Marshal(aggregatedOutput) + if err != nil { + rs.Error = err + return rs + } + rs.Payload = finalOutput + case Result: + rs.TaskID = res.TaskID + var item json.RawMessage + err := json.Unmarshal(res.Payload, &item) + if err != nil { + rs.Error = err + return rs + } + finalOutput, err := json.Marshal(item) + if err != nil { + rs.Error = err + return rs + } + rs.Payload = finalOutput + } + return rs +} + +func (tm *TaskManager) appendFinalResult(result Result) { + tm.mutex.Lock() + tm.results = append(tm.results, result) + tm.nodeResults[result.NodeKey] = result // Store result by node key + tm.mutex.Unlock() +} + +func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, parentNode *Node) { + defer tm.wg.Done() + var result Result + select { + case <-ctx.Done(): + result = Result{TaskID: task.ID, NodeKey: node.Key, Error: ctx.Err()} + tm.appendFinalResult(result) + return + default: + result = node.handler(ctx, task) + if result.Error != nil { // Exit the flow on error + tm.appendFinalResult(result) + return + } + } + tm.mutex.Lock() + task.Results[node.Key] = result // Store intermediate results + tm.mutex.Unlock() + if len(node.Edges) == 0 { + if parentNode != nil { + tm.appendFinalResult(result) + } + return + } + for _, edge := range node.Edges { + switch edge.Type { + case LoopEdge: + var items []json.RawMessage + err := json.Unmarshal(task.Payload, &items) + if err != nil { + tm.appendFinalResult(Result{TaskID: task.ID, NodeKey: node.Key, Error: err}) + return + } + for _, item := range items { + loopTask := &Task{ + ID: task.ID, + NodeKey: edge.From.Key, + Payload: item, + Results: task.Results, + } + tm.wg.Add(1) + go tm.processNode(ctx, edge.To, loopTask, node) + } + case ConditionEdge: + if edge.Condition(result) && edge.To != nil { + tm.wg.Add(1) + go tm.processNode(ctx, edge.To, task, node) + } else if parentNode != nil { + tm.appendFinalResult(result) + } + case SimpleEdge: + if edge.To != nil { + tm.wg.Add(1) + go tm.processNode(ctx, edge.To, task, node) + } else if parentNode != nil { + tm.appendFinalResult(result) + } + } + } +} diff --git a/examples/dag.go b/examples/dag.go index d255407..992b303 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -16,7 +16,7 @@ import ( var d *dag.DAG func main() { - d = dag.New(mq.WithSyncMode(false), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.crt")) + d = dag.New(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.crt")) d.AddNode("queue1", tasks.Node1, true) d.AddNode("queue2", tasks.Node2) d.AddNode("queue3", tasks.Node3) From c9c5ac9946134f4e320e14f44ab84ae858e74dbf Mon Sep 17 00:00:00 2001 From: Oarkflow Date: Tue, 8 Oct 2024 11:51:31 +0545 Subject: [PATCH 02/17] feat: separate broker --- dag/dag.go | 430 +++++++++++++++++++++++++++++------- examples/dag_v2.go | 296 ++++--------------------- v2/dag.go | 102 +++++++++ {dag => v2}/task_manager.go | 7 +- 4 files changed, 502 insertions(+), 333 deletions(-) create mode 100644 v2/dag.go rename {dag => v2}/task_manager.go (96%) diff --git a/dag/dag.go b/dag/dag.go index 64666c3..c0745ef 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -1,102 +1,382 @@ -package v2 +package dag import ( "context" "encoding/json" + "fmt" + "log" + "net/http" "sync" + "time" - "github.com/oarkflow/xid" + "github.com/oarkflow/mq/consts" + + "github.com/oarkflow/mq" ) -type Handler func(ctx context.Context, task *Task) Result - -type Result struct { - TaskID string `json:"task_id"` - NodeKey string `json:"node_key"` - Payload json.RawMessage `json:"payload"` - Status string `json:"status"` - Error error `json:"error"` -} - -type Task struct { - ID string `json:"id"` - NodeKey string `json:"node_key"` - Payload json.RawMessage `json:"payload"` - Results map[string]Result `json:"results"` -} - -type Node struct { - Key string - Edges []Edge - handler Handler -} - -type EdgeType int - -const ( - SimpleEdge EdgeType = iota - LoopEdge - ConditionEdge -) - -type Edge struct { - From *Node - To *Node - Type EdgeType - Condition func(result Result) bool +type taskContext struct { + totalItems int + completed int + results []json.RawMessage + result json.RawMessage + multipleResults bool } type DAG struct { - Nodes map[string]*Node - taskContext map[string]*TaskManager - mu sync.RWMutex + FirstNode string + server *mq.Broker + nodes map[string]*mq.Consumer + edges map[string]string + conditions map[string]map[string]string + loopEdges map[string][]string + taskChMap map[string]chan mq.Result + taskResults map[string]map[string]*taskContext + mu sync.Mutex } -func NewDAG() *DAG { - return &DAG{ - Nodes: make(map[string]*Node), - taskContext: make(map[string]*TaskManager), +func New(opts ...mq.Option) *DAG { + d := &DAG{ + nodes: make(map[string]*mq.Consumer), + edges: make(map[string]string), + conditions: make(map[string]map[string]string), + loopEdges: make(map[string][]string), + taskChMap: make(map[string]chan mq.Result), + taskResults: make(map[string]map[string]*taskContext), + } + opts = append(opts, mq.WithCallback(d.TaskCallback)) + d.server = mq.NewBroker(opts...) + return d +} + +func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { + tlsConfig := d.server.TLSConfig() + con := mq.NewConsumer(name, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath)) + if len(firstNode) > 0 { + d.FirstNode = name + } + con.RegisterHandler(name, handler) + d.nodes[name] = con +} + +func (d *DAG) AddCondition(fromNode string, conditions map[string]string) { + d.conditions[fromNode] = conditions +} + +func (d *DAG) AddEdge(fromNode string, toNodes string) { + d.edges[fromNode] = toNodes +} + +func (d *DAG) AddLoop(fromNode string, toNode ...string) { + d.loopEdges[fromNode] = toNode +} + +func (d *DAG) Prepare() { + if d.FirstNode == "" { + firstNode, ok := d.FindFirstNode() + if ok && firstNode != "" { + d.FirstNode = firstNode + } } } -func (tm *DAG) AddNode(key string, handler Handler) { - tm.mu.Lock() - defer tm.mu.Unlock() - tm.Nodes[key] = &Node{ - Key: key, - handler: handler, +func (d *DAG) Start(ctx context.Context, addr string) error { + d.Prepare() + if d.server.SyncMode() { + return nil } + go func() { + err := d.server.Start(ctx) + if err != nil { + panic(err) + } + }() + for _, con := range d.nodes { + go func(con *mq.Consumer) { + con.Consume(ctx) + }(con) + } + log.Printf("HTTP server started on %s", addr) + config := d.server.TLSConfig() + if config.UseTLS { + return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) + } + return http.ListenAndServe(addr, nil) } -func (tm *DAG) AddEdge(from, to string, edgeType EdgeType) { - tm.mu.Lock() - defer tm.mu.Unlock() - fromNode, ok := tm.Nodes[from] +func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) mq.Result { + queue, ok := mq.GetQueue(ctx) if !ok { - return + queue = d.FirstNode } - toNode, ok := tm.Nodes[to] - if !ok { - return + var id string + if len(taskID) > 0 { + id = taskID[0] + } else { + id = mq.NewID() } - fromNode.Edges = append(fromNode.Edges, Edge{ - From: fromNode, - To: toNode, - Type: edgeType, + task := mq.Task{ + ID: id, + Payload: payload, + CreatedAt: time.Now(), + } + err := d.server.Publish(ctx, task, queue) + if err != nil { + return mq.Result{Error: err} + } + return mq.Result{ + Payload: payload, + Queue: queue, + MessageID: id, + } +} + +func (d *DAG) FindFirstNode() (string, bool) { + inDegree := make(map[string]int) + for n, _ := range d.nodes { + inDegree[n] = 0 + } + for _, outNode := range d.edges { + inDegree[outNode]++ + } + for _, targets := range d.loopEdges { + for _, outNode := range targets { + inDegree[outNode]++ + } + } + for n, count := range inDegree { + if count == 0 { + return n, true + } + } + return "", false +} + +func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result { + return d.sendSync(ctx, mq.Result{Payload: payload}) +} + +func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { + if d.FirstNode == "" { + return mq.Result{Error: fmt.Errorf("initial node not defined")} + } + if d.server.SyncMode() { + return d.sendSync(ctx, mq.Result{Payload: payload}) + } + resultCh := make(chan mq.Result) + result := d.PublishTask(ctx, payload) + if result.Error != nil { + return result + } + d.mu.Lock() + d.taskChMap[result.MessageID] = resultCh + d.mu.Unlock() + finalResult := <-resultCh + return finalResult +} + +func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { + if con, ok := d.nodes[task.Queue]; ok { + return con.ProcessTask(ctx, mq.Task{ + ID: task.MessageID, + Payload: task.Payload, + }) + } + return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)} +} + +func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { + if task.MessageID == "" { + task.MessageID = mq.NewID() + } + if task.Queue == "" { + task.Queue = d.FirstNode + } + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: task.Queue, }) + result := d.processNode(ctx, task) + if result.Error != nil { + return result + } + for _, target := range d.loopEdges[task.Queue] { + var items, results []json.RawMessage + if err := json.Unmarshal(result.Payload, &items); err != nil { + return mq.Result{Error: err} + } + for _, item := range items { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + }) + result = d.sendSync(ctx, mq.Result{ + Payload: item, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + results = append(results, result.Payload) + } + bt, err := json.Marshal(results) + if err != nil { + return mq.Result{Error: err} + } + result.Payload = bt + } + if conditions, ok := d.conditions[task.Queue]; ok { + if target, exists := conditions[result.Status]; exists { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + }) + result = d.sendSync(ctx, mq.Result{ + Payload: result.Payload, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + } + } + if target, ok := d.edges[task.Queue]; ok { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + }) + result = d.sendSync(ctx, mq.Result{ + Payload: result.Payload, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + } + return result } -func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) Result { - tm.mu.Lock() - defer tm.mu.Unlock() - taskID := xid.New().String() - task := &Task{ - ID: taskID, - NodeKey: node, - Payload: payload, - Results: make(map[string]Result), +func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) { + var result any + var payload []byte + completed := false + multipleResults := false + if ok && triggeredNode != "" { + taskResults, ok := d.taskResults[task.MessageID] + if ok { + nodeResult, exists := taskResults[triggeredNode] + if exists { + multipleResults = nodeResult.multipleResults + nodeResult.completed++ + if nodeResult.completed == nodeResult.totalItems { + completed = true + } + if multipleResults { + nodeResult.results = append(nodeResult.results, task.Payload) + if completed { + result = nodeResult.results + } + } else { + nodeResult.result = task.Payload + if completed { + result = nodeResult.result + } + } + } + if completed { + delete(taskResults, triggeredNode) + } + } } - manager := NewTaskManager(tm) - tm.taskContext[taskID] = manager - return manager.processTask(ctx, node, task) + if completed { + payload, _ = json.Marshal(result) + } else { + payload = task.Payload + } + return payload, completed, multipleResults +} + +func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { + if task.Error != nil { + return mq.Result{Error: task.Error} + } + triggeredNode, ok := mq.GetTriggerNode(ctx) + payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) + if loopNodes, exists := d.loopEdges[task.Queue]; exists { + var items []json.RawMessage + if err := json.Unmarshal(payload, &items); err != nil { + return mq.Result{Error: task.Error} + } + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { + totalItems: len(items), + multipleResults: true, + }, + } + + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) + for _, loopNode := range loopNodes { + for _, item := range items { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: loopNode, + }) + result := d.PublishTask(ctx, item, task.MessageID) + if result.Error != nil { + return result + } + } + } + + return task + } + if multipleResults && completed { + task.Queue = triggeredNode + } + if conditions, ok := d.conditions[task.Queue]; ok { + if target, exists := conditions[task.Status]; exists { + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { + totalItems: len(conditions), + }, + } + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + consts.TriggerNode: task.Queue, + }) + result := d.PublishTask(ctx, payload, task.MessageID) + if result.Error != nil { + return result + } + } + } else { + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) + edge, exists := d.edges[task.Queue] + if exists { + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { + totalItems: 1, + }, + } + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: edge, + }) + result := d.PublishTask(ctx, payload, task.MessageID) + if result.Error != nil { + return result + } + } else if completed { + d.mu.Lock() + if resultCh, ok := d.taskChMap[task.MessageID]; ok { + resultCh <- mq.Result{ + Payload: payload, + Queue: task.Queue, + MessageID: task.MessageID, + Status: "done", + } + delete(d.taskChMap, task.MessageID) + delete(d.taskResults, task.MessageID) + } + d.mu.Unlock() + } + } + + return task } diff --git a/examples/dag_v2.go b/examples/dag_v2.go index 8757147..9ca810e 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -4,272 +4,58 @@ import ( "context" "encoding/json" "fmt" - "sync" - "time" + v2 "github.com/oarkflow/mq/v2" ) -type Task struct { - ID string `json:"id"` - Payload json.RawMessage `json:"payload"` - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - Status string `json:"status"` - Error error `json:"error"` -} - -type Result struct { - Payload json.RawMessage `json:"payload"` - Queue string `json:"queue"` - MessageID string `json:"message_id"` - Error error `json:"error,omitempty"` - Status string `json:"status"` -} - -const ( - SimpleEdge = iota - LoopEdge - ConditionEdge -) - -type Edge struct { - edgeType int - to string - conditions map[string]string -} - -type Node struct { - key string - handler func(context.Context, Task) Result - edges []Edge -} - -type RadixTrie struct { - children map[rune]*RadixTrie - node *Node - mu sync.RWMutex -} - -func NewRadixTrie() *RadixTrie { - return &RadixTrie{ - children: make(map[rune]*RadixTrie), +func handler1(ctx context.Context, task *v2.Task) v2.Result { + return v2.Result{ + TaskID: task.ID, + NodeKey: "A", + Payload: task.Payload, + Status: "success", } } -func (trie *RadixTrie) Insert(key string, node *Node) { - trie.mu.Lock() - defer trie.mu.Unlock() - - current := trie - for _, char := range key { - if _, exists := current.children[char]; !exists { - current.children[char] = NewRadixTrie() - } - current = current.children[char] - } - current.node = node -} - -func (trie *RadixTrie) Search(key string) (*Node, bool) { - trie.mu.RLock() - defer trie.mu.RUnlock() - current := trie - for _, char := range key { - if _, exists := current.children[char]; !exists { - return nil, false - } - current = current.children[char] - } - if current.node != nil { - return current.node, true - } - return nil, false -} - -type DAG struct { - trie *RadixTrie - mu sync.RWMutex -} - -func NewDAG() *DAG { - return &DAG{ - trie: NewRadixTrie(), +func handler2(ctx context.Context, task *v2.Task) v2.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + return v2.Result{ + TaskID: task.ID, + NodeKey: "B", + Payload: task.Payload, + Status: "success", } } -func (d *DAG) AddNode(key string, handler func(context.Context, Task) Result, isRoot ...bool) { - node := &Node{key: key, handler: handler} - d.trie.Insert(key, node) -} - -func (d *DAG) AddEdge(fromKey string, toKey string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add edge.\n", fromKey) - return +func handler3(ctx context.Context, task *v2.Task) v2.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + age := int(user["age"].(float64)) + status := "FAIL" + if age > 20 { + status = "PASS" } - edge := Edge{edgeType: SimpleEdge, to: toKey} - node.edges = append(node.edges, edge) -} - -func (d *DAG) AddLoop(fromKey string, toKey string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add loop edge.\n", fromKey) - return + user["status"] = status + resultPayload, _ := json.Marshal(user) + return v2.Result{ + TaskID: task.ID, + NodeKey: "C", + Payload: resultPayload, + Status: status, } - edge := Edge{edgeType: LoopEdge, to: toKey} - node.edges = append(node.edges, edge) -} - -func (d *DAG) AddCondition(fromKey string, conditions map[string]string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add condition edge.\n", fromKey) - return - } - edge := Edge{edgeType: ConditionEdge, conditions: conditions} - node.edges = append(node.edges, edge) -} - -type ProcessCallback func(ctx context.Context, key string, result Result) string - -func (d *DAG) ProcessTask(ctx context.Context, key string, task Task) { - node, exists := d.trie.Search(key) - if !exists { - fmt.Printf("Node %s not found.\n", key) - return - } - result := node.handler(ctx, task) - nextKey := d.callback(ctx, key, result) - if nextKey != "" { - d.ProcessTask(ctx, nextKey, task) - } -} - -func (d *DAG) ProcessLoop(ctx context.Context, key string, task Task) { - _, exists := d.trie.Search(key) - if !exists { - fmt.Printf("Node %s not found.\n", key) - return - } - var items []json.RawMessage - err := json.Unmarshal(task.Payload, &items) - if err != nil { - fmt.Printf("Error unmarshaling payload as slice: %v\n", err) - return - } - for _, item := range items { - newTask := Task{ - ID: task.ID, - Payload: item, - } - - d.ProcessTask(ctx, key, newTask) - } -} - -func (d *DAG) callback(ctx context.Context, currentKey string, result Result) string { - fmt.Printf("Callback received result from %s: %s\n", currentKey, string(result.Payload)) - node, exists := d.trie.Search(currentKey) - if !exists { - return "" - } - for _, edge := range node.edges { - switch edge.edgeType { - case SimpleEdge: - return edge.to - case LoopEdge: - - d.ProcessLoop(ctx, edge.to, Task{Payload: result.Payload}) - return "" - case ConditionEdge: - if nextKey, conditionMet := edge.conditions[result.Status]; conditionMet { - return nextKey - } - } - } - return "" -} - -func Node1(ctx context.Context, task Task) Result { - return Result{Payload: task.Payload, MessageID: task.ID} -} - -func Node2(ctx context.Context, task Task) Result { - return Result{Payload: task.Payload, MessageID: task.ID} -} - -func Node3(ctx context.Context, task Task) Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) - bt, _ := json.Marshal(data) - return Result{Payload: bt, MessageID: task.ID} -} - -func Node4(ctx context.Context, task Task) Result { - var data []map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - payload := map[string]any{"storage": data} - bt, _ := json.Marshal(payload) - return Result{Payload: bt, MessageID: task.ID} -} - -func CheckCondition(ctx context.Context, task Task) Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - var status string - if data["user_id"].(float64) == 2 { - status = "pass" - } else { - status = "fail" - } - return Result{Status: status, Payload: task.Payload, MessageID: task.ID} -} - -func Pass(ctx context.Context, task Task) Result { - fmt.Println("Pass") - return Result{Payload: task.Payload} -} - -func Fail(ctx context.Context, task Task) Result { - fmt.Println("Fail") - return Result{Payload: []byte(`{"test2": "asdsa"}`)} } func main() { - dag := NewDAG() - dag.AddNode("queue1", Node1, true) - dag.AddNode("queue2", Node2) - dag.AddNode("queue3", Node3) - dag.AddNode("queue4", Node4) - dag.AddNode("queue5", CheckCondition) - dag.AddNode("queue6", Pass) - dag.AddNode("queue7", Fail) - dag.AddEdge("queue1", "queue2") - dag.AddEdge("queue2", "queue4") - dag.AddEdge("queue3", "queue5") - dag.AddLoop("queue2", "queue3") - dag.AddCondition("queue5", map[string]string{"pass": "queue6", "fail": "queue7"}) - ctx := context.Background() - task := Task{ - ID: "task1", - Payload: []byte(`[{"user_id": 1}, {"user_id": 2}]`), - } - dag.ProcessTask(ctx, "queue1", task) + dag := v2.NewDAG() + dag.AddNode("A", handler1) + dag.AddNode("B", handler2) + dag.AddNode("C", handler3) + dag.AddEdge("A", "B", v2.LoopEdge) + dag.AddEdge("B", "C", v2.SimpleEdge) + initialPayload, _ := json.Marshal([]map[string]any{ + {"user_id": 1, "age": 12}, + {"user_id": 2, "age": 34}, + }) + rs := dag.ProcessTask(context.Background(), "A", initialPayload) + fmt.Println(string(rs.Payload)) } diff --git a/v2/dag.go b/v2/dag.go new file mode 100644 index 0000000..64666c3 --- /dev/null +++ b/v2/dag.go @@ -0,0 +1,102 @@ +package v2 + +import ( + "context" + "encoding/json" + "sync" + + "github.com/oarkflow/xid" +) + +type Handler func(ctx context.Context, task *Task) Result + +type Result struct { + TaskID string `json:"task_id"` + NodeKey string `json:"node_key"` + Payload json.RawMessage `json:"payload"` + Status string `json:"status"` + Error error `json:"error"` +} + +type Task struct { + ID string `json:"id"` + NodeKey string `json:"node_key"` + Payload json.RawMessage `json:"payload"` + Results map[string]Result `json:"results"` +} + +type Node struct { + Key string + Edges []Edge + handler Handler +} + +type EdgeType int + +const ( + SimpleEdge EdgeType = iota + LoopEdge + ConditionEdge +) + +type Edge struct { + From *Node + To *Node + Type EdgeType + Condition func(result Result) bool +} + +type DAG struct { + Nodes map[string]*Node + taskContext map[string]*TaskManager + mu sync.RWMutex +} + +func NewDAG() *DAG { + return &DAG{ + Nodes: make(map[string]*Node), + taskContext: make(map[string]*TaskManager), + } +} + +func (tm *DAG) AddNode(key string, handler Handler) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.Nodes[key] = &Node{ + Key: key, + handler: handler, + } +} + +func (tm *DAG) AddEdge(from, to string, edgeType EdgeType) { + tm.mu.Lock() + defer tm.mu.Unlock() + fromNode, ok := tm.Nodes[from] + if !ok { + return + } + toNode, ok := tm.Nodes[to] + if !ok { + return + } + fromNode.Edges = append(fromNode.Edges, Edge{ + From: fromNode, + To: toNode, + Type: edgeType, + }) +} + +func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) Result { + tm.mu.Lock() + defer tm.mu.Unlock() + taskID := xid.New().String() + task := &Task{ + ID: taskID, + NodeKey: node, + Payload: payload, + Results: make(map[string]Result), + } + manager := NewTaskManager(tm) + tm.taskContext[taskID] = manager + return manager.processTask(ctx, node, task) +} diff --git a/dag/task_manager.go b/v2/task_manager.go similarity index 96% rename from dag/task_manager.go rename to v2/task_manager.go index ae92306..316130c 100644 --- a/dag/task_manager.go +++ b/v2/task_manager.go @@ -18,9 +18,10 @@ type TaskManager struct { func NewTaskManager(d *DAG) *TaskManager { return &TaskManager{ - dag: d, - results: make([]Result, 0), - done: make(chan struct{}), + dag: d, + nodeResults: make(map[string]Result), + results: make([]Result, 0), + done: make(chan struct{}), } } From 18ce5ec13bced54731054d5095020377de720998 Mon Sep 17 00:00:00 2001 From: Oarkflow Date: Tue, 8 Oct 2024 14:49:21 +0545 Subject: [PATCH 03/17] feat: separate broker --- examples/dag_v2.go | 53 ++++++++++++++++++++++++++++++---------------- v2/dag.go | 16 +++++++++----- v2/task_manager.go | 45 ++++++++++++++++++++++++++------------- 3 files changed, 76 insertions(+), 38 deletions(-) diff --git a/examples/dag_v2.go b/examples/dag_v2.go index 9ca810e..52bef21 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -8,23 +8,13 @@ import ( ) func handler1(ctx context.Context, task *v2.Task) v2.Result { - return v2.Result{ - TaskID: task.ID, - NodeKey: "A", - Payload: task.Payload, - Status: "success", - } + return v2.Result{TaskID: task.ID, NodeKey: "A", Payload: task.Payload} } func handler2(ctx context.Context, task *v2.Task) v2.Result { var user map[string]any json.Unmarshal(task.Payload, &user) - return v2.Result{ - TaskID: task.ID, - NodeKey: "B", - Payload: task.Payload, - Status: "success", - } + return v2.Result{TaskID: task.ID, NodeKey: "B", Payload: task.Payload} } func handler3(ctx context.Context, task *v2.Task) v2.Result { @@ -37,12 +27,30 @@ func handler3(ctx context.Context, task *v2.Task) v2.Result { } user["status"] = status resultPayload, _ := json.Marshal(user) - return v2.Result{ - TaskID: task.ID, - NodeKey: "C", - Payload: resultPayload, - Status: status, - } + return v2.Result{TaskID: task.ID, NodeKey: "C", Payload: resultPayload, Status: status} +} + +func handler4(ctx context.Context, task *v2.Task) v2.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["final"] = "D" + resultPayload, _ := json.Marshal(user) + return v2.Result{TaskID: task.ID, NodeKey: "D", Payload: resultPayload} +} + +func handler5(ctx context.Context, task *v2.Task) v2.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["salary"] = "E" + resultPayload, _ := json.Marshal(user) + return v2.Result{TaskID: task.ID, NodeKey: "E", Payload: resultPayload} +} + +func handler6(ctx context.Context, task *v2.Task) v2.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + resultPayload, _ := json.Marshal(map[string]any{"storage": user}) + return v2.Result{TaskID: task.ID, NodeKey: "F", Payload: resultPayload} } func main() { @@ -50,12 +58,21 @@ func main() { dag.AddNode("A", handler1) dag.AddNode("B", handler2) dag.AddNode("C", handler3) + dag.AddNode("D", handler4) + dag.AddNode("E", handler5) + dag.AddNode("F", handler6) dag.AddEdge("A", "B", v2.LoopEdge) + dag.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) dag.AddEdge("B", "C", v2.SimpleEdge) + dag.AddEdge("D", "F", v2.SimpleEdge) + dag.AddEdge("E", "F", v2.SimpleEdge) + initialPayload, _ := json.Marshal([]map[string]any{ {"user_id": 1, "age": 12}, {"user_id": 2, "age": 34}, }) rs := dag.ProcessTask(context.Background(), "A", initialPayload) fmt.Println(string(rs.Payload)) + rs = dag.ProcessTask(context.Background(), "A", initialPayload) + fmt.Println(string(rs.Payload)) } diff --git a/v2/dag.go b/v2/dag.go index 64666c3..b357300 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -36,19 +36,18 @@ type EdgeType int const ( SimpleEdge EdgeType = iota LoopEdge - ConditionEdge ) type Edge struct { - From *Node - To *Node - Type EdgeType - Condition func(result Result) bool + From *Node + To *Node + Type EdgeType } type DAG struct { Nodes map[string]*Node taskContext map[string]*TaskManager + conditions map[string]map[string]string mu sync.RWMutex } @@ -56,6 +55,7 @@ func NewDAG() *DAG { return &DAG{ Nodes: make(map[string]*Node), taskContext: make(map[string]*TaskManager), + conditions: make(map[string]map[string]string), } } @@ -68,6 +68,12 @@ func (tm *DAG) AddNode(key string, handler Handler) { } } +func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.conditions[fromNode] = conditions +} + func (tm *DAG) AddEdge(from, to string, edgeType EdgeType) { tm.mu.Lock() defer tm.mu.Unlock() diff --git a/v2/task_manager.go b/v2/task_manager.go index 316130c..5e3991d 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -12,7 +12,7 @@ type TaskManager struct { wg sync.WaitGroup mutex sync.Mutex results []Result - nodeResults map[string]Result // Store results per node for future reference + nodeResults map[string]Result done chan struct{} } @@ -93,7 +93,7 @@ func (tm *TaskManager) callback(results any) Result { func (tm *TaskManager) appendFinalResult(result Result) { tm.mutex.Lock() tm.results = append(tm.results, result) - tm.nodeResults[result.NodeKey] = result // Store result by node key + tm.nodeResults[result.NodeKey] = result tm.mutex.Unlock() } @@ -107,25 +107,41 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, return default: result = node.handler(ctx, task) - if result.Error != nil { // Exit the flow on error + if result.Error != nil { tm.appendFinalResult(result) return } } tm.mutex.Lock() - task.Results[node.Key] = result // Store intermediate results + task.Results[node.Key] = result tm.mutex.Unlock() - if len(node.Edges) == 0 { + + edges := make([]Edge, len(node.Edges)) + copy(edges, node.Edges) + if result.Status != "" { + if conditions, ok := tm.dag.conditions[result.NodeKey]; ok { + if targetNodeKey, ok := conditions[result.Status]; ok { + if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { + edges = append(edges, Edge{ + From: node, + To: targetNode, + Type: SimpleEdge, + }) + } + } + } + } + if len(edges) == 0 { if parentNode != nil { tm.appendFinalResult(result) } return } - for _, edge := range node.Edges { + for _, edge := range edges { switch edge.Type { case LoopEdge: var items []json.RawMessage - err := json.Unmarshal(task.Payload, &items) + err := json.Unmarshal(result.Payload, &items) if err != nil { tm.appendFinalResult(Result{TaskID: task.ID, NodeKey: node.Key, Error: err}) return @@ -140,17 +156,16 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, tm.wg.Add(1) go tm.processNode(ctx, edge.To, loopTask, node) } - case ConditionEdge: - if edge.Condition(result) && edge.To != nil { - tm.wg.Add(1) - go tm.processNode(ctx, edge.To, task, node) - } else if parentNode != nil { - tm.appendFinalResult(result) - } case SimpleEdge: if edge.To != nil { tm.wg.Add(1) - go tm.processNode(ctx, edge.To, task, node) + t := &Task{ + ID: task.ID, + NodeKey: edge.From.Key, + Payload: result.Payload, + Results: task.Results, + } + go tm.processNode(ctx, edge.To, t, node) } else if parentNode != nil { tm.appendFinalResult(result) } From be6bffab766b14f1dc3457b4794d08a1773a4794 Mon Sep 17 00:00:00 2001 From: Oarkflow Date: Tue, 8 Oct 2024 14:53:31 +0545 Subject: [PATCH 04/17] feat: separate broker --- examples/dag_v2.go | 12 ++++++------ v2/dag.go | 2 +- v2/task_manager.go | 9 +++++---- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/dag_v2.go b/examples/dag_v2.go index 52bef21..733f79a 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -8,13 +8,13 @@ import ( ) func handler1(ctx context.Context, task *v2.Task) v2.Result { - return v2.Result{TaskID: task.ID, NodeKey: "A", Payload: task.Payload} + return v2.Result{TaskID: task.ID, Payload: task.Payload} } func handler2(ctx context.Context, task *v2.Task) v2.Result { var user map[string]any json.Unmarshal(task.Payload, &user) - return v2.Result{TaskID: task.ID, NodeKey: "B", Payload: task.Payload} + return v2.Result{TaskID: task.ID, Payload: task.Payload} } func handler3(ctx context.Context, task *v2.Task) v2.Result { @@ -27,7 +27,7 @@ func handler3(ctx context.Context, task *v2.Task) v2.Result { } user["status"] = status resultPayload, _ := json.Marshal(user) - return v2.Result{TaskID: task.ID, NodeKey: "C", Payload: resultPayload, Status: status} + return v2.Result{TaskID: task.ID, Payload: resultPayload, Status: status} } func handler4(ctx context.Context, task *v2.Task) v2.Result { @@ -35,7 +35,7 @@ func handler4(ctx context.Context, task *v2.Task) v2.Result { json.Unmarshal(task.Payload, &user) user["final"] = "D" resultPayload, _ := json.Marshal(user) - return v2.Result{TaskID: task.ID, NodeKey: "D", Payload: resultPayload} + return v2.Result{TaskID: task.ID, Payload: resultPayload} } func handler5(ctx context.Context, task *v2.Task) v2.Result { @@ -43,14 +43,14 @@ func handler5(ctx context.Context, task *v2.Task) v2.Result { json.Unmarshal(task.Payload, &user) user["salary"] = "E" resultPayload, _ := json.Marshal(user) - return v2.Result{TaskID: task.ID, NodeKey: "E", Payload: resultPayload} + return v2.Result{TaskID: task.ID, Payload: resultPayload} } func handler6(ctx context.Context, task *v2.Task) v2.Result { var user map[string]any json.Unmarshal(task.Payload, &user) resultPayload, _ := json.Marshal(map[string]any{"storage": user}) - return v2.Result{TaskID: task.ID, NodeKey: "F", Payload: resultPayload} + return v2.Result{TaskID: task.ID, Payload: resultPayload} } func main() { diff --git a/v2/dag.go b/v2/dag.go index b357300..90ac9ca 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -12,10 +12,10 @@ type Handler func(ctx context.Context, task *Task) Result type Result struct { TaskID string `json:"task_id"` - NodeKey string `json:"node_key"` Payload json.RawMessage `json:"payload"` Status string `json:"status"` Error error `json:"error"` + nodeKey string } type Task struct { diff --git a/v2/task_manager.go b/v2/task_manager.go index 5e3991d..874cb7e 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -93,7 +93,7 @@ func (tm *TaskManager) callback(results any) Result { func (tm *TaskManager) appendFinalResult(result Result) { tm.mutex.Lock() tm.results = append(tm.results, result) - tm.nodeResults[result.NodeKey] = result + tm.nodeResults[result.nodeKey] = result tm.mutex.Unlock() } @@ -102,11 +102,12 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, var result Result select { case <-ctx.Done(): - result = Result{TaskID: task.ID, NodeKey: node.Key, Error: ctx.Err()} + result = Result{TaskID: task.ID, nodeKey: node.Key, Error: ctx.Err()} tm.appendFinalResult(result) return default: result = node.handler(ctx, task) + result.nodeKey = node.Key if result.Error != nil { tm.appendFinalResult(result) return @@ -119,7 +120,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, edges := make([]Edge, len(node.Edges)) copy(edges, node.Edges) if result.Status != "" { - if conditions, ok := tm.dag.conditions[result.NodeKey]; ok { + if conditions, ok := tm.dag.conditions[result.nodeKey]; ok { if targetNodeKey, ok := conditions[result.Status]; ok { if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { edges = append(edges, Edge{ @@ -143,7 +144,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, var items []json.RawMessage err := json.Unmarshal(result.Payload, &items) if err != nil { - tm.appendFinalResult(Result{TaskID: task.ID, NodeKey: node.Key, Error: err}) + tm.appendFinalResult(Result{TaskID: task.ID, nodeKey: node.Key, Error: err}) return } for _, item := range items { From 04101856df6dafeb74805a4096656c4916d6dd60 Mon Sep 17 00:00:00 2001 From: Oarkflow Date: Tue, 8 Oct 2024 15:21:33 +0545 Subject: [PATCH 05/17] feat: separate broker --- examples/dag_v2.go | 10 +++---- v2/dag.go | 72 ++++++++++++++++++++++++++++++++++++++-------- v2/task_manager.go | 46 +++++++---------------------- 3 files changed, 75 insertions(+), 53 deletions(-) diff --git a/examples/dag_v2.go b/examples/dag_v2.go index 733f79a..1a9099d 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "fmt" - v2 "github.com/oarkflow/mq/v2" + "github.com/oarkflow/mq/v2" ) func handler1(ctx context.Context, task *v2.Task) v2.Result { @@ -63,9 +63,9 @@ func main() { dag.AddNode("F", handler6) dag.AddEdge("A", "B", v2.LoopEdge) dag.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) - dag.AddEdge("B", "C", v2.SimpleEdge) - dag.AddEdge("D", "F", v2.SimpleEdge) - dag.AddEdge("E", "F", v2.SimpleEdge) + dag.AddEdge("B", "C") + dag.AddEdge("D", "F") + dag.AddEdge("E", "F") initialPayload, _ := json.Marshal([]map[string]any{ {"user_id": 1, "age": 12}, @@ -73,6 +73,4 @@ func main() { }) rs := dag.ProcessTask(context.Background(), "A", initialPayload) fmt.Println(string(rs.Payload)) - rs = dag.ProcessTask(context.Background(), "A", initialPayload) - fmt.Println(string(rs.Payload)) } diff --git a/v2/dag.go b/v2/dag.go index 90ac9ca..ac011ad 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -3,6 +3,7 @@ package v2 import ( "context" "encoding/json" + "fmt" "sync" "github.com/oarkflow/xid" @@ -11,6 +12,7 @@ import ( type Handler func(ctx context.Context, task *Task) Result type Result struct { + Ctx context.Context TaskID string `json:"task_id"` Payload json.RawMessage `json:"payload"` Status string `json:"status"` @@ -18,6 +20,44 @@ type Result struct { nodeKey string } +func (r Result) Unmarshal(data any) error { + if r.Payload == nil { + return fmt.Errorf("payload is nil") + } + return json.Unmarshal(r.Payload, data) +} + +func (r Result) String() string { + return string(r.Payload) +} + +func HandleError(ctx context.Context, err error, status ...string) Result { + st := "Failed" + if len(status) > 0 { + st = status[0] + } + if err == nil { + return Result{} + } + return Result{ + Status: st, + Error: err, + Ctx: ctx, + } +} + +func (r Result) WithData(status string, data []byte) Result { + if r.Error != nil { + return r + } + return Result{ + Status: status, + Payload: data, + Error: nil, + Ctx: r.Ctx, + } +} + type Task struct { ID string `json:"id"` NodeKey string `json:"node_key"` @@ -25,6 +65,17 @@ type Task struct { Results map[string]Result `json:"results"` } +func NewTask(id string, payload json.RawMessage, nodeKey string, results ...map[string]Result) *Task { + if id == "" { + id = xid.New().String() + } + result := make(map[string]Result) + if len(results) > 0 && results[0] != nil { + result = results[0] + } + return &Task{ID: id, Payload: payload, NodeKey: nodeKey, Results: result} +} + type Node struct { Key string Edges []Edge @@ -33,6 +84,8 @@ type Node struct { type EdgeType int +func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge } + const ( SimpleEdge EdgeType = iota LoopEdge @@ -74,7 +127,7 @@ func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { tm.conditions[fromNode] = conditions } -func (tm *DAG) AddEdge(from, to string, edgeType EdgeType) { +func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { tm.mu.Lock() defer tm.mu.Unlock() fromNode, ok := tm.Nodes[from] @@ -85,23 +138,18 @@ func (tm *DAG) AddEdge(from, to string, edgeType EdgeType) { if !ok { return } - fromNode.Edges = append(fromNode.Edges, Edge{ - From: fromNode, - To: toNode, - Type: edgeType, - }) + edge := Edge{From: fromNode, To: toNode} + if len(edgeTypes) > 0 && edgeTypes[0].IsValid() { + edge.Type = edgeTypes[0] + } + fromNode.Edges = append(fromNode.Edges, edge) } func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) Result { tm.mu.Lock() defer tm.mu.Unlock() taskID := xid.New().String() - task := &Task{ - ID: taskID, - NodeKey: node, - Payload: payload, - Results: make(map[string]Result), - } + task := NewTask(taskID, payload, node, make(map[string]Result)) manager := NewTaskManager(tm) tm.taskContext[taskID] = manager return manager.processTask(ctx, node, task) diff --git a/v2/task_manager.go b/v2/task_manager.go index 874cb7e..42c0c50 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -43,13 +43,13 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *Tas tm.mutex.Lock() defer tm.mutex.Unlock() if len(tm.results) == 1 { - return tm.callback(tm.results[0]) + return tm.callback(ctx, tm.results[0]) } - return tm.callback(tm.results) + return tm.callback(ctx, tm.results) } } -func (tm *TaskManager) callback(results any) Result { +func (tm *TaskManager) callback(ctx context.Context, results any) Result { var rs Result switch res := results.(type) { case []Result: @@ -57,35 +57,26 @@ func (tm *TaskManager) callback(results any) Result { for i, result := range res { if i == 0 { rs.TaskID = result.TaskID + rs.Status = result.Status } var item json.RawMessage err := json.Unmarshal(result.Payload, &item) if err != nil { - rs.Error = err - return rs + return HandleError(ctx, err) } aggregatedOutput = append(aggregatedOutput, item) } finalOutput, err := json.Marshal(aggregatedOutput) - if err != nil { - rs.Error = err - return rs - } - rs.Payload = finalOutput + return HandleError(ctx, err).WithData(rs.Status, finalOutput) case Result: rs.TaskID = res.TaskID var item json.RawMessage err := json.Unmarshal(res.Payload, &item) if err != nil { - rs.Error = err - return rs + return HandleError(ctx, err) } finalOutput, err := json.Marshal(item) - if err != nil { - rs.Error = err - return rs - } - rs.Payload = finalOutput + return HandleError(ctx, err).WithData(res.Status, finalOutput) } return rs } @@ -116,18 +107,13 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, tm.mutex.Lock() task.Results[node.Key] = result tm.mutex.Unlock() - edges := make([]Edge, len(node.Edges)) copy(edges, node.Edges) if result.Status != "" { if conditions, ok := tm.dag.conditions[result.nodeKey]; ok { if targetNodeKey, ok := conditions[result.Status]; ok { if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { - edges = append(edges, Edge{ - From: node, - To: targetNode, - Type: SimpleEdge, - }) + edges = append(edges, Edge{From: node, To: targetNode}) } } } @@ -148,24 +134,14 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, return } for _, item := range items { - loopTask := &Task{ - ID: task.ID, - NodeKey: edge.From.Key, - Payload: item, - Results: task.Results, - } + loopTask := NewTask(task.ID, item, edge.From.Key, task.Results) tm.wg.Add(1) go tm.processNode(ctx, edge.To, loopTask, node) } case SimpleEdge: if edge.To != nil { tm.wg.Add(1) - t := &Task{ - ID: task.ID, - NodeKey: edge.From.Key, - Payload: result.Payload, - Results: task.Results, - } + t := NewTask(task.ID, result.Payload, edge.From.Key, task.Results) go tm.processNode(ctx, edge.To, t, node) } else if parentNode != nil { tm.appendFinalResult(result) From a226982aa2f7deeb4eb1bd55766ff02c9dd70242 Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 16:10:38 +0545 Subject: [PATCH 06/17] feat: add example --- examples/dag_v2.go | 13 +++++++------ v2/task_manager.go | 7 +++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/dag_v2.go b/examples/dag_v2.go index 1a9099d..bbf1496 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -4,17 +4,18 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq/v2" ) func handler1(ctx context.Context, task *v2.Task) v2.Result { - return v2.Result{TaskID: task.ID, Payload: task.Payload} + return v2.Result{Payload: task.Payload, Ctx: ctx} } func handler2(ctx context.Context, task *v2.Task) v2.Result { var user map[string]any json.Unmarshal(task.Payload, &user) - return v2.Result{TaskID: task.ID, Payload: task.Payload} + return v2.Result{Payload: task.Payload, Ctx: ctx} } func handler3(ctx context.Context, task *v2.Task) v2.Result { @@ -27,7 +28,7 @@ func handler3(ctx context.Context, task *v2.Task) v2.Result { } user["status"] = status resultPayload, _ := json.Marshal(user) - return v2.Result{TaskID: task.ID, Payload: resultPayload, Status: status} + return v2.Result{Payload: resultPayload, Status: status, Ctx: ctx} } func handler4(ctx context.Context, task *v2.Task) v2.Result { @@ -35,7 +36,7 @@ func handler4(ctx context.Context, task *v2.Task) v2.Result { json.Unmarshal(task.Payload, &user) user["final"] = "D" resultPayload, _ := json.Marshal(user) - return v2.Result{TaskID: task.ID, Payload: resultPayload} + return v2.Result{Payload: resultPayload, Ctx: ctx} } func handler5(ctx context.Context, task *v2.Task) v2.Result { @@ -43,14 +44,14 @@ func handler5(ctx context.Context, task *v2.Task) v2.Result { json.Unmarshal(task.Payload, &user) user["salary"] = "E" resultPayload, _ := json.Marshal(user) - return v2.Result{TaskID: task.ID, Payload: resultPayload} + return v2.Result{Payload: resultPayload, Ctx: ctx} } func handler6(ctx context.Context, task *v2.Task) v2.Result { var user map[string]any json.Unmarshal(task.Payload, &user) resultPayload, _ := json.Marshal(map[string]any{"storage": user}) - return v2.Result{TaskID: task.ID, Payload: resultPayload} + return v2.Result{Payload: resultPayload, Ctx: ctx} } func main() { diff --git a/v2/task_manager.go b/v2/task_manager.go index 42c0c50..ed6f484 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -104,6 +104,9 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, return } } + if result.Ctx == nil { + result.Ctx = ctx + } tm.mutex.Lock() task.Results[node.Key] = result tm.mutex.Unlock() @@ -136,13 +139,13 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, for _, item := range items { loopTask := NewTask(task.ID, item, edge.From.Key, task.Results) tm.wg.Add(1) - go tm.processNode(ctx, edge.To, loopTask, node) + go tm.processNode(result.Ctx, edge.To, loopTask, node) } case SimpleEdge: if edge.To != nil { tm.wg.Add(1) t := NewTask(task.ID, result.Payload, edge.From.Key, task.Results) - go tm.processNode(ctx, edge.To, t, node) + go tm.processNode(result.Ctx, edge.To, t, node) } else if parentNode != nil { tm.appendFinalResult(result) } From a4dff9c77b116a024b95642f69fc5197eb27b795 Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 16:18:53 +0545 Subject: [PATCH 07/17] feat: add example --- codec/codec.go | 4 +- consts/constants.go | 2 +- consumer.go | 8 ++-- ctx.go | 16 ++++--- dag/dag.go | 96 ++++++++++++++++++++--------------------- examples/dag.go | 2 +- examples/tasks/tasks.go | 27 ++++++------ options.go | 10 ++--- 8 files changed, 84 insertions(+), 81 deletions(-) diff --git a/codec/codec.go b/codec/codec.go index 2d397bf..ac71082 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -31,7 +31,7 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) { var buf bytes.Buffer - // Serialize Headers, Queue, Command, Payload, and Metadata + // Serialize Headers, Topic, Command, Payload, and Metadata if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil { return nil, "", fmt.Errorf("error serializing headers: %v", err) } @@ -62,7 +62,7 @@ func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool buf := bytes.NewReader(data) - // Deserialize Headers, Queue, Command, Payload, and Metadata + // Deserialize Headers, Topic, Command, Payload, and Metadata headers := make(map[string]string) if err := readLengthPrefixedJSON(buf, &headers); err != nil { return nil, fmt.Errorf("error deserializing headers: %v", err) diff --git a/consts/constants.go b/consts/constants.go index 2776c94..9abbc6b 100644 --- a/consts/constants.go +++ b/consts/constants.go @@ -54,7 +54,7 @@ var ( PublisherKey = "Publisher-Key" ContentType = "Content-Type" AwaitResponseKey = "Await-Response" - QueueKey = "Queue" + QueueKey = "Topic" TypeJson = "application/json" HeaderKey = "headers" TriggerNode = "triggerNode" diff --git a/consumer.go b/consumer.go index 2aa2926..609ead1 100644 --- a/consumer.go +++ b/consumer.go @@ -89,9 +89,9 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C return } ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) - result := c.ProcessTask(ctx, task) - result.MessageID = task.ID - result.Queue = msg.Queue + result := c.ProcessTask(ctx, &task) + result.TaskID = task.ID + result.Topic = msg.Queue if result.Status == "" { if result.Error != nil { result.Status = "FAILED" @@ -107,7 +107,7 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C } // ProcessTask handles a received task message and invokes the appropriate handler. -func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { +func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result { queue, _ := GetQueue(ctx) handler, exists := c.handlers[queue] if !exists { diff --git a/ctx.go b/ctx.go index 2817b3e..b1e66d1 100644 --- a/ctx.go +++ b/ctx.go @@ -16,15 +16,17 @@ import ( ) type Task struct { - ID string `json:"id"` - Payload json.RawMessage `json:"payload"` - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - Status string `json:"status"` - Error error `json:"error"` + ID string `json:"id"` + Results map[string]Result `json:"results"` + Topic string `json:"topic"` + Payload json.RawMessage `json:"payload"` + CreatedAt time.Time `json:"created_at"` + ProcessedAt time.Time `json:"processed_at"` + Status string `json:"status"` + Error error `json:"error"` } -type Handler func(context.Context, Task) Result +type Handler func(context.Context, *Task) Result func IsClosed(conn net.Conn) bool { _, err := conn.Read(make([]byte, 1)) diff --git a/dag/dag.go b/dag/dag.go index c0745ef..0f2b51f 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -124,9 +124,9 @@ func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID . return mq.Result{Error: err} } return mq.Result{ - Payload: payload, - Queue: queue, - MessageID: id, + Payload: payload, + Topic: queue, + TaskID: id, } } @@ -168,37 +168,37 @@ func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { return result } d.mu.Lock() - d.taskChMap[result.MessageID] = resultCh + d.taskChMap[result.TaskID] = resultCh d.mu.Unlock() finalResult := <-resultCh return finalResult } func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { - if con, ok := d.nodes[task.Queue]; ok { + if con, ok := d.nodes[task.Topic]; ok { return con.ProcessTask(ctx, mq.Task{ - ID: task.MessageID, + ID: task.TaskID, Payload: task.Payload, }) } - return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)} + return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Topic)} } func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { - if task.MessageID == "" { - task.MessageID = mq.NewID() + if task.TaskID == "" { + task.TaskID = mq.NewID() } - if task.Queue == "" { - task.Queue = d.FirstNode + if task.Topic == "" { + task.Topic = d.FirstNode } ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: task.Queue, + consts.QueueKey: task.Topic, }) result := d.processNode(ctx, task) if result.Error != nil { return result } - for _, target := range d.loopEdges[task.Queue] { + for _, target := range d.loopEdges[task.Topic] { var items, results []json.RawMessage if err := json.Unmarshal(result.Payload, &items); err != nil { return mq.Result{Error: err} @@ -208,9 +208,9 @@ func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { consts.QueueKey: target, }) result = d.sendSync(ctx, mq.Result{ - Payload: item, - Queue: target, - MessageID: result.MessageID, + Payload: item, + Topic: target, + TaskID: result.TaskID, }) if result.Error != nil { return result @@ -223,29 +223,29 @@ func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { } result.Payload = bt } - if conditions, ok := d.conditions[task.Queue]; ok { + if conditions, ok := d.conditions[task.Topic]; ok { if target, exists := conditions[result.Status]; exists { ctx = mq.SetHeaders(ctx, map[string]string{ consts.QueueKey: target, }) result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Queue: target, - MessageID: result.MessageID, + Payload: result.Payload, + Topic: target, + TaskID: result.TaskID, }) if result.Error != nil { return result } } } - if target, ok := d.edges[task.Queue]; ok { + if target, ok := d.edges[task.Topic]; ok { ctx = mq.SetHeaders(ctx, map[string]string{ consts.QueueKey: target, }) result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Queue: target, - MessageID: result.MessageID, + Payload: result.Payload, + Topic: target, + TaskID: result.TaskID, }) if result.Error != nil { return result @@ -260,7 +260,7 @@ func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) completed := false multipleResults := false if ok && triggeredNode != "" { - taskResults, ok := d.taskResults[task.MessageID] + taskResults, ok := d.taskResults[task.TaskID] if ok { nodeResult, exists := taskResults[triggeredNode] if exists { @@ -300,25 +300,25 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { } triggeredNode, ok := mq.GetTriggerNode(ctx) payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) - if loopNodes, exists := d.loopEdges[task.Queue]; exists { + if loopNodes, exists := d.loopEdges[task.Topic]; exists { var items []json.RawMessage if err := json.Unmarshal(payload, &items); err != nil { return mq.Result{Error: task.Error} } - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { + d.taskResults[task.TaskID] = map[string]*taskContext{ + task.Topic: { totalItems: len(items), multipleResults: true, }, } - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Topic}) for _, loopNode := range loopNodes { for _, item := range items { ctx = mq.SetHeaders(ctx, map[string]string{ consts.QueueKey: loopNode, }) - result := d.PublishTask(ctx, item, task.MessageID) + result := d.PublishTask(ctx, item, task.TaskID) if result.Error != nil { return result } @@ -328,51 +328,51 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { return task } if multipleResults && completed { - task.Queue = triggeredNode + task.Topic = triggeredNode } - if conditions, ok := d.conditions[task.Queue]; ok { + if conditions, ok := d.conditions[task.Topic]; ok { if target, exists := conditions[task.Status]; exists { - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { + d.taskResults[task.TaskID] = map[string]*taskContext{ + task.Topic: { totalItems: len(conditions), }, } ctx = mq.SetHeaders(ctx, map[string]string{ consts.QueueKey: target, - consts.TriggerNode: task.Queue, + consts.TriggerNode: task.Topic, }) - result := d.PublishTask(ctx, payload, task.MessageID) + result := d.PublishTask(ctx, payload, task.TaskID) if result.Error != nil { return result } } } else { - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) - edge, exists := d.edges[task.Queue] + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Topic}) + edge, exists := d.edges[task.Topic] if exists { - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { + d.taskResults[task.TaskID] = map[string]*taskContext{ + task.Topic: { totalItems: 1, }, } ctx = mq.SetHeaders(ctx, map[string]string{ consts.QueueKey: edge, }) - result := d.PublishTask(ctx, payload, task.MessageID) + result := d.PublishTask(ctx, payload, task.TaskID) if result.Error != nil { return result } } else if completed { d.mu.Lock() - if resultCh, ok := d.taskChMap[task.MessageID]; ok { + if resultCh, ok := d.taskChMap[task.TaskID]; ok { resultCh <- mq.Result{ - Payload: payload, - Queue: task.Queue, - MessageID: task.MessageID, - Status: "done", + Payload: payload, + Topic: task.Topic, + TaskID: task.TaskID, + Status: "done", } - delete(d.taskChMap, task.MessageID) - delete(d.taskResults, task.MessageID) + delete(d.taskChMap, task.TaskID) + delete(d.taskResults, task.TaskID) } d.mu.Unlock() } diff --git a/examples/dag.go b/examples/dag.go index 992b303..f2e66cd 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -82,7 +82,7 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ } w.Header().Set("Content-Type", "application/json") result := map[string]any{ - "message_id": rs.MessageID, + "message_id": rs.TaskID, "payload": string(rs.Payload), "error": rs.Error, } diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index fd7d534..3a6f64c 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -4,18 +4,19 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq" ) -func Node1(ctx context.Context, task mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, MessageID: task.ID} +func Node1(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Payload: task.Payload, TaskID: task.ID} } -func Node2(ctx context.Context, task mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, MessageID: task.ID} +func Node2(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Payload: task.Payload, TaskID: task.ID} } -func Node3(ctx context.Context, task mq.Task) mq.Result { +func Node3(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { @@ -23,10 +24,10 @@ func Node3(ctx context.Context, task mq.Task) mq.Result { } data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) bt, _ := json.Marshal(data) - return mq.Result{Payload: bt, MessageID: task.ID} + return mq.Result{Payload: bt, TaskID: task.ID} } -func Node4(ctx context.Context, task mq.Task) mq.Result { +func Node4(ctx context.Context, task *mq.Task) mq.Result { var data []map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { @@ -34,10 +35,10 @@ func Node4(ctx context.Context, task mq.Task) mq.Result { } payload := map[string]any{"storage": data} bt, _ := json.Marshal(payload) - return mq.Result{Payload: bt, MessageID: task.ID} + return mq.Result{Payload: bt, TaskID: task.ID} } -func CheckCondition(ctx context.Context, task mq.Task) mq.Result { +func CheckCondition(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { @@ -49,20 +50,20 @@ func CheckCondition(ctx context.Context, task mq.Task) mq.Result { } else { status = "fail" } - return mq.Result{Status: status, Payload: task.Payload, MessageID: task.ID} + return mq.Result{Status: status, Payload: task.Payload, TaskID: task.ID} } -func Pass(ctx context.Context, task mq.Task) mq.Result { +func Pass(ctx context.Context, task *mq.Task) mq.Result { fmt.Println("Pass") return mq.Result{Payload: task.Payload} } -func Fail(ctx context.Context, task mq.Task) mq.Result { +func Fail(ctx context.Context, task *mq.Task) mq.Result { fmt.Println("Fail") return mq.Result{Payload: []byte(`{"test2": "asdsa"}`)} } func Callback(ctx context.Context, task mq.Result) mq.Result { - fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue) + fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic) return mq.Result{} } diff --git a/options.go b/options.go index 76ea6ea..4f98d04 100644 --- a/options.go +++ b/options.go @@ -7,11 +7,11 @@ import ( ) type Result struct { - Payload json.RawMessage `json:"payload"` - Queue string `json:"queue"` - MessageID string `json:"message_id"` - Error error `json:"error,omitempty"` - Status string `json:"status"` + Payload json.RawMessage `json:"payload"` + Topic string `json:"topic"` + TaskID string `json:"task_id"` + Error error `json:"error,omitempty"` + Status string `json:"status"` } type TLSConfig struct { From 612eb535ec0fe1458c563a2e15ad180c355b4a79 Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 16:24:52 +0545 Subject: [PATCH 08/17] feat: add example --- examples/dag_v2.go | 25 ++++++++-------- options.go | 40 +++++++++++++++++++++++++ v2/dag.go | 73 ++++++---------------------------------------- v2/task_manager.go | 48 +++++++++++++++--------------- 4 files changed, 87 insertions(+), 99 deletions(-) diff --git a/examples/dag_v2.go b/examples/dag_v2.go index bbf1496..b3a7b34 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -5,20 +5,21 @@ import ( "encoding/json" "fmt" + "github.com/oarkflow/mq" "github.com/oarkflow/mq/v2" ) -func handler1(ctx context.Context, task *v2.Task) v2.Result { - return v2.Result{Payload: task.Payload, Ctx: ctx} +func handler1(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Payload: task.Payload, Ctx: ctx} } -func handler2(ctx context.Context, task *v2.Task) v2.Result { +func handler2(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) - return v2.Result{Payload: task.Payload, Ctx: ctx} + return mq.Result{Payload: task.Payload, Ctx: ctx} } -func handler3(ctx context.Context, task *v2.Task) v2.Result { +func handler3(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) age := int(user["age"].(float64)) @@ -28,30 +29,30 @@ func handler3(ctx context.Context, task *v2.Task) v2.Result { } user["status"] = status resultPayload, _ := json.Marshal(user) - return v2.Result{Payload: resultPayload, Status: status, Ctx: ctx} + return mq.Result{Payload: resultPayload, Status: status, Ctx: ctx} } -func handler4(ctx context.Context, task *v2.Task) v2.Result { +func handler4(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) user["final"] = "D" resultPayload, _ := json.Marshal(user) - return v2.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload, Ctx: ctx} } -func handler5(ctx context.Context, task *v2.Task) v2.Result { +func handler5(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) user["salary"] = "E" resultPayload, _ := json.Marshal(user) - return v2.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload, Ctx: ctx} } -func handler6(ctx context.Context, task *v2.Task) v2.Result { +func handler6(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) resultPayload, _ := json.Marshal(map[string]any{"storage": user}) - return v2.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload, Ctx: ctx} } func main() { diff --git a/options.go b/options.go index 4f98d04..520c8fe 100644 --- a/options.go +++ b/options.go @@ -3,10 +3,12 @@ package mq import ( "context" "encoding/json" + "fmt" "time" ) type Result struct { + Ctx context.Context Payload json.RawMessage `json:"payload"` Topic string `json:"topic"` TaskID string `json:"task_id"` @@ -14,6 +16,44 @@ type Result struct { Status string `json:"status"` } +func (r Result) Unmarshal(data any) error { + if r.Payload == nil { + return fmt.Errorf("payload is nil") + } + return json.Unmarshal(r.Payload, data) +} + +func (r Result) String() string { + return string(r.Payload) +} + +func HandleError(ctx context.Context, err error, status ...string) Result { + st := "Failed" + if len(status) > 0 { + st = status[0] + } + if err == nil { + return Result{} + } + return Result{ + Status: st, + Error: err, + Ctx: ctx, + } +} + +func (r Result) WithData(status string, data []byte) Result { + if r.Error != nil { + return r + } + return Result{ + Status: status, + Payload: data, + Error: nil, + Ctx: r.Ctx, + } +} + type TLSConfig struct { UseTLS bool CertPath string diff --git a/v2/dag.go b/v2/dag.go index ac011ad..24a7ad5 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -3,83 +3,28 @@ package v2 import ( "context" "encoding/json" - "fmt" "sync" "github.com/oarkflow/xid" + + "github.com/oarkflow/mq" ) -type Handler func(ctx context.Context, task *Task) Result - -type Result struct { - Ctx context.Context - TaskID string `json:"task_id"` - Payload json.RawMessage `json:"payload"` - Status string `json:"status"` - Error error `json:"error"` - nodeKey string -} - -func (r Result) Unmarshal(data any) error { - if r.Payload == nil { - return fmt.Errorf("payload is nil") - } - return json.Unmarshal(r.Payload, data) -} - -func (r Result) String() string { - return string(r.Payload) -} - -func HandleError(ctx context.Context, err error, status ...string) Result { - st := "Failed" - if len(status) > 0 { - st = status[0] - } - if err == nil { - return Result{} - } - return Result{ - Status: st, - Error: err, - Ctx: ctx, - } -} - -func (r Result) WithData(status string, data []byte) Result { - if r.Error != nil { - return r - } - return Result{ - Status: status, - Payload: data, - Error: nil, - Ctx: r.Ctx, - } -} - -type Task struct { - ID string `json:"id"` - NodeKey string `json:"node_key"` - Payload json.RawMessage `json:"payload"` - Results map[string]Result `json:"results"` -} - -func NewTask(id string, payload json.RawMessage, nodeKey string, results ...map[string]Result) *Task { +func NewTask(id string, payload json.RawMessage, nodeKey string, results ...map[string]mq.Result) *mq.Task { if id == "" { id = xid.New().String() } - result := make(map[string]Result) + result := make(map[string]mq.Result) if len(results) > 0 && results[0] != nil { result = results[0] } - return &Task{ID: id, Payload: payload, NodeKey: nodeKey, Results: result} + return &mq.Task{ID: id, Payload: payload, Topic: nodeKey, Results: result} } type Node struct { Key string Edges []Edge - handler Handler + handler mq.Handler } type EdgeType int @@ -112,7 +57,7 @@ func NewDAG() *DAG { } } -func (tm *DAG) AddNode(key string, handler Handler) { +func (tm *DAG) AddNode(key string, handler mq.Handler) { tm.mu.Lock() defer tm.mu.Unlock() tm.Nodes[key] = &Node{ @@ -145,11 +90,11 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { fromNode.Edges = append(fromNode.Edges, edge) } -func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) Result { +func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) mq.Result { tm.mu.Lock() defer tm.mu.Unlock() taskID := xid.New().String() - task := NewTask(taskID, payload, node, make(map[string]Result)) + task := NewTask(taskID, payload, node, make(map[string]mq.Result)) manager := NewTaskManager(tm) tm.taskContext[taskID] = manager return manager.processTask(ctx, node, task) diff --git a/v2/task_manager.go b/v2/task_manager.go index ed6f484..18ca0f9 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -5,30 +5,32 @@ import ( "encoding/json" "fmt" "sync" + + "github.com/oarkflow/mq" ) type TaskManager struct { dag *DAG wg sync.WaitGroup mutex sync.Mutex - results []Result - nodeResults map[string]Result + results []mq.Result + nodeResults map[string]mq.Result done chan struct{} } func NewTaskManager(d *DAG) *TaskManager { return &TaskManager{ dag: d, - nodeResults: make(map[string]Result), - results: make([]Result, 0), + nodeResults: make(map[string]mq.Result), + results: make([]mq.Result, 0), done: make(chan struct{}), } } -func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *Task) Result { +func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *mq.Task) mq.Result { node, ok := tm.dag.Nodes[nodeID] if !ok { - return Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} + return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } tm.wg.Add(1) go tm.processNode(ctx, node, task, nil) @@ -38,7 +40,7 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *Tas }() select { case <-ctx.Done(): - return Result{Error: ctx.Err()} + return mq.Result{Error: ctx.Err()} case <-tm.done: tm.mutex.Lock() defer tm.mutex.Unlock() @@ -49,10 +51,10 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *Tas } } -func (tm *TaskManager) callback(ctx context.Context, results any) Result { - var rs Result +func (tm *TaskManager) callback(ctx context.Context, results any) mq.Result { + var rs mq.Result switch res := results.(type) { - case []Result: + case []mq.Result: aggregatedOutput := make([]json.RawMessage, 0) for i, result := range res { if i == 0 { @@ -62,43 +64,43 @@ func (tm *TaskManager) callback(ctx context.Context, results any) Result { var item json.RawMessage err := json.Unmarshal(result.Payload, &item) if err != nil { - return HandleError(ctx, err) + return mq.HandleError(ctx, err) } aggregatedOutput = append(aggregatedOutput, item) } finalOutput, err := json.Marshal(aggregatedOutput) - return HandleError(ctx, err).WithData(rs.Status, finalOutput) - case Result: + return mq.HandleError(ctx, err).WithData(rs.Status, finalOutput) + case mq.Result: rs.TaskID = res.TaskID var item json.RawMessage err := json.Unmarshal(res.Payload, &item) if err != nil { - return HandleError(ctx, err) + return mq.HandleError(ctx, err) } finalOutput, err := json.Marshal(item) - return HandleError(ctx, err).WithData(res.Status, finalOutput) + return mq.HandleError(ctx, err).WithData(res.Status, finalOutput) } return rs } -func (tm *TaskManager) appendFinalResult(result Result) { +func (tm *TaskManager) appendFinalResult(result mq.Result) { tm.mutex.Lock() tm.results = append(tm.results, result) - tm.nodeResults[result.nodeKey] = result + tm.nodeResults[result.Topic] = result tm.mutex.Unlock() } -func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, parentNode *Node) { +func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Task, parentNode *Node) { defer tm.wg.Done() - var result Result + var result mq.Result select { case <-ctx.Done(): - result = Result{TaskID: task.ID, nodeKey: node.Key, Error: ctx.Err()} + result = mq.Result{TaskID: task.ID, Topic: node.Key, Error: ctx.Err()} tm.appendFinalResult(result) return default: result = node.handler(ctx, task) - result.nodeKey = node.Key + result.Topic = node.Key if result.Error != nil { tm.appendFinalResult(result) return @@ -113,7 +115,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, edges := make([]Edge, len(node.Edges)) copy(edges, node.Edges) if result.Status != "" { - if conditions, ok := tm.dag.conditions[result.nodeKey]; ok { + if conditions, ok := tm.dag.conditions[result.Topic]; ok { if targetNodeKey, ok := conditions[result.Status]; ok { if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { edges = append(edges, Edge{From: node, To: targetNode}) @@ -133,7 +135,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *Task, var items []json.RawMessage err := json.Unmarshal(result.Payload, &items) if err != nil { - tm.appendFinalResult(Result{TaskID: task.ID, nodeKey: node.Key, Error: err}) + tm.appendFinalResult(mq.Result{TaskID: task.ID, Topic: node.Key, Error: err}) return } for _, item := range items { From 75419c45bb9b41913bfed519f52d7ef8f75ed31a Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 17:02:37 +0545 Subject: [PATCH 09/17] feat: add example --- v2/dag.go | 12 +++++++----- v2/task_manager.go | 11 ++++++++--- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/v2/dag.go b/v2/dag.go index 24a7ad5..8a1d5ce 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -22,9 +22,9 @@ func NewTask(id string, payload json.RawMessage, nodeKey string, results ...map[ } type Node struct { - Key string - Edges []Edge - handler mq.Handler + Key string + Edges []Edge + consumer *mq.Consumer } type EdgeType int @@ -60,9 +60,11 @@ func NewDAG() *DAG { func (tm *DAG) AddNode(key string, handler mq.Handler) { tm.mu.Lock() defer tm.mu.Unlock() + con := mq.NewConsumer(key) + con.RegisterHandler(key, handler) tm.Nodes[key] = &Node{ - Key: key, - handler: handler, + Key: key, + consumer: con, } } diff --git a/v2/task_manager.go b/v2/task_manager.go index 18ca0f9..83c06a5 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/oarkflow/mq" + "github.com/oarkflow/mq/consts" ) type TaskManager struct { @@ -99,7 +100,8 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas tm.appendFinalResult(result) return default: - result = node.handler(ctx, task) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) + result = node.consumer.ProcessTask(ctx, task) result.Topic = node.Key if result.Error != nil { tm.appendFinalResult(result) @@ -109,6 +111,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas if result.Ctx == nil { result.Ctx = ctx } + ctx = result.Ctx tm.mutex.Lock() task.Results[node.Key] = result tm.mutex.Unlock() @@ -141,13 +144,15 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas for _, item := range items { loopTask := NewTask(task.ID, item, edge.From.Key, task.Results) tm.wg.Add(1) - go tm.processNode(result.Ctx, edge.To, loopTask, node) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) + go tm.processNode(ctx, edge.To, loopTask, node) } case SimpleEdge: if edge.To != nil { tm.wg.Add(1) t := NewTask(task.ID, result.Payload, edge.From.Key, task.Results) - go tm.processNode(result.Ctx, edge.To, t, node) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) + go tm.processNode(ctx, edge.To, t, node) } else if parentNode != nil { tm.appendFinalResult(result) } From ff2922eddffe4d7150574631a1ca08e3cd27613d Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 18:34:21 +0545 Subject: [PATCH 10/17] feat: add example --- ctx.go | 5 +++ dag/dag.go | 2 +- examples/dag.go | 6 ++-- examples/dag_v2.go | 79 +++++++++++++++++++++++++++++++++++----------- options.go | 3 -- v2/dag.go | 40 +++++++++++++++++++++-- v2/task_manager.go | 33 ++++++++++++------- 7 files changed, 129 insertions(+), 39 deletions(-) diff --git a/ctx.go b/ctx.go index b1e66d1..47bdfe1 100644 --- a/ctx.go +++ b/ctx.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "os" + "sync" "time" "github.com/oarkflow/xid" @@ -38,7 +39,11 @@ func IsClosed(conn net.Conn) bool { return false } +var m = sync.RWMutex{} + func SetHeaders(ctx context.Context, headers map[string]string) context.Context { + m.Lock() + defer m.Unlock() hd, ok := GetHeaders(ctx) if !ok { hd = make(map[string]string) diff --git a/dag/dag.go b/dag/dag.go index 0f2b51f..8139bc5 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -176,7 +176,7 @@ func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { if con, ok := d.nodes[task.Topic]; ok { - return con.ProcessTask(ctx, mq.Task{ + return con.ProcessTask(ctx, &mq.Task{ ID: task.TaskID, Payload: task.Payload, }) diff --git a/examples/dag.go b/examples/dag.go index f2e66cd..59f8431 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -1,5 +1,6 @@ package main +/* import ( "context" "encoding/json" @@ -46,13 +47,13 @@ func main() { }() time.Sleep(10 * time.Second) - /*d.Prepare() + d.Prepare() http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) err := d.Start(context.TODO(), ":8083") if err != nil { panic(err) - }*/ + } } func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { @@ -89,3 +90,4 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ json.NewEncoder(w).Encode(result) } } +*/ diff --git a/examples/dag_v2.go b/examples/dag_v2.go index b3a7b34..9bcf5ba 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -4,19 +4,21 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" "github.com/oarkflow/mq" "github.com/oarkflow/mq/v2" ) func handler1(ctx context.Context, task *mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, Ctx: ctx} + return mq.Result{Payload: task.Payload} } func handler2(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) - return mq.Result{Payload: task.Payload, Ctx: ctx} + return mq.Result{Payload: task.Payload} } func handler3(ctx context.Context, task *mq.Task) mq.Result { @@ -29,7 +31,7 @@ func handler3(ctx context.Context, task *mq.Task) mq.Result { } user["status"] = status resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Status: status, Ctx: ctx} + return mq.Result{Payload: resultPayload, Status: status} } func handler4(ctx context.Context, task *mq.Task) mq.Result { @@ -37,7 +39,7 @@ func handler4(ctx context.Context, task *mq.Task) mq.Result { json.Unmarshal(task.Payload, &user) user["final"] = "D" resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload} } func handler5(ctx context.Context, task *mq.Task) mq.Result { @@ -45,34 +47,73 @@ func handler5(ctx context.Context, task *mq.Task) mq.Result { json.Unmarshal(task.Payload, &user) user["salary"] = "E" resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload} } func handler6(ctx context.Context, task *mq.Task) mq.Result { var user map[string]any json.Unmarshal(task.Payload, &user) resultPayload, _ := json.Marshal(map[string]any{"storage": user}) - return mq.Result{Payload: resultPayload, Ctx: ctx} + return mq.Result{Payload: resultPayload} } +var ( + d = v2.NewDAG(mq.WithSyncMode(true)) +) + func main() { - dag := v2.NewDAG() - dag.AddNode("A", handler1) - dag.AddNode("B", handler2) - dag.AddNode("C", handler3) - dag.AddNode("D", handler4) - dag.AddNode("E", handler5) - dag.AddNode("F", handler6) - dag.AddEdge("A", "B", v2.LoopEdge) - dag.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) - dag.AddEdge("B", "C") - dag.AddEdge("D", "F") - dag.AddEdge("E", "F") + d.AddNode("A", handler1) + d.AddNode("B", handler2) + d.AddNode("C", handler3) + d.AddNode("D", handler4) + d.AddNode("E", handler5) + d.AddNode("F", handler6) + d.AddEdge("A", "B", v2.LoopEdge) + d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) + d.AddEdge("B", "C") + d.AddEdge("D", "F") + d.AddEdge("E", "F") initialPayload, _ := json.Marshal([]map[string]any{ {"user_id": 1, "age": 12}, {"user_id": 2, "age": 34}, }) - rs := dag.ProcessTask(context.Background(), "A", initialPayload) + rs := d.ProcessTask(context.Background(), "A", initialPayload) fmt.Println(string(rs.Payload)) + http.HandleFunc("POST /publish", requestHandler("publish")) + http.HandleFunc("POST /request", requestHandler("request")) + err := d.Start(context.TODO(), ":8083") + if err != nil { + panic(err) + } +} + +func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) + return + } + var payload []byte + if r.Body != nil { + defer r.Body.Close() + var err error + payload, err = io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + } else { + http.Error(w, "Empty request body", http.StatusBadRequest) + return + } + rs := d.ProcessTask(context.Background(), "A", payload) + w.Header().Set("Content-Type", "application/json") + result := map[string]any{ + "message_id": rs.TaskID, + "payload": string(rs.Payload), + "error": rs.Error, + } + json.NewEncoder(w).Encode(result) + } } diff --git a/options.go b/options.go index 520c8fe..096cc18 100644 --- a/options.go +++ b/options.go @@ -8,7 +8,6 @@ import ( ) type Result struct { - Ctx context.Context Payload json.RawMessage `json:"payload"` Topic string `json:"topic"` TaskID string `json:"task_id"` @@ -38,7 +37,6 @@ func HandleError(ctx context.Context, err error, status ...string) Result { return Result{ Status: st, Error: err, - Ctx: ctx, } } @@ -50,7 +48,6 @@ func (r Result) WithData(status string, data []byte) Result { Status: status, Payload: data, Error: nil, - Ctx: r.Ctx, } } diff --git a/v2/dag.go b/v2/dag.go index 8a1d5ce..9c61695 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -3,6 +3,8 @@ package v2 import ( "context" "encoding/json" + "log" + "net/http" "sync" "github.com/oarkflow/xid" @@ -44,17 +46,51 @@ type Edge struct { type DAG struct { Nodes map[string]*Node + server *mq.Broker taskContext map[string]*TaskManager conditions map[string]map[string]string mu sync.RWMutex } -func NewDAG() *DAG { - return &DAG{ +func NewDAG(opts ...mq.Option) *DAG { + d := &DAG{ Nodes: make(map[string]*Node), taskContext: make(map[string]*TaskManager), conditions: make(map[string]map[string]string), } + opts = append(opts, mq.WithCallback(d.onTaskCallback)) + d.server = mq.NewBroker(opts...) + return d +} + +func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { + if taskContext, ok := tm.taskContext[result.TaskID]; ok { + return taskContext.handleCallback(ctx, result) + } + return mq.Result{} +} + +func (tm *DAG) Start(ctx context.Context, addr string) error { + if tm.server.SyncMode() { + return nil + } + go func() { + err := tm.server.Start(ctx) + if err != nil { + panic(err) + } + }() + for _, con := range tm.Nodes { + go func(con *Node) { + con.consumer.Consume(ctx) + }(con) + } + log.Printf("HTTP server started on %s", addr) + config := tm.server.TLSConfig() + if config.UseTLS { + return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) + } + return http.ListenAndServe(addr, nil) } func (tm *DAG) AddNode(key string, handler mq.Handler) { diff --git a/v2/task_manager.go b/v2/task_manager.go index 83c06a5..435d163 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -46,13 +46,18 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *mq. tm.mutex.Lock() defer tm.mutex.Unlock() if len(tm.results) == 1 { - return tm.callback(ctx, tm.results[0]) + return tm.handleResult(ctx, tm.results[0]) } - return tm.callback(ctx, tm.results) + return tm.handleResult(ctx, tm.results) } } -func (tm *TaskManager) callback(ctx context.Context, results any) mq.Result { +func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result { + fmt.Println(string(result.Payload), result.Topic, result.TaskID) + return mq.Result{} +} + +func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result { var rs mq.Result switch res := results.(type) { case []mq.Result: @@ -101,17 +106,21 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas return default: ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) - result = node.consumer.ProcessTask(ctx, task) - result.Topic = node.Key - if result.Error != nil { - tm.appendFinalResult(result) - return + if tm.dag.server.SyncMode() { + result = node.consumer.ProcessTask(ctx, task) + result.Topic = node.Key + if result.Error != nil { + tm.appendFinalResult(result) + return + } + } else { + err := tm.dag.server.Publish(ctx, *task, node.Key) + if err != nil { + tm.appendFinalResult(mq.Result{Error: err}) + return + } } } - if result.Ctx == nil { - result.Ctx = ctx - } - ctx = result.Ctx tm.mutex.Lock() task.Results[node.Key] = result tm.mutex.Unlock() From 64d70ed7f1d6ee6c56520b037e6d4d4c3b2329c5 Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 20:14:15 +0545 Subject: [PATCH 11/17] feat: add example --- broker.go | 10 ++--- consumer.go | 39 +++++++------------ ctx.go | 93 +++++++++++++++++++++++--------------------- examples/consumer.go | 9 +++-- examples/dag_v2.go | 13 +++++-- v2/dag.go | 5 +-- v2/task_manager.go | 7 +++- 7 files changed, 86 insertions(+), 90 deletions(-) diff --git a/broker.go b/broker.go index 95ef01a..bfca49d 100644 --- a/broker.go +++ b/broker.go @@ -97,16 +97,12 @@ func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) { func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) { msg.Command = consts.RESPONSE - headers, ok := GetHeaders(ctx) - if !ok { - return - } b.HandleCallback(ctx, msg) - awaitResponse, ok := headers[consts.AwaitResponseKey] + awaitResponse, ok := GetAwaitResponse(ctx) if !(ok && awaitResponse == "true") { return } - publisherID, exists := headers[consts.PublisherKey] + publisherID, exists := GetPublisherID(ctx) if !exists { return } @@ -126,7 +122,7 @@ func (b *Broker) Publish(ctx context.Context, task Task, queue string) error { if err != nil { return err } - msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers) + msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.headers) b.broadcastToConsumers(ctx, msg) return nil } diff --git a/consumer.go b/consumer.go index 609ead1..6911868 100644 --- a/consumer.go +++ b/consumer.go @@ -3,7 +3,6 @@ package mq import ( "context" "encoding/json" - "errors" "fmt" "log" "net" @@ -19,20 +18,21 @@ import ( // Consumer structure to hold consumer-specific configurations and state. type Consumer struct { - id string - handlers map[string]Handler - conn net.Conn - queues []string - opts Options + id string + handler Handler + conn net.Conn + queue string + opts Options } // NewConsumer initializes a new consumer with the provided options. -func NewConsumer(id string, opts ...Option) *Consumer { +func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer { options := setupOptions(opts...) return &Consumer{ - handlers: make(map[string]Handler), - id: id, - opts: options, + id: id, + opts: options, + queue: queue, + handler: handler, } } @@ -108,12 +108,7 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C // ProcessTask handles a received task message and invokes the appropriate handler. func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result { - queue, _ := GetQueue(ctx) - handler, exists := c.handlers[queue] - if !exists { - return Result{Error: errors.New("No handler for queue " + queue)} - } - return handler(ctx, msg) + return c.handler(ctx, msg) } // AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration. @@ -159,10 +154,8 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } - for _, q := range c.queues { - if err := c.subscribe(ctx, q); err != nil { - return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) - } + if err := c.subscribe(ctx, c.queue); err != nil { + return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) } var wg sync.WaitGroup wg.Add(1) @@ -191,9 +184,3 @@ func (c *Consumer) waitForAck(conn net.Conn) error { } return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) } - -// RegisterHandler registers a handler for a queue. -func (c *Consumer) RegisterHandler(queue string, handler Handler) { - c.queues = append(c.queues, queue) - c.handlers[queue] = handler -} diff --git a/ctx.go b/ctx.go index 47bdfe1..a48bb2c 100644 --- a/ctx.go +++ b/ctx.go @@ -37,91 +37,94 @@ func IsClosed(conn net.Conn) bool { } } return false +} // HeaderMap wraps a map and a mutex for thread-safe access +type HeaderMap struct { + mu sync.RWMutex + headers map[string]string } -var m = sync.RWMutex{} +// NewHeaderMap initializes a new HeaderMap +func NewHeaderMap() *HeaderMap { + return &HeaderMap{ + headers: make(map[string]string), + } +} func SetHeaders(ctx context.Context, headers map[string]string) context.Context { - m.Lock() - defer m.Unlock() - hd, ok := GetHeaders(ctx) - if !ok { - hd = make(map[string]string) + hd, _ := GetHeaders(ctx) + if hd == nil { + hd = NewHeaderMap() } + hd.mu.Lock() + defer hd.mu.Unlock() for key, val := range headers { - hd[key] = val + hd.headers[key] = val } return context.WithValue(ctx, consts.HeaderKey, hd) } func WithHeaders(ctx context.Context, headers map[string]string) map[string]string { - hd, ok := GetHeaders(ctx) - if !ok { - hd = make(map[string]string) + hd, _ := GetHeaders(ctx) + if hd == nil { + hd = NewHeaderMap() } + hd.mu.Lock() + defer hd.mu.Unlock() for key, val := range headers { - hd[key] = val + hd.headers[key] = val } - return hd + return getMapAsRegularMap(hd) } -func GetHeaders(ctx context.Context) (map[string]string, bool) { - headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) +func GetHeaders(ctx context.Context) (*HeaderMap, bool) { + headers, ok := ctx.Value(consts.HeaderKey).(*HeaderMap) return headers, ok } func GetHeader(ctx context.Context, key string) (string, bool) { - headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) + headers, ok := GetHeaders(ctx) if !ok { return "", false } - val, ok := headers[key] + headers.mu.RLock() + defer headers.mu.RUnlock() + val, ok := headers.headers[key] return val, ok } func GetContentType(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.ContentType] - return contentType, ok + return GetHeader(ctx, consts.ContentType) } func GetQueue(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.QueueKey] - return contentType, ok + return GetHeader(ctx, consts.QueueKey) } func GetConsumerID(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.ConsumerKey] - return contentType, ok + return GetHeader(ctx, consts.ConsumerKey) } func GetTriggerNode(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.TriggerNode] - return contentType, ok + return GetHeader(ctx, consts.TriggerNode) +} + +func GetAwaitResponse(ctx context.Context) (string, bool) { + return GetHeader(ctx, consts.AwaitResponseKey) } func GetPublisherID(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false + return GetHeader(ctx, consts.PublisherKey) +} + +// Helper function to convert HeaderMap to a regular map +func getMapAsRegularMap(hd *HeaderMap) map[string]string { + result := make(map[string]string) + hd.mu.RLock() + defer hd.mu.RUnlock() + for key, value := range hd.headers { + result[key] = value } - contentType, ok := headers[consts.PublisherKey] - return contentType, ok + return result } func NewID() string { diff --git a/examples/consumer.go b/examples/consumer.go index a312348..7b9575e 100644 --- a/examples/consumer.go +++ b/examples/consumer.go @@ -2,15 +2,16 @@ package main import ( "context" + "github.com/oarkflow/mq" "github.com/oarkflow/mq/examples/tasks" ) func main() { - consumer := mq.NewConsumer("consumer-1") + consumer1 := mq.NewConsumer("consumer-1", "queue1", tasks.Node1) + consumer2 := mq.NewConsumer("consumer-2", "queue2", tasks.Node2) // consumer := mq.NewConsumer("consumer-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key")) - consumer.RegisterHandler("queue1", tasks.Node1) - consumer.RegisterHandler("queue2", tasks.Node2) - consumer.Consume(context.Background()) + go consumer1.Consume(context.Background()) + consumer2.Consume(context.Background()) } diff --git a/examples/dag_v2.go b/examples/dag_v2.go index 9bcf5ba..9d21f19 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -78,14 +78,19 @@ func main() { {"user_id": 1, "age": 12}, {"user_id": 2, "age": 34}, }) - rs := d.ProcessTask(context.Background(), "A", initialPayload) - fmt.Println(string(rs.Payload)) - http.HandleFunc("POST /publish", requestHandler("publish")) + for i := 0; i < 100; i++ { + rs := d.ProcessTask(context.Background(), "A", initialPayload) + if rs.Error != nil { + panic(rs.Error) + } + fmt.Println(string(rs.Payload)) + } + /*http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) err := d.Start(context.TODO(), ":8083") if err != nil { panic(err) - } + }*/ } func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { diff --git a/v2/dag.go b/v2/dag.go index 9c61695..c1e26bc 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -96,8 +96,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { func (tm *DAG) AddNode(key string, handler mq.Handler) { tm.mu.Lock() defer tm.mu.Unlock() - con := mq.NewConsumer(key) - con.RegisterHandler(key, handler) + con := mq.NewConsumer(key, key, handler) tm.Nodes[key] = &Node{ Key: key, consumer: con, @@ -133,7 +132,7 @@ func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) mq. defer tm.mu.Unlock() taskID := xid.New().String() task := NewTask(taskID, payload, node, make(map[string]mq.Result)) - manager := NewTaskManager(tm) + manager := NewTaskManager(tm, taskID) tm.taskContext[taskID] = manager return manager.processTask(ctx, node, task) } diff --git a/v2/task_manager.go b/v2/task_manager.go index 435d163..bdec392 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -11,6 +11,7 @@ import ( ) type TaskManager struct { + taskID string dag *DAG wg sync.WaitGroup mutex sync.Mutex @@ -19,12 +20,13 @@ type TaskManager struct { done chan struct{} } -func NewTaskManager(d *DAG) *TaskManager { +func NewTaskManager(d *DAG, taskID string) *TaskManager { return &TaskManager{ dag: d, nodeResults: make(map[string]mq.Result), results: make([]mq.Result, 0), done: make(chan struct{}), + taskID: taskID, } } @@ -67,6 +69,9 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result rs.TaskID = result.TaskID rs.Status = result.Status } + if result.Error != nil { + return mq.HandleError(ctx, result.Error) + } var item json.RawMessage err := json.Unmarshal(result.Payload, &item) if err != nil { From ddf37d090a1087d5e0efc2b2f2edeef941cc5ef0 Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 21:17:06 +0545 Subject: [PATCH 12/17] feat: add example --- v2/task_manager.go | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/v2/task_manager.go b/v2/task_manager.go index bdec392..a28a063 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -18,6 +18,7 @@ type TaskManager struct { results []mq.Result nodeResults map[string]mq.Result done chan struct{} + finalResult chan mq.Result // Channel to collect final results } func NewTaskManager(d *DAG, taskID string) *TaskManager { @@ -27,6 +28,7 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { results: make([]mq.Result, 0), done: make(chan struct{}), taskID: taskID, + finalResult: make(chan mq.Result), // Initialize finalResult channel } } @@ -36,7 +38,7 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *mq. return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } tm.wg.Add(1) - go tm.processNode(ctx, node, task, nil) + go tm.processNode(ctx, node, task) go func() { tm.wg.Wait() close(tm.done) @@ -55,7 +57,6 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *mq. } func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result { - fmt.Println(string(result.Payload), result.Topic, result.TaskID) return mq.Result{} } @@ -101,7 +102,7 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) { tm.mutex.Unlock() } -func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Task, parentNode *Node) { +func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Task) { defer tm.wg.Done() var result mq.Result select { @@ -141,9 +142,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas } } if len(edges) == 0 { - if parentNode != nil { - tm.appendFinalResult(result) - } + tm.appendFinalResult(result) return } for _, edge := range edges { @@ -159,16 +158,14 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas loopTask := NewTask(task.ID, item, edge.From.Key, task.Results) tm.wg.Add(1) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, loopTask, node) + go tm.processNode(ctx, edge.To, loopTask) } case SimpleEdge: if edge.To != nil { tm.wg.Add(1) t := NewTask(task.ID, result.Payload, edge.From.Key, task.Results) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, t, node) - } else if parentNode != nil { - tm.appendFinalResult(result) + go tm.processNode(ctx, edge.To, t) } } } From af11b3826dc33e17231f769b3fb0018ebaedc5df Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 21:38:40 +0545 Subject: [PATCH 13/17] feat: add example --- broker.go | 2 +- ctx.go | 15 +++++++-------- dag/dag.go | 5 ++--- examples/dag_v2.go | 13 +++++++------ v2/dag.go | 11 +++-------- v2/task_manager.go | 40 +++++++++++++++++++--------------------- 6 files changed, 39 insertions(+), 47 deletions(-) diff --git a/broker.go b/broker.go index bfca49d..f35ef57 100644 --- a/broker.go +++ b/broker.go @@ -116,7 +116,7 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) } } -func (b *Broker) Publish(ctx context.Context, task Task, queue string) error { +func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error { headers, _ := GetHeaders(ctx) payload, err := json.Marshal(task) if err != nil { diff --git a/ctx.go b/ctx.go index a48bb2c..8907cb0 100644 --- a/ctx.go +++ b/ctx.go @@ -17,14 +17,13 @@ import ( ) type Task struct { - ID string `json:"id"` - Results map[string]Result `json:"results"` - Topic string `json:"topic"` - Payload json.RawMessage `json:"payload"` - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - Status string `json:"status"` - Error error `json:"error"` + ID string `json:"id"` + Topic string `json:"topic"` + Payload json.RawMessage `json:"payload"` + CreatedAt time.Time `json:"created_at"` + ProcessedAt time.Time `json:"processed_at"` + Status string `json:"status"` + Error error `json:"error"` } type Handler func(context.Context, *Task) Result diff --git a/dag/dag.go b/dag/dag.go index 8139bc5..0915ecf 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -50,11 +50,10 @@ func New(opts ...mq.Option) *DAG { func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { tlsConfig := d.server.TLSConfig() - con := mq.NewConsumer(name, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath)) + con := mq.NewConsumer(name, name, handler, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath)) if len(firstNode) > 0 { d.FirstNode = name } - con.RegisterHandler(name, handler) d.nodes[name] = con } @@ -114,7 +113,7 @@ func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID . } else { id = mq.NewID() } - task := mq.Task{ + task := &mq.Task{ ID: id, Payload: payload, CreatedAt: time.Now(), diff --git a/examples/dag_v2.go b/examples/dag_v2.go index 9d21f19..82e12eb 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -78,13 +78,14 @@ func main() { {"user_id": 1, "age": 12}, {"user_id": 2, "age": 34}, }) - for i := 0; i < 100; i++ { - rs := d.ProcessTask(context.Background(), "A", initialPayload) - if rs.Error != nil { - panic(rs.Error) - } - fmt.Println(string(rs.Payload)) + /*for i := 0; i < 100; i++ { + + }*/ + rs := d.ProcessTask(context.Background(), "A", initialPayload) + if rs.Error != nil { + panic(rs.Error) } + fmt.Println(rs.TaskID, "Task", string(rs.Payload)) /*http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) err := d.Start(context.TODO(), ":8083") diff --git a/v2/dag.go b/v2/dag.go index c1e26bc..8868591 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -12,15 +12,11 @@ import ( "github.com/oarkflow/mq" ) -func NewTask(id string, payload json.RawMessage, nodeKey string, results ...map[string]mq.Result) *mq.Task { +func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task { if id == "" { id = xid.New().String() } - result := make(map[string]mq.Result) - if len(results) > 0 && results[0] != nil { - result = results[0] - } - return &mq.Task{ID: id, Payload: payload, Topic: nodeKey, Results: result} + return &mq.Task{ID: id, Payload: payload, Topic: nodeKey} } type Node struct { @@ -131,8 +127,7 @@ func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) mq. tm.mu.Lock() defer tm.mu.Unlock() taskID := xid.New().String() - task := NewTask(taskID, payload, node, make(map[string]mq.Result)) manager := NewTaskManager(tm, taskID) tm.taskContext[taskID] = manager - return manager.processTask(ctx, node, task) + return manager.processTask(ctx, node, payload) } diff --git a/v2/task_manager.go b/v2/task_manager.go index a28a063..66daedf 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -32,13 +32,13 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { } } -func (tm *TaskManager) processTask(ctx context.Context, nodeID string, task *mq.Task) mq.Result { +func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { node, ok := tm.dag.Nodes[nodeID] if !ok { return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } tm.wg.Add(1) - go tm.processNode(ctx, node, task) + go tm.processNode(ctx, node, payload) go func() { tm.wg.Wait() close(tm.done) @@ -65,10 +65,10 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result switch res := results.(type) { case []mq.Result: aggregatedOutput := make([]json.RawMessage, 0) + status := "" for i, result := range res { if i == 0 { - rs.TaskID = result.TaskID - rs.Status = result.Status + status = result.Status } if result.Error != nil { return mq.HandleError(ctx, result.Error) @@ -81,16 +81,16 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result aggregatedOutput = append(aggregatedOutput, item) } finalOutput, err := json.Marshal(aggregatedOutput) - return mq.HandleError(ctx, err).WithData(rs.Status, finalOutput) - case mq.Result: - rs.TaskID = res.TaskID - var item json.RawMessage - err := json.Unmarshal(res.Payload, &item) if err != nil { return mq.HandleError(ctx, err) } - finalOutput, err := json.Marshal(item) - return mq.HandleError(ctx, err).WithData(res.Status, finalOutput) + return mq.Result{ + TaskID: tm.taskID, + Payload: finalOutput, + Status: status, + } + case mq.Result: + return res } return rs } @@ -102,25 +102,25 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) { tm.mutex.Unlock() } -func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Task) { +func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { defer tm.wg.Done() var result mq.Result select { case <-ctx.Done(): - result = mq.Result{TaskID: task.ID, Topic: node.Key, Error: ctx.Err()} + result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err()} tm.appendFinalResult(result) return default: ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) if tm.dag.server.SyncMode() { - result = node.consumer.ProcessTask(ctx, task) + result = node.consumer.ProcessTask(ctx, NewTask(tm.taskID, payload, node.Key)) result.Topic = node.Key if result.Error != nil { tm.appendFinalResult(result) return } } else { - err := tm.dag.server.Publish(ctx, *task, node.Key) + err := tm.dag.server.Publish(ctx, NewTask(tm.taskID, payload, node.Key), node.Key) if err != nil { tm.appendFinalResult(mq.Result{Error: err}) return @@ -128,7 +128,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas } } tm.mutex.Lock() - task.Results[node.Key] = result + tm.nodeResults[node.Key] = result tm.mutex.Unlock() edges := make([]Edge, len(node.Edges)) copy(edges, node.Edges) @@ -151,21 +151,19 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, task *mq.Tas var items []json.RawMessage err := json.Unmarshal(result.Payload, &items) if err != nil { - tm.appendFinalResult(mq.Result{TaskID: task.ID, Topic: node.Key, Error: err}) + tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}) return } for _, item := range items { - loopTask := NewTask(task.ID, item, edge.From.Key, task.Results) tm.wg.Add(1) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, loopTask) + go tm.processNode(ctx, edge.To, item) } case SimpleEdge: if edge.To != nil { tm.wg.Add(1) - t := NewTask(task.ID, result.Payload, edge.From.Key, task.Results) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, t) + go tm.processNode(ctx, edge.To, result.Payload) } } } From 3e8f47086f50987a353ea5dc7ebafeee3d5f4d1c Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 23:14:54 +0545 Subject: [PATCH 14/17] feat: add example --- consumer.go | 1 + ctx.go | 2 - dag/dag.go | 381 --------------------------------------------- examples/dag.go | 108 +++++++------ examples/dag_v2.go | 125 --------------- v2/dag.go | 2 + v2/task_manager.go | 161 +++++++++++-------- 7 files changed, 166 insertions(+), 614 deletions(-) delete mode 100644 dag/dag.go delete mode 100644 examples/dag_v2.go diff --git a/consumer.go b/consumer.go index 6911868..26b62ab 100644 --- a/consumer.go +++ b/consumer.go @@ -154,6 +154,7 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } + if err := c.subscribe(ctx, c.queue); err != nil { return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) } diff --git a/ctx.go b/ctx.go index 8907cb0..eaf139a 100644 --- a/ctx.go +++ b/ctx.go @@ -118,8 +118,6 @@ func GetPublisherID(ctx context.Context) (string, bool) { // Helper function to convert HeaderMap to a regular map func getMapAsRegularMap(hd *HeaderMap) map[string]string { result := make(map[string]string) - hd.mu.RLock() - defer hd.mu.RUnlock() for key, value := range hd.headers { result[key] = value } diff --git a/dag/dag.go b/dag/dag.go deleted file mode 100644 index 0915ecf..0000000 --- a/dag/dag.go +++ /dev/null @@ -1,381 +0,0 @@ -package dag - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "sync" - "time" - - "github.com/oarkflow/mq/consts" - - "github.com/oarkflow/mq" -) - -type taskContext struct { - totalItems int - completed int - results []json.RawMessage - result json.RawMessage - multipleResults bool -} - -type DAG struct { - FirstNode string - server *mq.Broker - nodes map[string]*mq.Consumer - edges map[string]string - conditions map[string]map[string]string - loopEdges map[string][]string - taskChMap map[string]chan mq.Result - taskResults map[string]map[string]*taskContext - mu sync.Mutex -} - -func New(opts ...mq.Option) *DAG { - d := &DAG{ - nodes: make(map[string]*mq.Consumer), - edges: make(map[string]string), - conditions: make(map[string]map[string]string), - loopEdges: make(map[string][]string), - taskChMap: make(map[string]chan mq.Result), - taskResults: make(map[string]map[string]*taskContext), - } - opts = append(opts, mq.WithCallback(d.TaskCallback)) - d.server = mq.NewBroker(opts...) - return d -} - -func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { - tlsConfig := d.server.TLSConfig() - con := mq.NewConsumer(name, name, handler, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath)) - if len(firstNode) > 0 { - d.FirstNode = name - } - d.nodes[name] = con -} - -func (d *DAG) AddCondition(fromNode string, conditions map[string]string) { - d.conditions[fromNode] = conditions -} - -func (d *DAG) AddEdge(fromNode string, toNodes string) { - d.edges[fromNode] = toNodes -} - -func (d *DAG) AddLoop(fromNode string, toNode ...string) { - d.loopEdges[fromNode] = toNode -} - -func (d *DAG) Prepare() { - if d.FirstNode == "" { - firstNode, ok := d.FindFirstNode() - if ok && firstNode != "" { - d.FirstNode = firstNode - } - } -} - -func (d *DAG) Start(ctx context.Context, addr string) error { - d.Prepare() - if d.server.SyncMode() { - return nil - } - go func() { - err := d.server.Start(ctx) - if err != nil { - panic(err) - } - }() - for _, con := range d.nodes { - go func(con *mq.Consumer) { - con.Consume(ctx) - }(con) - } - log.Printf("HTTP server started on %s", addr) - config := d.server.TLSConfig() - if config.UseTLS { - return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) - } - return http.ListenAndServe(addr, nil) -} - -func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) mq.Result { - queue, ok := mq.GetQueue(ctx) - if !ok { - queue = d.FirstNode - } - var id string - if len(taskID) > 0 { - id = taskID[0] - } else { - id = mq.NewID() - } - task := &mq.Task{ - ID: id, - Payload: payload, - CreatedAt: time.Now(), - } - err := d.server.Publish(ctx, task, queue) - if err != nil { - return mq.Result{Error: err} - } - return mq.Result{ - Payload: payload, - Topic: queue, - TaskID: id, - } -} - -func (d *DAG) FindFirstNode() (string, bool) { - inDegree := make(map[string]int) - for n, _ := range d.nodes { - inDegree[n] = 0 - } - for _, outNode := range d.edges { - inDegree[outNode]++ - } - for _, targets := range d.loopEdges { - for _, outNode := range targets { - inDegree[outNode]++ - } - } - for n, count := range inDegree { - if count == 0 { - return n, true - } - } - return "", false -} - -func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result { - return d.sendSync(ctx, mq.Result{Payload: payload}) -} - -func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { - if d.FirstNode == "" { - return mq.Result{Error: fmt.Errorf("initial node not defined")} - } - if d.server.SyncMode() { - return d.sendSync(ctx, mq.Result{Payload: payload}) - } - resultCh := make(chan mq.Result) - result := d.PublishTask(ctx, payload) - if result.Error != nil { - return result - } - d.mu.Lock() - d.taskChMap[result.TaskID] = resultCh - d.mu.Unlock() - finalResult := <-resultCh - return finalResult -} - -func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { - if con, ok := d.nodes[task.Topic]; ok { - return con.ProcessTask(ctx, &mq.Task{ - ID: task.TaskID, - Payload: task.Payload, - }) - } - return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Topic)} -} - -func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { - if task.TaskID == "" { - task.TaskID = mq.NewID() - } - if task.Topic == "" { - task.Topic = d.FirstNode - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: task.Topic, - }) - result := d.processNode(ctx, task) - if result.Error != nil { - return result - } - for _, target := range d.loopEdges[task.Topic] { - var items, results []json.RawMessage - if err := json.Unmarshal(result.Payload, &items); err != nil { - return mq.Result{Error: err} - } - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: item, - Topic: target, - TaskID: result.TaskID, - }) - if result.Error != nil { - return result - } - results = append(results, result.Payload) - } - bt, err := json.Marshal(results) - if err != nil { - return mq.Result{Error: err} - } - result.Payload = bt - } - if conditions, ok := d.conditions[task.Topic]; ok { - if target, exists := conditions[result.Status]; exists { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Topic: target, - TaskID: result.TaskID, - }) - if result.Error != nil { - return result - } - } - } - if target, ok := d.edges[task.Topic]; ok { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Topic: target, - TaskID: result.TaskID, - }) - if result.Error != nil { - return result - } - } - return result -} - -func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) { - var result any - var payload []byte - completed := false - multipleResults := false - if ok && triggeredNode != "" { - taskResults, ok := d.taskResults[task.TaskID] - if ok { - nodeResult, exists := taskResults[triggeredNode] - if exists { - multipleResults = nodeResult.multipleResults - nodeResult.completed++ - if nodeResult.completed == nodeResult.totalItems { - completed = true - } - if multipleResults { - nodeResult.results = append(nodeResult.results, task.Payload) - if completed { - result = nodeResult.results - } - } else { - nodeResult.result = task.Payload - if completed { - result = nodeResult.result - } - } - } - if completed { - delete(taskResults, triggeredNode) - } - } - } - if completed { - payload, _ = json.Marshal(result) - } else { - payload = task.Payload - } - return payload, completed, multipleResults -} - -func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { - if task.Error != nil { - return mq.Result{Error: task.Error} - } - triggeredNode, ok := mq.GetTriggerNode(ctx) - payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) - if loopNodes, exists := d.loopEdges[task.Topic]; exists { - var items []json.RawMessage - if err := json.Unmarshal(payload, &items); err != nil { - return mq.Result{Error: task.Error} - } - d.taskResults[task.TaskID] = map[string]*taskContext{ - task.Topic: { - totalItems: len(items), - multipleResults: true, - }, - } - - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Topic}) - for _, loopNode := range loopNodes { - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: loopNode, - }) - result := d.PublishTask(ctx, item, task.TaskID) - if result.Error != nil { - return result - } - } - } - - return task - } - if multipleResults && completed { - task.Topic = triggeredNode - } - if conditions, ok := d.conditions[task.Topic]; ok { - if target, exists := conditions[task.Status]; exists { - d.taskResults[task.TaskID] = map[string]*taskContext{ - task.Topic: { - totalItems: len(conditions), - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - consts.TriggerNode: task.Topic, - }) - result := d.PublishTask(ctx, payload, task.TaskID) - if result.Error != nil { - return result - } - } - } else { - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Topic}) - edge, exists := d.edges[task.Topic] - if exists { - d.taskResults[task.TaskID] = map[string]*taskContext{ - task.Topic: { - totalItems: 1, - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: edge, - }) - result := d.PublishTask(ctx, payload, task.TaskID) - if result.Error != nil { - return result - } - } else if completed { - d.mu.Lock() - if resultCh, ok := d.taskChMap[task.TaskID]; ok { - resultCh <- mq.Result{ - Payload: payload, - Topic: task.Topic, - TaskID: task.TaskID, - Status: "done", - } - delete(d.taskChMap, task.TaskID) - delete(d.taskResults, task.TaskID) - } - d.mu.Unlock() - } - } - - return task -} diff --git a/examples/dag.go b/examples/dag.go index 59f8431..76bf297 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -1,53 +1,79 @@ package main -/* import ( "context" "encoding/json" - "fmt" "io" "net/http" - "time" "github.com/oarkflow/mq" - "github.com/oarkflow/mq/dag" - "github.com/oarkflow/mq/examples/tasks" + "github.com/oarkflow/mq/v2" ) -var d *dag.DAG +func handler1(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Payload: task.Payload} +} + +func handler2(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + return mq.Result{Payload: task.Payload} +} + +func handler3(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + age := int(user["age"].(float64)) + status := "FAIL" + if age > 20 { + status = "PASS" + } + user["status"] = status + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload, Status: status} +} + +func handler4(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["final"] = "D" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} +} + +func handler5(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["salary"] = "E" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} +} + +func handler6(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + resultPayload, _ := json.Marshal(map[string]any{"storage": user}) + return mq.Result{Payload: resultPayload} +} + +var ( + d = v2.NewDAG(mq.WithSyncMode(false)) +) func main() { - d = dag.New(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.crt")) - d.AddNode("queue1", tasks.Node1, true) - d.AddNode("queue2", tasks.Node2) - d.AddNode("queue3", tasks.Node3) - d.AddNode("queue4", tasks.Node4) + d.AddNode("A", handler1) + d.AddNode("B", handler2) + d.AddNode("C", handler3) + d.AddNode("D", handler4) + d.AddNode("E", handler5) + d.AddNode("F", handler6) + d.AddEdge("A", "B", v2.LoopEdge) + d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) + d.AddEdge("B", "C") + d.AddEdge("D", "F") + d.AddEdge("E", "F") - d.AddNode("queue5", tasks.CheckCondition) - d.AddNode("queue6", tasks.Pass) - d.AddNode("queue7", tasks.Fail) - - d.AddCondition("queue5", map[string]string{"pass": "queue6", "fail": "queue7"}) - d.AddEdge("queue1", "queue2") - d.AddEdge("queue2", "queue4") - d.AddEdge("queue3", "queue5") - - d.AddLoop("queue2", "queue3") - d.Prepare() - go func() { - d.Start(context.Background(), ":8081") - }() - go func() { - time.Sleep(3 * time.Second) - result := d.Send(context.Background(), []byte(`[{"user_id": 1}, {"user_id": 2}]`)) - if result.Error != nil { - panic(result.Error) - } - fmt.Println("Response", string(result.Payload)) - }() - - time.Sleep(10 * time.Second) - d.Prepare() + // fmt.Println(rs.TaskID, "Task", string(rs.Payload)) http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) err := d.Start(context.TODO(), ":8083") @@ -75,19 +101,13 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ http.Error(w, "Empty request body", http.StatusBadRequest) return } - var rs mq.Result - if requestType == "request" { - rs = d.Request(context.Background(), payload) - } else { - rs = d.Send(context.Background(), payload) - } + rs := d.ProcessTask(context.Background(), "A", payload) w.Header().Set("Content-Type", "application/json") result := map[string]any{ "message_id": rs.TaskID, - "payload": string(rs.Payload), + "payload": rs.Payload, "error": rs.Error, } json.NewEncoder(w).Encode(result) } } -*/ diff --git a/examples/dag_v2.go b/examples/dag_v2.go deleted file mode 100644 index 82e12eb..0000000 --- a/examples/dag_v2.go +++ /dev/null @@ -1,125 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - - "github.com/oarkflow/mq" - "github.com/oarkflow/mq/v2" -) - -func handler1(ctx context.Context, task *mq.Task) mq.Result { - return mq.Result{Payload: task.Payload} -} - -func handler2(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - return mq.Result{Payload: task.Payload} -} - -func handler3(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - age := int(user["age"].(float64)) - status := "FAIL" - if age > 20 { - status = "PASS" - } - user["status"] = status - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Status: status} -} - -func handler4(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - user["final"] = "D" - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload} -} - -func handler5(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - user["salary"] = "E" - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload} -} - -func handler6(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - resultPayload, _ := json.Marshal(map[string]any{"storage": user}) - return mq.Result{Payload: resultPayload} -} - -var ( - d = v2.NewDAG(mq.WithSyncMode(true)) -) - -func main() { - d.AddNode("A", handler1) - d.AddNode("B", handler2) - d.AddNode("C", handler3) - d.AddNode("D", handler4) - d.AddNode("E", handler5) - d.AddNode("F", handler6) - d.AddEdge("A", "B", v2.LoopEdge) - d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) - d.AddEdge("B", "C") - d.AddEdge("D", "F") - d.AddEdge("E", "F") - - initialPayload, _ := json.Marshal([]map[string]any{ - {"user_id": 1, "age": 12}, - {"user_id": 2, "age": 34}, - }) - /*for i := 0; i < 100; i++ { - - }*/ - rs := d.ProcessTask(context.Background(), "A", initialPayload) - if rs.Error != nil { - panic(rs.Error) - } - fmt.Println(rs.TaskID, "Task", string(rs.Payload)) - /*http.HandleFunc("POST /publish", requestHandler("publish")) - http.HandleFunc("POST /request", requestHandler("request")) - err := d.Start(context.TODO(), ":8083") - if err != nil { - panic(err) - }*/ -} - -func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var payload []byte - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err = io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return - } - rs := d.ProcessTask(context.Background(), "A", payload) - w.Header().Set("Content-Type", "application/json") - result := map[string]any{ - "message_id": rs.TaskID, - "payload": string(rs.Payload), - "error": rs.Error, - } - json.NewEncoder(w).Encode(result) - } -} diff --git a/v2/dag.go b/v2/dag.go index 8868591..5163fa2 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "sync" + "time" "github.com/oarkflow/xid" @@ -78,6 +79,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { }() for _, con := range tm.Nodes { go func(con *Node) { + time.Sleep(1 * time.Second) con.consumer.Consume(ctx) }(con) } diff --git a/v2/task_manager.go b/v2/task_manager.go index 66daedf..b9a5f27 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -5,20 +5,22 @@ import ( "encoding/json" "fmt" "sync" + "sync/atomic" "github.com/oarkflow/mq" "github.com/oarkflow/mq/consts" ) type TaskManager struct { - taskID string - dag *DAG - wg sync.WaitGroup - mutex sync.Mutex - results []mq.Result - nodeResults map[string]mq.Result - done chan struct{} - finalResult chan mq.Result // Channel to collect final results + taskID string + dag *DAG + wg sync.WaitGroup + mutex sync.Mutex + results []mq.Result + waitingCallback int64 + nodeResults map[string]mq.Result + done chan struct{} + finalResult chan mq.Result // Channel to collect final results } func NewTaskManager(d *DAG, taskID string) *TaskManager { @@ -26,9 +28,7 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { dag: d, nodeResults: make(map[string]mq.Result), results: make([]mq.Result, 0), - done: make(chan struct{}), taskID: taskID, - finalResult: make(chan mq.Result), // Initialize finalResult channel } } @@ -37,26 +37,97 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j if !ok { return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } - tm.wg.Add(1) - go tm.processNode(ctx, node, payload) - go func() { - tm.wg.Wait() - close(tm.done) - }() - select { - case <-ctx.Done(): - return mq.Result{Error: ctx.Err()} - case <-tm.done: - tm.mutex.Lock() - defer tm.mutex.Unlock() - if len(tm.results) == 1 { - return tm.handleResult(ctx, tm.results[0]) + if tm.dag.server.SyncMode() { + tm.done = make(chan struct{}) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + close(tm.done) + }() + select { + case <-ctx.Done(): + return mq.Result{Error: ctx.Err()} + case <-tm.done: + tm.mutex.Lock() + defer tm.mutex.Unlock() + if len(tm.results) == 1 { + return tm.handleResult(ctx, tm.results[0]) + } + return tm.handleResult(ctx, tm.results) + } + } else { + tm.finalResult = make(chan mq.Result) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + }() + select { + case result := <-tm.finalResult: // Block until a result is available + return result + case <-ctx.Done(): // Handle context cancellation + return mq.Result{Error: ctx.Err()} } - return tm.handleResult(ctx, tm.results) } } func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result { + if result.Topic != "" { + atomic.AddInt64(&tm.waitingCallback, -1) + } + node, ok := tm.dag.Nodes[result.Topic] + if !ok { + return result + } + edges := make([]Edge, len(node.Edges)) + copy(edges, node.Edges) + if result.Status != "" { + if conditions, ok := tm.dag.conditions[result.Topic]; ok { + if targetNodeKey, ok := conditions[result.Status]; ok { + if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { + edges = append(edges, Edge{From: node, To: targetNode}) + } + } + } + } + if len(edges) == 0 { + tm.appendFinalResult(result) + if !tm.dag.server.SyncMode() { + var rs mq.Result + if len(tm.results) == 1 { + rs = tm.handleResult(ctx, tm.results[0]) + } else { + rs = tm.handleResult(ctx, tm.results) + } + if tm.waitingCallback == 0 { + tm.finalResult <- rs + } + } + return result + } + for _, edge := range edges { + switch edge.Type { + case LoopEdge: + var items []json.RawMessage + err := json.Unmarshal(result.Payload, &items) + if err != nil { + tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}) + return result + } + for _, item := range items { + tm.wg.Add(1) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) + go tm.processNode(ctx, edge.To, item) + } + case SimpleEdge: + if edge.To != nil { + tm.wg.Add(1) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) + go tm.processNode(ctx, edge.To, result.Payload) + } + } + } return mq.Result{} } @@ -103,6 +174,7 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) { } func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { + atomic.AddInt64(&tm.waitingCallback, 1) defer tm.wg.Done() var result mq.Result select { @@ -115,6 +187,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json if tm.dag.server.SyncMode() { result = node.consumer.ProcessTask(ctx, NewTask(tm.taskID, payload, node.Key)) result.Topic = node.Key + result.TaskID = tm.taskID if result.Error != nil { tm.appendFinalResult(result) return @@ -130,41 +203,5 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json tm.mutex.Lock() tm.nodeResults[node.Key] = result tm.mutex.Unlock() - edges := make([]Edge, len(node.Edges)) - copy(edges, node.Edges) - if result.Status != "" { - if conditions, ok := tm.dag.conditions[result.Topic]; ok { - if targetNodeKey, ok := conditions[result.Status]; ok { - if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { - edges = append(edges, Edge{From: node, To: targetNode}) - } - } - } - } - if len(edges) == 0 { - tm.appendFinalResult(result) - return - } - for _, edge := range edges { - switch edge.Type { - case LoopEdge: - var items []json.RawMessage - err := json.Unmarshal(result.Payload, &items) - if err != nil { - tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}) - return - } - for _, item := range items { - tm.wg.Add(1) - ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, item) - } - case SimpleEdge: - if edge.To != nil { - tm.wg.Add(1) - ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, result.Payload) - } - } - } + tm.handleCallback(ctx, result) } From e477acf91c148e3f7024c20bdfa1e05e27590da6 Mon Sep 17 00:00:00 2001 From: sujit Date: Tue, 8 Oct 2024 23:57:44 +0545 Subject: [PATCH 15/17] feat: add example --- examples/dag.go | 9 +++-- v2/dag.go | 89 ++++++++++++++++++++++++++++++----------- v2/task_manager.go | 98 +++++++++++++++++++++++++++------------------- 3 files changed, 130 insertions(+), 66 deletions(-) diff --git a/examples/dag.go b/examples/dag.go index 76bf297..a228aea 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -57,11 +57,11 @@ func handler6(ctx context.Context, task *mq.Task) mq.Result { } var ( - d = v2.NewDAG(mq.WithSyncMode(false)) + d = v2.NewDAG(mq.WithSyncMode(true)) ) func main() { - d.AddNode("A", handler1) + d.AddNode("A", handler1, true) d.AddNode("B", handler2) d.AddNode("C", handler3) d.AddNode("D", handler4) @@ -72,7 +72,6 @@ func main() { d.AddEdge("B", "C") d.AddEdge("D", "F") d.AddEdge("E", "F") - // fmt.Println(rs.TaskID, "Task", string(rs.Payload)) http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) @@ -101,7 +100,9 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ http.Error(w, "Empty request body", http.StatusBadRequest) return } - rs := d.ProcessTask(context.Background(), "A", payload) + ctx := context.Background() + // ctx = context.WithValue(ctx, "initial_node", "E") + rs := d.ProcessTask(ctx, payload) w.Header().Set("Content-Type", "application/json") result := map[string]any{ "message_id": rs.TaskID, diff --git a/v2/dag.go b/v2/dag.go index 5163fa2..2115ad4 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -3,6 +3,7 @@ package v2 import ( "context" "encoding/json" + "fmt" "log" "net/http" "sync" @@ -20,12 +21,6 @@ func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task { return &mq.Task{ID: id, Payload: payload, Topic: nodeKey} } -type Node struct { - Key string - Edges []Edge - consumer *mq.Consumer -} - type EdgeType int func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge } @@ -35,6 +30,12 @@ const ( LoopEdge ) +type Node struct { + Key string + Edges []Edge + consumer *mq.Consumer +} + type Edge struct { From *Node To *Node @@ -42,6 +43,7 @@ type Edge struct { } type DAG struct { + FirstNode string Nodes map[string]*Node server *mq.Broker taskContext map[string]*TaskManager @@ -68,21 +70,21 @@ func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { } func (tm *DAG) Start(ctx context.Context, addr string) error { - if tm.server.SyncMode() { - return nil - } - go func() { - err := tm.server.Start(ctx) - if err != nil { - panic(err) + if !tm.server.SyncMode() { + go func() { + err := tm.server.Start(ctx) + if err != nil { + panic(err) + } + }() + for _, con := range tm.Nodes { + go func(con *Node) { + time.Sleep(1 * time.Second) + con.consumer.Consume(ctx) + }(con) } - }() - for _, con := range tm.Nodes { - go func(con *Node) { - time.Sleep(1 * time.Second) - con.consumer.Consume(ctx) - }(con) } + log.Printf("HTTP server started on %s", addr) config := tm.server.TLSConfig() if config.UseTLS { @@ -91,7 +93,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { return http.ListenAndServe(addr, nil) } -func (tm *DAG) AddNode(key string, handler mq.Handler) { +func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) { tm.mu.Lock() defer tm.mu.Unlock() con := mq.NewConsumer(key, key, handler) @@ -99,6 +101,9 @@ func (tm *DAG) AddNode(key string, handler mq.Handler) { Key: key, consumer: con, } + if len(firstNode) > 0 && firstNode[0] { + tm.FirstNode = key + } } func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { @@ -125,11 +130,51 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { fromNode.Edges = append(fromNode.Edges, edge) } -func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) mq.Result { +func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { + val := ctx.Value("initial_node") + initialNode, ok := val.(string) + if !ok { + if tm.FirstNode == "" { + firstNode := tm.FindInitialNode() + if firstNode != nil { + tm.FirstNode = firstNode.Key + } + } + if tm.FirstNode == "" { + return mq.Result{Error: fmt.Errorf("initial node not found")} + } + initialNode = tm.FirstNode + } tm.mu.Lock() defer tm.mu.Unlock() taskID := xid.New().String() manager := NewTaskManager(tm, taskID) tm.taskContext[taskID] = manager - return manager.processTask(ctx, node, payload) + return manager.processTask(ctx, initialNode, payload) +} + +func (tm *DAG) FindInitialNode() *Node { + incomingEdges := make(map[string]bool) + connectedNodes := make(map[string]bool) + for _, node := range tm.Nodes { + for _, edge := range node.Edges { + if edge.Type.IsValid() { + connectedNodes[node.Key] = true + connectedNodes[edge.To.Key] = true + incomingEdges[edge.To.Key] = true + } + } + if cond, ok := tm.conditions[node.Key]; ok { + for _, target := range cond { + connectedNodes[target] = true + incomingEdges[target] = true + } + } + } + for nodeID, node := range tm.Nodes { + if !incomingEdges[nodeID] && connectedNodes[nodeID] { + return node + } + } + return nil } diff --git a/v2/task_manager.go b/v2/task_manager.go index b9a5f27..92f89c2 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -32,42 +32,63 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { } } +func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result { + tm.done = make(chan struct{}) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + close(tm.done) + }() + select { + case <-ctx.Done(): + return mq.Result{Error: ctx.Err()} + case <-tm.done: + tm.mutex.Lock() + defer tm.mutex.Unlock() + if len(tm.results) == 1 { + return tm.handleResult(ctx, tm.results[0]) + } + return tm.handleResult(ctx, tm.results) + } +} + +func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result { + tm.finalResult = make(chan mq.Result) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + }() + select { + case result := <-tm.finalResult: // Block until a result is available + return result + case <-ctx.Done(): // Handle context cancellation + return mq.Result{Error: ctx.Err()} + } +} + func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { node, ok := tm.dag.Nodes[nodeID] if !ok { return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } if tm.dag.server.SyncMode() { - tm.done = make(chan struct{}) - tm.wg.Add(1) - go tm.processNode(ctx, node, payload) - go func() { - tm.wg.Wait() - close(tm.done) - }() - select { - case <-ctx.Done(): - return mq.Result{Error: ctx.Err()} - case <-tm.done: - tm.mutex.Lock() - defer tm.mutex.Unlock() - if len(tm.results) == 1 { - return tm.handleResult(ctx, tm.results[0]) - } - return tm.handleResult(ctx, tm.results) + return tm.handleSyncTask(ctx, node, payload) + } + return tm.handleAsyncTask(ctx, node, payload) +} + +func (tm *TaskManager) dispatchFinalResult(ctx context.Context) { + if !tm.dag.server.SyncMode() { + var rs mq.Result + if len(tm.results) == 1 { + rs = tm.handleResult(ctx, tm.results[0]) + } else { + rs = tm.handleResult(ctx, tm.results) } - } else { - tm.finalResult = make(chan mq.Result) - tm.wg.Add(1) - go tm.processNode(ctx, node, payload) - go func() { - tm.wg.Wait() - }() - select { - case result := <-tm.finalResult: // Block until a result is available - return result - case <-ctx.Done(): // Handle context cancellation - return mq.Result{Error: ctx.Err()} + if tm.waitingCallback == 0 { + tm.finalResult <- rs } } } @@ -93,17 +114,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. } if len(edges) == 0 { tm.appendFinalResult(result) - if !tm.dag.server.SyncMode() { - var rs mq.Result - if len(tm.results) == 1 { - rs = tm.handleResult(ctx, tm.results[0]) - } else { - rs = tm.handleResult(ctx, tm.results) - } - if tm.waitingCallback == 0 { - tm.finalResult <- rs - } - } + tm.dispatchFinalResult(ctx) return result } for _, edge := range edges { @@ -205,3 +216,10 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json tm.mutex.Unlock() tm.handleCallback(ctx, result) } + +func (tm *TaskManager) Clear() error { + tm.waitingCallback = 0 + clear(tm.results) + tm.nodeResults = make(map[string]mq.Result) + return nil +} From 9630a2b277b33b2374e63cd43f12789c88c0f5fb Mon Sep 17 00:00:00 2001 From: sujit Date: Wed, 9 Oct 2024 00:08:21 +0545 Subject: [PATCH 16/17] feat: add example --- {v2 => dag}/dag.go | 2 +- {v2 => dag}/task_manager.go | 2 +- examples/dag.go | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) rename {v2 => dag}/dag.go (99%) rename {v2 => dag}/task_manager.go (99%) diff --git a/v2/dag.go b/dag/dag.go similarity index 99% rename from v2/dag.go rename to dag/dag.go index 2115ad4..224020d 100644 --- a/v2/dag.go +++ b/dag/dag.go @@ -1,4 +1,4 @@ -package v2 +package dag import ( "context" diff --git a/v2/task_manager.go b/dag/task_manager.go similarity index 99% rename from v2/task_manager.go rename to dag/task_manager.go index 92f89c2..3a2347b 100644 --- a/v2/task_manager.go +++ b/dag/task_manager.go @@ -1,4 +1,4 @@ -package v2 +package dag import ( "context" diff --git a/examples/dag.go b/examples/dag.go index a228aea..c3fa15a 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -7,7 +7,7 @@ import ( "net/http" "github.com/oarkflow/mq" - "github.com/oarkflow/mq/v2" + "github.com/oarkflow/mq/dag" ) func handler1(ctx context.Context, task *mq.Task) mq.Result { @@ -57,7 +57,8 @@ func handler6(ctx context.Context, task *mq.Task) mq.Result { } var ( - d = v2.NewDAG(mq.WithSyncMode(true)) + d = dag.NewDAG(mq.WithSyncMode(true)) + // d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) ) func main() { @@ -67,7 +68,7 @@ func main() { d.AddNode("D", handler4) d.AddNode("E", handler5) d.AddNode("F", handler6) - d.AddEdge("A", "B", v2.LoopEdge) + d.AddEdge("A", "B", dag.LoopEdge) d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) d.AddEdge("B", "C") d.AddEdge("D", "F") From ae60d8cc08af082454a1b52dd651c0d37608f7d2 Mon Sep 17 00:00:00 2001 From: sujit Date: Wed, 9 Oct 2024 00:09:46 +0545 Subject: [PATCH 17/17] feat: add example --- dag/dag.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dag/dag.go b/dag/dag.go index 224020d..8e83e6a 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -63,7 +63,7 @@ func NewDAG(opts ...mq.Option) *DAG { } func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { - if taskContext, ok := tm.taskContext[result.TaskID]; ok { + if taskContext, ok := tm.taskContext[result.TaskID]; ok && result.Topic != "" { return taskContext.handleCallback(ctx, result) } return mq.Result{}