diff --git a/broker.go b/broker.go index 9b0b497..4a0a80c 100644 --- a/broker.go +++ b/broker.go @@ -284,9 +284,6 @@ func (b *Broker) Start(ctx context.Context) error { c.Close() }() - // Optionally set connection timeouts to prevent idle connections - c.SetReadDeadline(time.Now().Add(5 * time.Minute)) - for { // Attempt to read the message err := b.readMessage(ctx, c) diff --git a/dag/dag.go b/dag/dag.go index 0e7a863..031841f 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -68,6 +68,8 @@ type DAG struct { opts []mq.Option mu sync.RWMutex paused bool + Error error + report string } func (tm *DAG) SetKey(key string) { @@ -283,32 +285,47 @@ func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG { } func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG { - tm.addEdge(Iterator, label, from, targets...) + tm.Error = tm.addEdge(Iterator, label, from, targets...) return tm } func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG { - tm.addEdge(Simple, label, from, targets...) + tm.Error = tm.addEdge(Simple, label, from, targets...) return tm } -func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) { +func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error { tm.mu.Lock() defer tm.mu.Unlock() fromNode, ok := tm.nodes[from] if !ok { - return + return fmt.Errorf("Error: 'from' node %s does not exist\n", from) } var nodes []*Node for _, target := range targets { toNode, ok := tm.nodes[target] if !ok { - return + return fmt.Errorf("Error: 'from' node %s does not exist\n", target) } nodes = append(nodes, toNode) } edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label} fromNode.Edges = append(fromNode.Edges, edge) + return nil +} + +func (tm *DAG) Validate() error { + report, hasCycle, err := tm.ClassifyEdges() + if hasCycle || err != nil { + tm.Error = err + return err + } + tm.report = report + return nil +} + +func (tm *DAG) GetReport() string { + return tm.report } func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { diff --git a/dag/task_manager.go b/dag/task_manager.go index 1af2296..dbdc601 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -98,8 +98,10 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. } edges := tm.getConditionalEdges(node, result) if len(edges) == 0 { - tm.appendFinalResult(result) + tm.appendResult(result, true) return result + } else { + tm.appendResult(result, false) } for _, edge := range edges { switch edge.Type { @@ -107,7 +109,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. var items []json.RawMessage err := json.Unmarshal(result.Payload, &items) if err != nil { - tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}) + tm.appendResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false) return result } for _, target := range edge.To { @@ -170,10 +172,12 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result return rs } -func (tm *TaskManager) appendFinalResult(result mq.Result) { +func (tm *TaskManager) appendResult(result mq.Result, final bool) { tm.mutex.Lock() tm.updateTS(&result) - tm.results = append(tm.results, result) + if final { + tm.results = append(tm.results, result) + } tm.nodeResults[result.Topic] = result tm.mutex.Unlock() } @@ -199,7 +203,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json select { case <-ctx.Done(): result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx} - tm.appendFinalResult(result) + tm.appendResult(result, false) return default: ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) @@ -210,14 +214,14 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json result.TaskID = tm.taskID } if result.Error != nil { - tm.appendFinalResult(result) + tm.appendResult(result, false) return } return } err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key) if err != nil { - tm.appendFinalResult(mq.Result{Error: err}) + tm.appendResult(mq.Result{Error: err}, false) return } } diff --git a/dag/ui.go b/dag/ui.go index 118b0f3..0758635 100644 --- a/dag/ui.go +++ b/dag/ui.go @@ -32,7 +32,8 @@ func (tm *DAG) PrintGraph() { } } -func (tm *DAG) ClassifyEdges(startNodes ...string) { +func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) { + builder := &strings.Builder{} startNode := tm.GetStartNode() tm.mu.RLock() defer tm.mu.RUnlock() @@ -43,57 +44,78 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) { discoveryTime := make(map[string]int) finishedTime := make(map[string]int) timeVal := 0 + inRecursionStack := make(map[string]bool) // track nodes in the recursion stack for cycle detection if startNode == "" { firstNode := tm.findStartNode() if firstNode != nil { startNode = firstNode.Key } } - if startNode != "" { - tm.dfs(startNode, visited, discoveryTime, finishedTime, &timeVal) + if startNode == "" { + return "", false, fmt.Errorf("no start node found") } + hasCycle, cycleErr := tm.dfs(startNode, visited, discoveryTime, finishedTime, &timeVal, inRecursionStack, builder) + if cycleErr != nil { + return builder.String(), hasCycle, cycleErr + } + return builder.String(), hasCycle, nil } -func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, timeVal *int) { +func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, timeVal *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) { visited[v] = true + inRecursionStack[v] = true // mark node as part of recursion stack *timeVal++ discoveryTime[v] = *timeVal node := tm.nodes[v] + hasCycle := false + var err error for _, edge := range node.Edges { for _, adj := range edge.To { - switch edge.Type { - case Simple: - if !visited[adj.Key] { - fmt.Printf("Simple Edge: %s -> %s\n", v, adj.Key) - tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal) + if !visited[adj.Key] { + builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key)) + hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) + if err != nil { + return true, err } - case Iterator: - if !visited[adj.Key] { - fmt.Printf("Iterator Edge: %s -> %s\n", v, adj.Key) - tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal) + if hasCycle { + return true, nil } + } else if inRecursionStack[adj.Key] { + cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key) + return true, fmt.Errorf(cycleMsg) } } } - tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal) + hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) + if err != nil { + return true, err + } *timeVal++ finishedTime[v] = *timeVal + inRecursionStack[v] = false // remove from recursion stack after finishing processing + return hasCycle, nil } -func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int) { +func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) { node := tm.nodes[v] for when, then := range tm.conditions[FromNode(node.Key)] { - if targetNodeKey, ok := tm.nodes[string(then)]; ok { - if !visited[targetNodeKey.Key] { - fmt.Printf("Conditional Edge [%s]: %s -> %s\n", when, v, targetNodeKey.Key) - tm.dfs(targetNodeKey.Key, visited, discoveryTime, finishedTime, time) - } else { - if discoveryTime[v] > discoveryTime[targetNodeKey.Key] { - fmt.Printf("Conditional Loop Edge [%s]: %s -> %s\n", when, v, targetNodeKey.Key) + if targetNode, ok := tm.nodes[string(then)]; ok { + if !visited[targetNode.Key] { + builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)) + hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder) + if err != nil { + return true, err } + if hasCycle { + return true, nil + } + } else if inRecursionStack[targetNode.Key] { + cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key) + return true, fmt.Errorf(cycleMsg) } } } + return false, nil } func (tm *DAG) SaveDOTFile(filename string) error { diff --git a/examples/dag.go b/examples/dag.go index 7583c69..30acbf8 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -52,7 +52,12 @@ func Sync() { func aSync() { f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete(), mq.WithNotifyResponse(tasks.NotifyResponse)) setup(f) - err := f.Start(context.TODO(), ":8083") + err := f.Validate() + if err != nil { + panic(err) + } + + err = f.Start(context.TODO(), ":8083") if err != nil { panic(err) } diff --git a/examples/tasks/operations.go b/examples/tasks/operations.go index be5158f..df44834 100644 --- a/examples/tasks/operations.go +++ b/examples/tasks/operations.go @@ -5,8 +5,6 @@ import ( "github.com/oarkflow/json" - "github.com/oarkflow/dipper" - "github.com/oarkflow/mq" "github.com/oarkflow/mq/services" ) @@ -109,38 +107,3 @@ func (e *InAppNotification) ProcessTask(ctx context.Context, task *mq.Task) mq.R } return mq.Result{Payload: task.Payload, Ctx: ctx} } - -type DataBranchHandler struct{ services.Operation } - -func (v *DataBranchHandler) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { - ctx = context.WithValue(ctx, "extra_params", map[string]any{"iphone": true}) - var row map[string]any - var result mq.Result - result.Payload = task.Payload - err := json.Unmarshal(result.Payload, &row) - if err != nil { - result.Error = err - return result - } - b := make(map[string]any) - switch branches := row["data_branch"].(type) { - case map[string]any: - for field, handler := range branches { - data, err := dipper.Get(row, field) - if err != nil { - break - } - b[handler.(string)] = data - } - break - } - br, err := json.Marshal(b) - if err != nil { - result.Error = err - return result - } - result.Status = "branches" - result.Payload = br - result.Ctx = ctx - return result -}