feat: implement Validate to check for cycle

This commit is contained in:
sujit
2024-10-22 12:40:14 +05:45
parent a84ff6d831
commit ea0a7022f9
6 changed files with 83 additions and 75 deletions

View File

@@ -284,9 +284,6 @@ func (b *Broker) Start(ctx context.Context) error {
c.Close()
}()
// Optionally set connection timeouts to prevent idle connections
c.SetReadDeadline(time.Now().Add(5 * time.Minute))
for {
// Attempt to read the message
err := b.readMessage(ctx, c)

View File

@@ -68,6 +68,8 @@ type DAG struct {
opts []mq.Option
mu sync.RWMutex
paused bool
Error error
report string
}
func (tm *DAG) SetKey(key string) {
@@ -283,32 +285,47 @@ func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG {
}
func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG {
tm.addEdge(Iterator, label, from, targets...)
tm.Error = tm.addEdge(Iterator, label, from, targets...)
return tm
}
func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG {
tm.addEdge(Simple, label, from, targets...)
tm.Error = tm.addEdge(Simple, label, from, targets...)
return tm
}
func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) {
func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error {
tm.mu.Lock()
defer tm.mu.Unlock()
fromNode, ok := tm.nodes[from]
if !ok {
return
return fmt.Errorf("Error: 'from' node %s does not exist\n", from)
}
var nodes []*Node
for _, target := range targets {
toNode, ok := tm.nodes[target]
if !ok {
return
return fmt.Errorf("Error: 'from' node %s does not exist\n", target)
}
nodes = append(nodes, toNode)
}
edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label}
fromNode.Edges = append(fromNode.Edges, edge)
return nil
}
func (tm *DAG) Validate() error {
report, hasCycle, err := tm.ClassifyEdges()
if hasCycle || err != nil {
tm.Error = err
return err
}
tm.report = report
return nil
}
func (tm *DAG) GetReport() string {
return tm.report
}
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {

View File

@@ -98,8 +98,10 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
}
edges := tm.getConditionalEdges(node, result)
if len(edges) == 0 {
tm.appendFinalResult(result)
tm.appendResult(result, true)
return result
} else {
tm.appendResult(result, false)
}
for _, edge := range edges {
switch edge.Type {
@@ -107,7 +109,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
var items []json.RawMessage
err := json.Unmarshal(result.Payload, &items)
if err != nil {
tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err})
tm.appendResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
return result
}
for _, target := range edge.To {
@@ -170,10 +172,12 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
return rs
}
func (tm *TaskManager) appendFinalResult(result mq.Result) {
func (tm *TaskManager) appendResult(result mq.Result, final bool) {
tm.mutex.Lock()
tm.updateTS(&result)
tm.results = append(tm.results, result)
if final {
tm.results = append(tm.results, result)
}
tm.nodeResults[result.Topic] = result
tm.mutex.Unlock()
}
@@ -199,7 +203,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
select {
case <-ctx.Done():
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
tm.appendFinalResult(result)
tm.appendResult(result, false)
return
default:
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
@@ -210,14 +214,14 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
result.TaskID = tm.taskID
}
if result.Error != nil {
tm.appendFinalResult(result)
tm.appendResult(result, false)
return
}
return
}
err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
if err != nil {
tm.appendFinalResult(mq.Result{Error: err})
tm.appendResult(mq.Result{Error: err}, false)
return
}
}

View File

