diff --git a/dag/dag.go b/dag/dag.go index 1f164ba..50b8fde 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -39,7 +39,7 @@ type Node struct { type Edge struct { From *Node - To *Node + To []*Node Type EdgeType } @@ -172,21 +172,30 @@ func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { tm.conditions[fromNode] = conditions } -func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { +func (tm *DAG) AddLoop(from string, targets ...string) { + tm.addEdge(LoopEdge, from, targets...) +} + +func (tm *DAG) AddEdge(from string, targets ...string) { + tm.addEdge(SimpleEdge, from, targets...) +} + +func (tm *DAG) addEdge(edgeType EdgeType, from string, targets ...string) { tm.mu.Lock() defer tm.mu.Unlock() fromNode, ok := tm.nodes[from] if !ok { return } - toNode, ok := tm.nodes[to] - if !ok { - return - } - edge := Edge{From: fromNode, To: toNode} - if len(edgeTypes) > 0 && edgeTypes[0].IsValid() { - edge.Type = edgeTypes[0] + var nodes []*Node + for _, target := range targets { + toNode, ok := tm.nodes[target] + if !ok { + return + } + nodes = append(nodes, toNode) } + edge := Edge{From: fromNode, To: nodes, Type: edgeType} fromNode.Edges = append(fromNode.Edges, edge) } @@ -229,8 +238,11 @@ func (tm *DAG) FindInitialNode() *Node { for _, edge := range node.Edges { if edge.Type.IsValid() { connectedNodes[node.Key] = true - connectedNodes[edge.To.Key] = true - incomingEdges[edge.To.Key] = true + for _, to := range edge.To { + connectedNodes[to.Key] = true + incomingEdges[to.Key] = true + } + } } if cond, ok := tm.conditions[node.Key]; ok { diff --git a/dag/task_manager.go b/dag/task_manager.go index bceafb5..d8789ca 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -92,7 +92,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. if conditions, ok := tm.dag.conditions[result.Topic]; ok { if targetNodeKey, ok := conditions[result.Status]; ok { if targetNode, ok := tm.dag.nodes[targetNodeKey]; ok { - edges = append(edges, Edge{From: node, To: targetNode}) + edges = append(edges, Edge{From: node, To: []*Node{targetNode}}) } } } @@ -111,14 +111,16 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}) return result } - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, item) + for _, target := range edge.To { + for _, item := range items { + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key}) + go tm.processNode(ctx, target, item) + } } case SimpleEdge: - if edge.To != nil { - ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, result.Payload) + for _, target := range edge.To { + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key}) + go tm.processNode(ctx, target, result.Payload) } } } diff --git a/examples/dag.go b/examples/dag.go index 4f5d172..2e6f94b 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -29,12 +29,15 @@ func main() { d.AddNode("D", tasks.Node4) d.AddNode("E", tasks.Node5) d.AddNode("F", tasks.Node6) + d.AddNode("G", tasks.Node7) + d.AddNode("H", tasks.Node8) - d.AddEdge("A", "B", dag.LoopEdge) + d.AddLoop("A", "B") d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) d.AddEdge("B", "C") d.AddEdge("D", "F") d.AddEdge("E", "F") + d.AddEdge("F", "G", "H") http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index 38764b3..a8dfbd8 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -55,6 +55,22 @@ func Node6(_ context.Context, task *mq.Task) mq.Result { return mq.Result{Payload: resultPayload} } +func Node7(_ context.Context, task *mq.Task) mq.Result { + var user map[string]any + _ = json.Unmarshal(task.Payload, &user) + user["node"] = "7" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} +} + +func Node8(_ context.Context, task *mq.Task) mq.Result { + var user map[string]any + _ = json.Unmarshal(task.Payload, &user) + user["node"] = "8" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} +} + func Callback(_ context.Context, task mq.Result) mq.Result { fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic) return mq.Result{}