diff --git a/dag/dag.go b/dag/dag.go index 6dcbb40..4cbc60f 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "net/http" + "strings" "sync" "time" @@ -24,11 +25,11 @@ func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task { type EdgeType int -func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge } +func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator } const ( - SimpleEdge EdgeType = iota - LoopEdge + Simple EdgeType = iota + Iterator ) type Node struct { @@ -50,8 +51,9 @@ type Edge struct { } type ( - When string - Then string + FromNode string + When string + Then string ) type DAG struct { @@ -59,7 +61,7 @@ type DAG struct { nodes map[string]*Node server *mq.Broker taskContext map[string]*TaskManager - conditions map[string]map[When]Then + conditions map[FromNode]map[When]Then mu sync.RWMutex paused bool opts []mq.Option @@ -69,7 +71,7 @@ func NewDAG(opts ...mq.Option) *DAG { d := &DAG{ nodes: make(map[string]*Node), taskContext: make(map[string]*TaskManager), - conditions: make(map[string]map[When]Then), + conditions: make(map[FromNode]map[When]Then), } opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) d.server = mq.NewBroker(opts...) @@ -77,6 +79,99 @@ func NewDAG(opts ...mq.Option) *DAG { return d } +// PrintGraph prints the DAG's adjacency list +func (tm *DAG) PrintGraph() { + tm.mu.RLock() + defer tm.mu.RUnlock() + + fmt.Println("DAG Graph structure:") + for _, node := range tm.nodes { + fmt.Printf("Node: %s (%s) -> ", node.Name, node.Key) + if conditions, ok := tm.conditions[FromNode(node.Key)]; ok { + var c []string + for when, then := range conditions { + if target, ok := tm.nodes[string(then)]; ok { + c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Name, target.Key)) + } + } + fmt.Println(strings.Join(c, ", ")) + } + var c []string + for _, edge := range node.Edges { + for _, target := range edge.To { + c = append(c, fmt.Sprintf("%s (%s)", target.Name, target.Key)) + } + } + fmt.Println(strings.Join(c, ", ")) + } +} + +func (tm *DAG) ClassifyEdges(startNodes ...string) { + startNode := tm.GetStartNode() + tm.mu.RLock() + defer tm.mu.RUnlock() + if len(startNodes) > 0 && startNodes[0] != "" { + startNode = startNodes[0] + } + visited := make(map[string]bool) + discoveryTime := make(map[string]int) + finishedTime := make(map[string]int) + timeVal := 0 + if startNode == "" { + firstNode := tm.findStartNode() + if firstNode != nil { + startNode = firstNode.Key + } + } + if startNode != "" { + tm.dfs(startNode, visited, discoveryTime, finishedTime, &timeVal) + } +} + +func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, timeVal *int) { + visited[v] = true + *timeVal++ + discoveryTime[v] = *timeVal + node := tm.nodes[v] + for _, edge := range node.Edges { + for _, adj := range edge.To { + switch edge.Type { + case Simple: + if !visited[adj.Key] { + fmt.Printf("Simple Edge: %s -> %s\n", v, adj.Key) + tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal) + } + case Iterator: + if !visited[adj.Key] { + fmt.Printf("Iterator Edge: %s -> %s\n", v, adj.Key) + tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal) + } + } + + } + } + tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal) + *timeVal++ + finishedTime[v] = *timeVal +} + +// handleConditionalEdges processes the conditional edges based on the task result +func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int) { + node := tm.nodes[v] + for when, then := range tm.conditions[FromNode(node.Key)] { + if targetNodeKey, ok := tm.nodes[string(then)]; ok { + if !visited[targetNodeKey.Key] { + fmt.Printf("Conditional Edge [%s]: %s -> %s\n", when, v, targetNodeKey.Key) + tm.dfs(targetNodeKey.Key, visited, discoveryTime, finishedTime, time) + } else { + if discoveryTime[v] > discoveryTime[targetNodeKey.Key] { + fmt.Printf("Conditional Loop Edge [%s]: %s -> %s\n", when, v, targetNodeKey.Key) + } + } + } + } +} + func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { if taskContext, ok := tm.taskContext[result.TaskID]; ok && result.Topic != "" { return taskContext.handleCallback(ctx, result) @@ -178,18 +273,18 @@ func (tm *DAG) IsReady() bool { return true } -func (tm *DAG) AddCondition(fromNode string, conditions map[When]Then) { +func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) { tm.mu.Lock() defer tm.mu.Unlock() tm.conditions[fromNode] = conditions } func (tm *DAG) AddLoop(from string, targets ...string) { - tm.addEdge(LoopEdge, from, targets...) + tm.addEdge(Iterator, from, targets...) } func (tm *DAG) AddEdge(from string, targets ...string) { - tm.addEdge(SimpleEdge, from, targets...) + tm.addEdge(Simple, from, targets...) } func (tm *DAG) addEdge(edgeType EdgeType, from string, targets ...string) { @@ -257,7 +352,7 @@ func (tm *DAG) findStartNode() *Node { } } - if cond, ok := tm.conditions[node.Key]; ok { + if cond, ok := tm.conditions[FromNode(node.Key)]; ok { for _, target := range cond { connectedNodes[string(target)] = true incomingEdges[string(target)] = true diff --git a/dag/task_manager.go b/dag/task_manager.go index 88ad693..978e69b 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -82,7 +82,7 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge edges := make([]Edge, len(node.Edges)) copy(edges, node.Edges) if result.Status != "" { - if conditions, ok := tm.dag.conditions[result.Topic]; ok { + if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok { if targetNodeKey, ok := conditions[When(result.Status)]; ok { if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok { edges = append(edges, Edge{From: node, To: []*Node{targetNode}}) @@ -113,7 +113,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. } for _, edge := range edges { switch edge.Type { - case LoopEdge: + case Iterator: var items []json.RawMessage err := json.Unmarshal(result.Payload, &items) if err != nil { @@ -126,7 +126,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. go tm.processNode(ctx, target, item) } } - case SimpleEdge: + case Simple: for _, target := range edge.To { ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key}) go tm.processNode(ctx, target, result.Payload) diff --git a/examples/dag.go b/examples/dag.go index 9acd5c7..468b369 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -1,126 +1,104 @@ package main import ( + "context" + "encoding/json" "fmt" - "sort" + "io" + "net/http" + + "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/examples/tasks" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" ) -// DAG represents a Directed Acyclic Graph -type DAG struct { - vertices int - adjList map[int][]int // adjacency list to represent edges -} +var ( + d = dag.NewDAG( + // mq.WithSyncMode(true), + mq.WithNotifyResponse(tasks.NotifyResponse), + mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")), + ) + // d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) +) -// NewDAG creates a new DAG with a given number of vertices -func NewDAG(vertices int) *DAG { - return &DAG{ - vertices: vertices, - adjList: make(map[int][]int), - } -} - -// AddEdge adds a directed edge from u to v -func (d *DAG) AddEdge(u, v int) { - d.adjList[u] = append(d.adjList[u], v) -} - -// PrintGraph prints the graph's adjacency list -func (d *DAG) PrintGraph() { - for vertex, edges := range d.adjList { - fmt.Printf("Vertex %d -> %v\n", vertex, edges) - } -} - -// DFS traversal function to classify edges as tree, forward, or cross -func (d *DAG) ClassifyEdges() { - visited := make([]bool, d.vertices) - discoveryTime := make([]int, d.vertices) - finishedTime := make([]int, d.vertices) - time := 0 - - for i := 0; i < d.vertices; i++ { - if !visited[i] { - d.dfs(i, visited, discoveryTime, finishedTime, &time) - } - } -} - -// dfs performs a DFS and classifies the edges -func (d *DAG) dfs(v int, visited []bool, discoveryTime []int, finishedTime []int, time *int) { - visited[v] = true - *time++ - discoveryTime[v] = *time - - for _, adj := range d.adjList[v] { - if !visited[adj] { - // Tree Edge: adj not visited, and it's being discovered - fmt.Printf("Tree Edge: %d -> %d\n", v, adj) - d.dfs(adj, visited, discoveryTime, finishedTime, time) - } else { - if discoveryTime[v] < discoveryTime[adj] { - // Forward Edge: adj is a descendant but already discovered - fmt.Printf("Forward Edge: %d -> %d\n", v, adj) - } else if finishedTime[adj] == 0 { - // Cross Edge: adj is in a different branch (adj was visited, but not fully processed) - fmt.Printf("Cross Edge: %d -> %d\n", v, adj) - } - } - } - - *time++ - finishedTime[v] = *time -} - -// TopologicalSort returns a topologically sorted order of the DAG vertices -func (d *DAG) TopologicalSort() []int { - visited := make([]bool, d.vertices) - stack := []int{} - - for i := 0; i < d.vertices; i++ { - if !visited[i] { - d.topologicalSortUtil(i, visited, &stack) - } - } - - // Reverse the stack to get the topological order - sort.Slice(stack, func(i, j int) bool { return stack[i] > stack[j] }) - return stack -} - -// Helper function for topological sorting using DFS -func (d *DAG) topologicalSortUtil(v int, visited []bool, stack *[]int) { - visited[v] = true - - for _, adj := range d.adjList[v] { - if !visited[adj] { - d.topologicalSortUtil(adj, visited, stack) - } - } - - *stack = append(*stack, v) -} - -// Main function to demonstrate DAG edge classification func main() { - // Create a new DAG - dag := NewDAG(6) + d.AddNode("A", "A", tasks.Node1, true) + d.AddNode("B", "B", tasks.Node2) + d.AddNode("C", "C", tasks.Node3) + d.AddNode("D", "D", tasks.Node4) + d.AddNode("E", "E", tasks.Node5) + d.AddNode("F", "F", tasks.Node6) + d.AddNode("G", "G", tasks.Node7) + d.AddNode("H", "H", tasks.Node8) - // Add edges (vertices start from 0) - dag.AddEdge(0, 1) - dag.AddEdge(0, 2) - dag.AddEdge(1, 3) - dag.AddEdge(2, 3) - dag.AddEdge(3, 4) - dag.AddEdge(4, 5) + d.AddLoop("A", "B") + d.AddCondition("C", map[dag.When]dag.Then{"PASS": "D", "FAIL": "E"}) + d.AddEdge("B", "C") + d.AddEdge("D", "F") + d.AddEdge("E", "F") + d.AddEdge("F", "G", "H") - fmt.Println("Graph adjacency list:") - dag.PrintGraph() + // Classify edges + d.ClassifyEdges() - fmt.Println("\nClassifying edges:") - dag.ClassifyEdges() - - // Perform topological sorting - fmt.Println("\nTopologically sorted order:") - order := dag.TopologicalSort() - fmt.Println(order) + http.HandleFunc("POST /publish", requestHandler("publish")) + http.HandleFunc("POST /request", requestHandler("request")) + http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { + id := request.PathValue("id") + if id != "" { + d.PauseConsumer(request.Context(), id) + } + }) + http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { + id := request.PathValue("id") + if id != "" { + d.ResumeConsumer(request.Context(), id) + } + }) + http.HandleFunc("/pause", func(writer http.ResponseWriter, request *http.Request) { + d.Pause(true) + }) + http.HandleFunc("/resume", func(writer http.ResponseWriter, request *http.Request) { + d.Pause(false) + }) + err := d.Start(context.TODO(), ":8083") + if err != nil { + panic(err) + } +} + +func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) + return + } + var payload []byte + if r.Body != nil { + defer r.Body.Close() + var err error + payload, err = io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + } else { + http.Error(w, "Empty request body", http.StatusBadRequest) + return + } + ctx := r.Context() + if requestType == "request" { + ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) + } + // ctx = context.WithValue(ctx, "initial_node", "E") + rs := d.ProcessTask(ctx, payload) + if rs.Error != nil { + http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(rs) + } }