@@ -32,7 +32,8 @@ func (tm *DAG) PrintGraph() {
}
}
func (tm *DAG) ClassifyEdges(startNodes ...string) {
func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
builder := &strings.Builder{}
startNode := tm.GetStartNode()
tm.mu.RLock()
defer tm.mu.RUnlock()
@@ -43,57 +44,78 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) {
discoveryTime := make(map[string]int)
finishedTime := make(map[string]int)
timeVal := 0
inRecursionStack := make(map[string]bool) // track nodes in the recursion stack for cycle detection
if startNode == "" {
firstNode := tm.findStartNode()
if firstNode != nil {
startNode = firstNode.Key
}
}
if startNode != "" {
tm.dfs(startNode, visited, discoveryTime, finishedTime, &timeVal)
if startNode == "" {
return "", false, fmt.Errorf("no start node found")
}
hasCycle, cycleErr := tm.dfs(startNode, visited, discoveryTime, finishedTime, &timeVal, inRecursionStack, builder)
if cycleErr != nil {
return builder.String(), hasCycle, cycleErr
}
return builder.String(), hasCycle, nil
}
func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, timeVal *int) {
func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, timeVal *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) {
visited[v] = true
inRecursionStack[v] = true // mark node as part of recursion stack
*timeVal++
discoveryTime[v] = *timeVal
node := tm.nodes[v]
hasCycle := false
var err error
for _, edge := range node.Edges {
for _, adj := range edge.To {
switch edge.Type {
case Simple:
if !visited[adj.Key] {
fmt.Printf("Simple Edge: %s -> %s\n", v, adj.Key)
tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal)
if !visited[adj.Key] {
builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key))
hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
if err != nil {
return true, err
}
case Iterator:
if !visited[adj.Key] {
fmt.Printf("Iterator Edge: %s -> %s\n", v, adj.Key)
tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal)
if hasCycle {
return true, nil
}
} else if inRecursionStack[adj.Key] {
cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key)
return true, fmt.Errorf(cycleMsg)
}
}
}
tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal)
hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
if err != nil {
return true, err
}
*timeVal++
finishedTime[v] = *timeVal
inRecursionStack[v] = false // remove from recursion stack after finishing processing
return hasCycle, nil
}
func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int) {
func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) {
node := tm.nodes[v]
for when, then := range tm.conditions[FromNode(node.Key)] {
if targetNodeKey, ok := tm.nodes[string(then)]; ok {
if !visited[targetNodeKey.Key] {
fmt.Printf("Conditional Edge [%s]: %s -> %s\n", when, v, targetNodeKey.Key)
tm.dfs(targetNodeKey.Key, visited, discoveryTime, finishedTime, time)
} else {
if discoveryTime[v] > discoveryTime[targetNodeKey.Key] {
fmt.Printf("Conditional Loop Edge [%s]: %s -> %s\n", when, v, targetNodeKey.Key)
if targetNode, ok := tm.nodes[string(then)]; ok {
if !visited[targetNode.Key] {
builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key))
hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder)
if err != nil {
return true, err
}
if hasCycle {
return true, nil
}
} else if inRecursionStack[targetNode.Key] {
cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)
return true, fmt.Errorf(cycleMsg)
}
}
}
return false, nil
}
func (tm *DAG) SaveDOTFile(filename string) error {

View File

@@ -52,7 +52,12 @@ func Sync() {
func aSync() {
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete(), mq.WithNotifyResponse(tasks.NotifyResponse))
setup(f)
err := f.Start(context.TODO(), ":8083")
err := f.Validate()
if err != nil {
panic(err)
}
err = f.Start(context.TODO(), ":8083")
if err != nil {
panic(err)
}

View File

@@ -5,8 +5,6 @@ import (
"github.com/oarkflow/json"
"github.com/oarkflow/dipper"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/services"
)
@@ -109,38 +107,3 @@ func (e *InAppNotification) ProcessTask(ctx context.Context, task *mq.Task) mq.R
}
return mq.Result{Payload: task.Payload, Ctx: ctx}
}
type DataBranchHandler struct{ services.Operation }
func (v *DataBranchHandler) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
ctx = context.WithValue(ctx, "extra_params", map[string]any{"iphone": true})
var row map[string]any
var result mq.Result
result.Payload = task.Payload
err := json.Unmarshal(result.Payload, &row)
if err != nil {
result.Error = err
return result
}
b := make(map[string]any)
switch branches := row["data_branch"].(type) {
case map[string]any:
for field, handler := range branches {
data, err := dipper.Get(row, field)
if err != nil {
break
}
b[handler.(string)] = data
}
break
}
br, err := json.Marshal(b)
if err != nil {
result.Error = err
return result
}
result.Status = "branches"
result.Payload = br
result.Ctx = ctx
return result
}