package dag import ( "context" "fmt" "strings" "time" "github.com/oarkflow/mq" "github.com/oarkflow/mq/logger" ) func (tm *DAG) SetStartNode(node string) { // If there was a previous start node, unset its IsFirst if tm.startNode != "" { if oldNode, ok := tm.nodes.Get(tm.startNode); ok { oldNode.IsFirst = false } } tm.startNode = node // Set IsFirst for the new start node if newNode, ok := tm.nodes.Get(node); ok { newNode.IsFirst = true } } func (tm *DAG) GetStartNode() string { return tm.startNode } func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG { tm.conditions[fromNode] = conditions // Update node identifiers after adding conditions tm.updateNodeIdentifiers() return tm } func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG { if tm.Error != nil { return tm } // Configure consumer options based on node type consumerOpts := []mq.Option{mq.WithBrokerURL(tm.server.Options().BrokerAddr())} // Page nodes should have no timeout to allow unlimited time for user input if nodeType == Page { consumerOpts = append(consumerOpts, mq.WithConsumerTimeout(0)) // 0 = no timeout } con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask, consumerOpts...) n := &Node{ Label: name, ID: nodeID, NodeType: nodeType, processor: con, IsLast: true, // Assume it's last until edges are added } if tm.server != nil && tm.server.SyncMode() { n.isReady = true } tm.nodes.Set(nodeID, n) if len(startNode) > 0 && startNode[0] { // If there was a previous start node, unset its IsFirst if tm.startNode != "" { if oldNode, ok := tm.nodes.Get(tm.startNode); ok { oldNode.IsFirst = false } } tm.startNode = nodeID n.IsFirst = true } if nodeType == Page && !tm.hasPageNode { tm.hasPageNode = true } return tm } // AddNodeWithDebug adds a node to the DAG with optional debug mode enabled func (tm *DAG) AddNodeWithDebug(nodeType NodeType, name, nodeID string, handler mq.Processor, debug bool, startNode ...bool) *DAG { dag := tm.AddNode(nodeType, name, nodeID, handler, startNode...) if dag.Error == nil { dag.SetNodeDebug(nodeID, debug) } return dag } func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error { if tm.server.SyncMode() { return fmt.Errorf("DAG cannot have deferred node in Sync Mode") } tm.nodes.Set(key, &Node{ Label: name, ID: key, NodeType: nodeType, IsLast: true, // Assume it's last until edges are added }) if len(firstNode) > 0 && firstNode[0] { // If there was a previous start node, unset its IsFirst if tm.startNode != "" { if oldNode, ok := tm.nodes.Get(tm.startNode); ok { oldNode.IsFirst = false } } tm.startNode = key if node, ok := tm.nodes.Get(key); ok { node.IsFirst = true } } return nil } func (tm *DAG) IsReady() bool { var isReady bool tm.nodes.ForEach(func(_ string, n *Node) bool { if !n.isReady { return false } isReady = true return true }) return isReady } func (tm *DAG) resolveNode(nodeID string) (*Node, bool) { nodeParts := strings.Split(nodeID, ".") if len(nodeParts) > 1 { nodeID = nodeParts[0] } return tm.nodes.Get(nodeID) } func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG { if tm.Error != nil { return tm } if edgeType == Iterator { tm.iteratorNodes.Set(from, []Edge{}) } node, ok := tm.resolveNode(from) if !ok { tm.Error = fmt.Errorf("node not found %s", from) return tm } for _, target := range targets { if targetNode, ok := tm.nodes.Get(target); ok { edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label, FromSource: from} node.Edges = append(node.Edges, edge) if edgeType != Iterator { if edges, ok := tm.iteratorNodes.Get(node.ID); ok { edges = append(edges, edge) tm.iteratorNodes.Set(node.ID, edges) } } } } // Update identifiers after adding edges node.IsLast = false tm.updateNodeIdentifiers() return tm } func (tm *DAG) getCurrentNode(manager *TaskManager) string { if manager.currentNodePayload.Size() == 0 { return "" } return manager.currentNodePayload.Keys()[0] } func (tm *DAG) AddDAGNode(nodeType NodeType, name string, key string, dag *DAG, firstNode ...bool) *DAG { dag.AssignTopic(key) dag.name += fmt.Sprintf("(%s)", name) // Use the sub-DAG directly as a processor since it implements mq.Processor tm.nodes.Set(key, &Node{ Label: name, ID: key, NodeType: nodeType, processor: dag, isReady: true, IsLast: true, // Assume it's last until edges are added }) dag.parentDAG = tm dag.nodeIDInParentDAG = key if len(firstNode) > 0 && firstNode[0] { // If there was a previous start node, unset its IsFirst if tm.startNode != "" { if oldNode, ok := tm.nodes.Get(tm.startNode); ok { oldNode.IsFirst = false } } tm.startNode = key if node, ok := tm.nodes.Get(key); ok { node.IsFirst = true } } return tm } // RemoveNode removes the node with the given nodeID and adjusts the edges. // For example, if A -> B and B -> C exist and B is removed, a new edge A -> C is created. func (tm *DAG) RemoveNode(nodeID string) error { node, exists := tm.nodes.Get(nodeID) if !exists { return fmt.Errorf("node %s does not exist", nodeID) } // Collect incoming edges (from nodes pointing to the removed node). var incomingEdges []Edge tm.nodes.ForEach(func(_ string, n *Node) bool { for _, edge := range n.Edges { if edge.To.ID == nodeID { incomingEdges = append(incomingEdges, Edge{ From: n, To: node, Label: edge.Label, Type: edge.Type, }) } } return true }) // Get outgoing edges from the node being removed. outgoingEdges := node.Edges // For each incoming edge and each outgoing edge, create a new edge A -> C. for _, inEdge := range incomingEdges { for _, outEdge := range outgoingEdges { // Avoid creating self-loop. if inEdge.From.ID != outEdge.To.ID { newEdge := Edge{ From: inEdge.From, To: outEdge.To, Label: inEdge.Label + "_" + outEdge.Label, Type: Simple, // Use Simple edge type for adjusted flows. } // Append new edge if one doesn't already exist. for _, e := range inEdge.From.Edges { if e.To.ID == newEdge.To.ID { goto SKIP_ADD } } inEdge.From.Edges = append(inEdge.From.Edges, newEdge) } SKIP_ADD: } } // Remove all edges that are connected to the removed node. tm.nodes.ForEach(func(_ string, n *Node) bool { var updatedEdges []Edge for _, edge := range n.Edges { if edge.To.ID != nodeID { updatedEdges = append(updatedEdges, edge) } } n.Edges = updatedEdges return true }) // Remove any conditions referencing the removed node. for key, cond := range tm.conditions { if key == nodeID { delete(tm.conditions, key) } else { for when, target := range cond { if target == nodeID { delete(cond, when) } } } } // Remove the node from the DAG. tm.nodes.Del(nodeID) // Invalidate caches. tm.nextNodesCache = nil tm.prevNodesCache = nil // Update node identifiers after removal and edge adjustments tm.updateNodeIdentifiers() tm.Logger().Info("Node removed and edges adjusted", logger.Field{Key: "removed_node", Value: nodeID}) return nil } // getOrCreateCircuitBreaker gets or creates a circuit breaker for a node func (tm *DAG) getOrCreateCircuitBreaker(nodeID string) *CircuitBreaker { tm.circuitBreakersMu.RLock() cb, exists := tm.circuitBreakers[nodeID] tm.circuitBreakersMu.RUnlock() if exists { return cb } tm.circuitBreakersMu.Lock() defer tm.circuitBreakersMu.Unlock() // Double-check after acquiring write lock if cb, exists := tm.circuitBreakers[nodeID]; exists { return cb } // Create new circuit breaker with default config config := &CircuitBreakerConfig{ FailureThreshold: 5, ResetTimeout: 30 * time.Second, HalfOpenMaxCalls: 3, } cb = NewCircuitBreaker(config, tm.Logger()) tm.circuitBreakers[nodeID] = cb return cb } // Complete missing methods for DAG func (tm *DAG) GetLastNodes() ([]*Node, error) { var lastNodes []*Node tm.nodes.ForEach(func(key string, node *Node) bool { if len(node.Edges) == 0 { if conds, exists := tm.conditions[node.ID]; !exists || len(conds) == 0 { lastNodes = append(lastNodes, node) } } return true }) return lastNodes, nil } // updateNodeIdentifiers updates the IsLast field for all nodes based on their edges and conditions func (tm *DAG) updateNodeIdentifiers() { tm.nodes.ForEach(func(id string, node *Node) bool { node.IsLast = len(node.Edges) == 0 && len(tm.conditions[node.ID]) == 0 return true }) } // GetFirstNode returns the first node in the DAG func (tm *DAG) GetFirstNode() *Node { if tm.startNode == "" { return nil } node, _ := tm.nodes.Get(tm.startNode) return node } // parseInitialNode extracts the initial node from context func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { if initialNode, ok := ctx.Value("initial_node").(string); ok && initialNode != "" { return initialNode, nil } // If no initial node specified, use start node if tm.startNode != "" { return tm.startNode, nil } // Find first node if no start node is set firstNode := tm.findStartNode() if firstNode != nil { return firstNode.ID, nil } return "", fmt.Errorf("no initial node found") } // findStartNode finds the first node in the DAG func (tm *DAG) findStartNode() *Node { incomingEdges := make(map[string]bool) connectedNodes := make(map[string]bool) for _, node := range tm.nodes.AsMap() { for _, edge := range node.Edges { if edge.Type.IsValid() { connectedNodes[node.ID] = true connectedNodes[edge.To.ID] = true incomingEdges[edge.To.ID] = true } } if cond, ok := tm.conditions[node.ID]; ok { for _, target := range cond { connectedNodes[target] = true incomingEdges[target] = true } } } for nodeID, node := range tm.nodes.AsMap() { if !incomingEdges[nodeID] && connectedNodes[nodeID] { return node } } return nil } // IsLastNode checks if a node is the last node in the DAG func (tm *DAG) IsLastNode(nodeID string) (bool, error) { node, exists := tm.nodes.Get(nodeID) if !exists { return false, fmt.Errorf("node %s not found", nodeID) } // Check if node has any outgoing edges if len(node.Edges) > 0 { return false, nil } // Check if node has any conditional edges if conditions, exists := tm.conditions[nodeID]; exists && len(conditions) > 0 { return false, nil } return true, nil } // GetNextNodes returns the next nodes for a given node func (tm *DAG) GetNextNodes(nodeID string) ([]*Node, error) { nodeID = strings.Split(nodeID, Delimiter)[0] // Check cache if tm.nextNodesCache != nil { if cached, exists := tm.nextNodesCache.Load(nodeID); exists { return cached.([]*Node), nil } } node, exists := tm.nodes.Get(nodeID) if !exists { return nil, fmt.Errorf("node %s not found", nodeID) } var nextNodes []*Node // Add direct edge targets for _, edge := range node.Edges { nextNodes = append(nextNodes, edge.To) } // Add conditional targets if conditions, exists := tm.conditions[nodeID]; exists { for _, targetID := range conditions { if targetNode, ok := tm.nodes.Get(targetID); ok { nextNodes = append(nextNodes, targetNode) } } } // Cache the result if tm.nextNodesCache != nil { tm.nextNodesCache.Store(nodeID, nextNodes) } return nextNodes, nil } // GetPreviousNodes returns the previous nodes for a given node func (tm *DAG) GetPreviousNodes(nodeID string) ([]*Node, error) { nodeID = strings.Split(nodeID, Delimiter)[0] // Check cache if tm.prevNodesCache != nil { if cached, exists := tm.prevNodesCache.Load(nodeID); exists { return cached.([]*Node), nil } } var prevNodes []*Node // Find nodes that point to this node tm.nodes.ForEach(func(id string, node *Node) bool { // Check direct edges for _, edge := range node.Edges { if edge.To.ID == nodeID { prevNodes = append(prevNodes, node) break } } // Check conditional edges if conditions, exists := tm.conditions[id]; exists { for _, targetID := range conditions { if targetID == nodeID { prevNodes = append(prevNodes, node) break } } } return true }) // Cache the result if tm.prevNodesCache != nil { tm.prevNodesCache.Store(nodeID, prevNodes) } return prevNodes, nil } // GetNodeByID returns a node by its ID func (tm *DAG) GetNodeByID(nodeID string) (*Node, error) { node, exists := tm.nodes.Get(nodeID) if !exists { return nil, fmt.Errorf("node %s not found", nodeID) } return node, nil } // GetAllNodes returns all nodes in the DAG func (tm *DAG) GetAllNodes() map[string]*Node { result := make(map[string]*Node) tm.nodes.ForEach(func(id string, node *Node) bool { result[id] = node return true }) return result } // GetNodeCount returns the total number of nodes func (tm *DAG) GetNodeCount() int { return tm.nodes.Size() }