package dag import ( "context" "fmt" "log" "math/rand" // ...new import for jitter... "strings" "sync" "time" "github.com/oarkflow/json" "github.com/oarkflow/mq" dagstorage "github.com/oarkflow/mq/dag/storage" // Import dag storage package with alias "github.com/oarkflow/mq/logger" mqstorage "github.com/oarkflow/mq/storage" "github.com/oarkflow/mq/storage/memory" ) // TaskError is used by node processors to indicate whether an error is recoverable. type TaskError struct { Err error Recoverable bool } func (te TaskError) Error() string { return te.Err.Error() } // TaskState holds state and intermediate results for a given task (identified by a node ID). type TaskState struct { UpdatedAt time.Time targetResults mqstorage.IMap[string, mq.Result] NodeID string Status mq.Status Result mq.Result } func newTaskState(nodeID string) *TaskState { return &TaskState{ NodeID: nodeID, Status: mq.Pending, UpdatedAt: time.Now(), targetResults: memory.New[string, mq.Result](), } } type nodeResult struct { ctx context.Context nodeID string status mq.Status result mq.Result } type task struct { ctx context.Context taskID string nodeID string payload json.RawMessage } func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task { return &task{ ctx: ctx, taskID: taskID, nodeID: nodeID, payload: payload, } } type TaskManagerConfig struct { MaxRetries int BaseBackoff time.Duration RecoveryHandler func(ctx context.Context, result mq.Result) error } type TaskManager struct { createdAt time.Time taskStates mqstorage.IMap[string, *TaskState] parentNodes mqstorage.IMap[string, string] childNodes mqstorage.IMap[string, int] deferredTasks mqstorage.IMap[string, *task] iteratorNodes mqstorage.IMap[string, []Edge] currentNodePayload mqstorage.IMap[string, json.RawMessage] currentNodeResult mqstorage.IMap[string, mq.Result] taskQueue chan *task result *mq.Result resultQueue chan nodeResult resultCh chan mq.Result stopCh chan struct{} taskID string dag *DAG maxRetries int baseBackoff time.Duration recoveryHandler func(ctx context.Context, result mq.Result) error pauseMu sync.Mutex pauseCh chan struct{} wg sync.WaitGroup storage dagstorage.TaskStorage // Added TaskStorage for persistence } func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes mqstorage.IMap[string, []Edge], taskStorage dagstorage.TaskStorage) *TaskManager { config := TaskManagerConfig{ MaxRetries: 3, BaseBackoff: time.Second, } tm := &TaskManager{ createdAt: time.Now(), taskStates: memory.New[string, *TaskState](), parentNodes: memory.New[string, string](), childNodes: memory.New[string, int](), deferredTasks: memory.New[string, *task](), currentNodePayload: memory.New[string, json.RawMessage](), currentNodeResult: memory.New[string, mq.Result](), taskQueue: make(chan *task, DefaultChannelSize), resultQueue: make(chan nodeResult, DefaultChannelSize), resultCh: resultCh, stopCh: make(chan struct{}), taskID: taskID, dag: dag, maxRetries: config.MaxRetries, baseBackoff: config.BaseBackoff, recoveryHandler: config.RecoveryHandler, iteratorNodes: iteratorNodes, storage: taskStorage, } tm.wg.Add(3) go tm.run() go tm.waitForResult() go tm.retryDeferredTasks() return tm } func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) { tm.enqueueTask(ctx, startNode, tm.taskID, payload) } func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string, payload json.RawMessage) { if tm.dag.debug { tm.dag.Logger().Info("enqueueTask called", logger.Field{Key: "startNode", Value: startNode}, logger.Field{Key: "taskID", Value: taskID}, logger.Field{Key: "payloadSize", Value: len(payload)}) } if index, ok := ctx.Value(ContextIndex).(string); ok { base := strings.Split(startNode, Delimiter)[0] startNode = fmt.Sprintf("%s%s%s", base, Delimiter, index) } if _, exists := tm.taskStates.Get(startNode); !exists { tm.taskStates.Set(startNode, newTaskState(startNode)) } t := newTask(ctx, taskID, startNode, payload) // Persist task to storage if tm.storage != nil { persistentTask := &dagstorage.PersistentTask{ ID: taskID, DAGID: tm.dag.key, NodeID: startNode, CurrentNodeID: startNode, ProcessingState: "enqueued", Status: dagstorage.TaskStatusPending, Payload: payload, CreatedAt: time.Now(), UpdatedAt: time.Now(), MaxRetries: tm.maxRetries, } if err := tm.storage.SaveTask(ctx, persistentTask); err != nil { tm.dag.Logger().Error("Failed to persist task", logger.Field{Key: "taskID", Value: taskID}, logger.Field{Key: "error", Value: err.Error()}) } else { // Log task creation activity tm.logActivity(ctx, taskID, startNode, "task_created", "Task enqueued for processing", nil) } } select { case tm.taskQueue <- t: // Successfully enqueued default: // Queue is full, add to deferred tasks with limit if tm.deferredTasks.Size() < 1000 { // Limit deferred tasks to prevent memory issues tm.deferredTasks.Set(taskID, t) tm.dag.Logger().Warn("Task queue full, deferring task", logger.Field{Key: "taskID", Value: taskID}, logger.Field{Key: "nodeID", Value: startNode}) } else { tm.dag.Logger().Error("Deferred tasks queue also full, dropping task", logger.Field{Key: "taskID", Value: taskID}, logger.Field{Key: "nodeID", Value: startNode}) } } } func (tm *TaskManager) run() { defer tm.wg.Done() for { select { case <-tm.stopCh: log.Println("Stopping TaskManager run loop") return default: tm.pauseMu.Lock() pch := tm.pauseCh tm.pauseMu.Unlock() if pch != nil { select { case <-tm.stopCh: log.Println("Stopping TaskManager run loop during pause") return case <-pch: // Resume from pause } } select { case <-tm.stopCh: log.Println("Stopping TaskManager run loop") return case tsk := <-tm.taskQueue: tm.processNode(tsk) } } } } // waitForResult listens for node results on resultQueue and processes them. func (tm *TaskManager) waitForResult() { defer tm.wg.Done() for { select { case <-tm.stopCh: log.Println("Stopping TaskManager result listener") return case nr := <-tm.resultQueue: select { case <-tm.stopCh: log.Println("Stopping TaskManager result listener during processing") return default: tm.onNodeCompleted(nr) } } } } // areDependenciesMet checks if all previous nodes have completed successfully func (tm *TaskManager) areDependenciesMet(nodeID string) bool { pureNodeID := strings.Split(nodeID, Delimiter)[0] // Get previous nodes prevNodes, err := tm.dag.GetPreviousNodes(pureNodeID) if err != nil { tm.dag.Logger().Error("Error getting previous nodes", logger.Field{Key: "nodeID", Value: nodeID}, logger.Field{Key: "error", Value: err.Error()}) return false } // For iterator nodes, we need to be more selective about dependencies // Iterator nodes should only depend on nodes that provide data to them, // not on nodes that they create (which would be circular dependencies) node, exists := tm.dag.nodes.Get(pureNodeID) if exists { // Check if this node has any iterator edges (meaning it's an iterator node) hasIteratorEdges := false for _, edge := range node.Edges { if edge.Type == Iterator { hasIteratorEdges = true break } } if hasIteratorEdges { // For iterator nodes, only check dependencies from Simple edges // Iterator edges represent outputs, not inputs filteredPrevNodes := make([]*Node, 0) for _, prevNode := range prevNodes { // Check if there's a Simple edge from prevNode to this node hasSimpleEdge := false for _, edge := range prevNode.Edges { if edge.To.ID == pureNodeID && edge.Type == Simple { hasSimpleEdge = true break } } if hasSimpleEdge { filteredPrevNodes = append(filteredPrevNodes, prevNode) } } prevNodes = filteredPrevNodes } } // Check if all relevant previous nodes have completed successfully for _, prevNode := range prevNodes { // Check both the pure node ID and the indexed node ID for state state, exists := tm.taskStates.Get(prevNode.ID) if !exists { // Also check if there's a state with an index suffix tm.taskStates.ForEach(func(key string, s *TaskState) bool { if strings.Split(key, Delimiter)[0] == prevNode.ID { state = s exists = true return false // Stop iteration } return true }) } if !exists || state.Status != mq.Completed { tm.dag.Logger().Debug("Dependency not met", logger.Field{Key: "nodeID", Value: nodeID}, logger.Field{Key: "dependency", Value: prevNode.ID}, logger.Field{Key: "stateExists", Value: exists}, logger.Field{Key: "stateStatus", Value: string(state.Status)}, logger.Field{Key: "taskID", Value: tm.taskID}) return false } } return true } func (tm *TaskManager) processNode(exec *task) { startTime := time.Now() pureNodeID := strings.Split(exec.nodeID, Delimiter)[0] node, exists := tm.dag.nodes.Get(pureNodeID) if !exists { tm.dag.Logger().Error("Node not found", logger.Field{Key: "nodeID", Value: pureNodeID}) return } // Check if all dependencies are met before processing if !tm.areDependenciesMet(pureNodeID) { tm.dag.Logger().Warn("Dependencies not met for node, deferring", logger.Field{Key: "nodeID", Value: pureNodeID}) // Defer the task tm.deferredTasks.Set(exec.taskID, exec) return } // Wrap context with timeout if node.Timeout is configured. if node.Timeout > 0 { var cancel context.CancelFunc exec.ctx, cancel = context.WithTimeout(exec.ctx, node.Timeout) defer cancel() } // Check for context cancellation before processing select { case <-exec.ctx.Done(): tm.dag.Logger().Warn("Context cancelled before node processing", logger.Field{Key: "nodeID", Value: exec.nodeID}, logger.Field{Key: "taskID", Value: exec.taskID}) return default: } // Invoke PreProcessHook if available. if tm.dag.PreProcessHook != nil { exec.ctx = tm.dag.PreProcessHook(exec.ctx, node, exec.taskID, exec.payload) } // Debug logging before processing if node.IsDebugEnabled(tm.dag.IsDebugEnabled()) { tm.debugNodeStart(exec, node) } state, _ := tm.taskStates.Get(exec.nodeID) if state == nil { tm.dag.Logger().Warn("State not found; creating new state", logger.Field{Key: "nodeID", Value: exec.nodeID}) state = newTaskState(exec.nodeID) tm.taskStates.Set(exec.nodeID, state) } state.Status = mq.Processing state.UpdatedAt = time.Now() tm.currentNodePayload.Clear() // Update task status in storage if tm.storage != nil { // Update task position and status if err := tm.updateTaskPosition(exec.ctx, exec.taskID, pureNodeID, "processing"); err != nil { tm.dag.Logger().Error("Failed to update task position", logger.Field{Key: "taskID", Value: exec.taskID}, logger.Field{Key: "error", Value: err.Error()}) } if err := tm.storage.UpdateTaskStatus(exec.ctx, exec.taskID, dagstorage.TaskStatusRunning, ""); err != nil { tm.dag.Logger().Error("Failed to update task status", logger.Field{Key: "taskID", Value: exec.taskID}, logger.Field{Key: "error", Value: err.Error()}) } // Log node processing start tm.logActivity(exec.ctx, exec.taskID, pureNodeID, "node_processing_started", "Node processing started", nil) } tm.currentNodeResult.Clear() tm.currentNodePayload.Set(exec.nodeID, exec.payload) var result mq.Result attempts := 0 for { // log.Printf("Tracing: Start processing node %s (attempt %d) on flow %s", exec.nodeID, attempts+1, tm.dag.key) // Get middlewares for this node middlewares := tm.dag.getNodeMiddlewares(pureNodeID) // Execute middlewares and processor result = tm.dag.executeMiddlewares(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID, mq.WithDAG(tm.dag)), middlewares, func(ctx context.Context, task *mq.Task) mq.Result { return node.processor.ProcessTask(ctx, task) }) if result.Error != nil { if te, ok := result.Error.(TaskError); ok && te.Recoverable { if attempts < tm.maxRetries { attempts++ backoff := tm.baseBackoff * time.Duration(1< 1 { aggregated := make([]json.RawMessage, size) i := 0 state.targetResults.ForEach(func(_ string, res mq.Result) bool { aggregated[i] = res.Payload i++ return true }) aggregatedPayload, err := json.Marshal(aggregated) if err != nil { panic(err) } state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID} } else if size == 1 { state.Result = state.targetResults.Values()[0] } state.Status = result.Status state.Result.Status = result.Status } if state.Result.Payload == nil { state.Result.Payload = result.Payload } state.UpdatedAt = time.Now() if result.Ctx == nil { result.Ctx = ctx } if result.Error != nil { state.Status = mq.Failed } if parentKey, ok := tm.parentNodes.Get(state.NodeID); ok { parts := strings.Split(state.NodeID, Delimiter) // For iterator nodes, only continue to next edge after ALL children have completed and been aggregated if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed && size == targetsCount { state.Status = mq.Processing tm.iteratorNodes.Del(parts[0]) state.targetResults.Clear() if len(parts) == 2 { ctx = context.WithValue(ctx, ContextIndex, parts[1]) } toProcess := nodeResult{ ctx: ctx, nodeID: state.NodeID, status: state.Status, result: state.Result, } tm.handleEdges(toProcess, edges) state.Status = mq.Completed } else if size == targetsCount { if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil { state.Result.Topic = state.NodeID tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal) } } } else { tm.updateTimestamps(&state.Result) tm.result = &state.Result state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0] tm.resultCh <- state.Result tm.processFinalResult(state) } } func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) { state.UpdatedAt = time.Now() if result.Ctx == nil { result.Ctx = ctx } if result.Error != nil { state.Status = mq.Failed } else { edges := tm.getConditionalEdges(node, result) if len(edges) == 0 { state.Status = mq.Completed } } if result.Status == "" { result.Status = state.Status } // Update task status in storage based on final result if tm.storage != nil { var status dagstorage.TaskStatus var errorMsg string var action string var message string if result.Error != nil { status = dagstorage.TaskStatusFailed errorMsg = result.Error.Error() action = "node_failed" message = fmt.Sprintf("Node %s failed: %s", state.NodeID, errorMsg) } else if state.Status == mq.Completed { status = dagstorage.TaskStatusCompleted action = "node_completed" message = fmt.Sprintf("Node %s completed successfully", state.NodeID) } else { status = dagstorage.TaskStatusRunning action = "node_processing" message = fmt.Sprintf("Node %s processing", state.NodeID) } if err := tm.storage.UpdateTaskStatus(ctx, tm.taskID, status, errorMsg); err != nil { tm.dag.Logger().Error("Failed to update task status", logger.Field{Key: "taskID", Value: tm.taskID}, logger.Field{Key: "error", Value: err.Error()}) } // Log node completion/failure tm.logActivity(ctx, tm.taskID, state.NodeID, action, message, result.Payload) } tm.enqueueResult(nodeResult{ ctx: ctx, nodeID: state.NodeID, status: state.Status, result: result, }) } func (tm *TaskManager) enqueueResult(nr nodeResult) { select { case tm.resultQueue <- nr: // Successfully enqueued default: tm.dag.Logger().Error("Result queue is full, dropping result", logger.Field{Key: "nodeID", Value: nr.nodeID}, logger.Field{Key: "taskID", Value: nr.result.TaskID}) } } func (tm *TaskManager) onNodeCompleted(nr nodeResult) { if tm.dag.debug { tm.dag.Logger().Info("onNodeCompleted called", logger.Field{Key: "nodeID", Value: nr.nodeID}, logger.Field{Key: "status", Value: string(nr.status)}, logger.Field{Key: "hasError", Value: nr.result.Error != nil}) } nodeID := strings.Split(nr.nodeID, Delimiter)[0] node, ok := tm.dag.nodes.Get(nodeID) if !ok { return } // Handle ResetTo functionality if nr.result.ResetTo != "" { tm.handleResetTo(nr) return } if nr.result.Error != nil || nr.status == mq.Failed { if state, exists := tm.taskStates.Get(nr.nodeID); exists { tm.processFinalResult(state) return } } edges := tm.getConditionalEdges(node, nr.result) if len(edges) > 0 { tm.handleEdges(nr, edges) return } // Check if this is a child node from an iterator (has a parent) if parentKey, exists := tm.parentNodes.Get(nr.nodeID); exists { if tm.dag.debug { tm.dag.Logger().Info("Found parent for node", logger.Field{Key: "nodeID", Value: nr.nodeID}, logger.Field{Key: "parentKey", Value: parentKey}) } if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil { tm.handlePrevious(nr.ctx, parentState, nr.result, nr.nodeID, true) return // Don't send to resultCh if has parent } } if tm.dag.debug { tm.dag.Logger().Info("No parent found for node, sending to resultCh", logger.Field{Key: "nodeID", Value: nr.nodeID}, logger.Field{Key: "result_topic", Value: nr.result.Topic}) } tm.updateTimestamps(&nr.result) tm.resultCh <- nr.result if state, ok := tm.taskStates.Get(nr.nodeID); ok { tm.processFinalResult(state) } } func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge { edges := make([]Edge, len(node.Edges)) copy(edges, node.Edges) // Handle conditional edges based on ConditionStatus if result.ConditionStatus != "" { if conditions, ok := tm.dag.conditions[node.ID]; ok { if targetKey, exists := conditions[result.ConditionStatus]; exists { if targetNode, found := tm.dag.nodes.Get(targetKey); found { conditionalEdge := Edge{ From: node, FromSource: node.ID, To: targetNode, Label: fmt.Sprintf("condition:%s", result.ConditionStatus), Type: Simple, } edges = append(edges, conditionalEdge) } } else if targetKey, exists := conditions["default"]; exists { if targetNode, found := tm.dag.nodes.Get(targetKey); found { conditionalEdge := Edge{ From: node, FromSource: node.ID, To: targetNode, Label: "condition:default", Type: Simple, } edges = append(edges, conditionalEdge) } } } } return edges } func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) { if len(edges) == 0 { return } if len(edges) == 1 { tm.processSingleEdge(currentResult, edges[0]) return } // For multiple edges, process sequentially to avoid race conditions for _, edge := range edges { tm.processSingleEdge(currentResult, edge) } } func (tm *TaskManager) processSingleEdge(currentResult nodeResult, edge Edge) { index, ok := currentResult.ctx.Value(ContextIndex).(string) if !ok { index = "0" } parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index) switch edge.Type { case Simple: if _, exists := tm.iteratorNodes.Get(edge.From.ID); exists { return } fallthrough case Iterator: if edge.Type == Iterator { if _, exists := tm.iteratorNodes.Get(edge.From.ID); !exists { return } // Use the actual completing node as parent, not the edge From ID parentNode = currentResult.nodeID var items []json.RawMessage if err := json.Unmarshal(currentResult.result.Payload, &items); err != nil { log.Printf("Error unmarshalling payload for node %s: %v", edge.To.ID, err) tm.enqueueResult(nodeResult{ ctx: currentResult.ctx, nodeID: edge.To.ID, status: mq.Failed, result: mq.Result{Error: err}, }) return } tm.childNodes.Set(parentNode, len(items)) for i, item := range items { childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i) ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i)) tm.parentNodes.Set(childNode, parentNode) tm.enqueueTask(ctx, edge.To.ID, tm.taskID, item) } } else { tm.childNodes.Set(parentNode, 1) idx, _ := currentResult.ctx.Value(ContextIndex).(string) childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx) ctx := context.WithValue(currentResult.ctx, ContextIndex, idx) // If the current result came from an iterator child that has a parent, // we need to preserve that parent relationship for the new target node if originalParent, hasParent := tm.parentNodes.Get(currentResult.nodeID); hasParent { if tm.dag.debug { tm.dag.Logger().Info("Transferring parent relationship for conditional edge", logger.Field{Key: "originalChild", Value: currentResult.nodeID}, logger.Field{Key: "newChild", Value: childNode}, logger.Field{Key: "parent", Value: originalParent}) } // Remove the original child from parent tracking since it's being replaced by conditional target tm.parentNodes.Del(currentResult.nodeID) // This edge target should now report back to the original parent instead tm.parentNodes.Set(childNode, originalParent) } else { if tm.dag.debug { tm.dag.Logger().Info("No parent found for conditional edge source", logger.Field{Key: "nodeID", Value: currentResult.nodeID}) } tm.parentNodes.Set(childNode, parentNode) } tm.enqueueTask(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload) } } } func (tm *TaskManager) retryDeferredTasks() { defer tm.wg.Done() ticker := time.NewTicker(tm.baseBackoff) defer ticker.Stop() for { select { case <-tm.stopCh: log.Println("Stopping deferred task retrier") return case <-ticker.C: // Process deferred tasks with a limit to prevent overwhelming the queue processed := 0 tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool { if processed >= 10 { // Process max 10 deferred tasks per tick return false } select { case tm.taskQueue <- tsk: tm.deferredTasks.Del(taskID) processed++ tm.dag.Logger().Debug("Retried deferred task", logger.Field{Key: "taskID", Value: taskID}) default: // Queue still full, keep the task deferred return false } return true }) } } } func (tm *TaskManager) processFinalResult(state *TaskState) { state.Status = mq.Completed state.targetResults.Clear() // update metrics using the task start time for duration calculation tm.dag.updateTaskMetrics(tm.taskID, state.Result, time.Since(tm.createdAt)) if tm.dag.finalResult != nil { tm.dag.finalResult(tm.taskID, state.Result) } } func (tm *TaskManager) Pause() { tm.pauseMu.Lock() defer tm.pauseMu.Unlock() if tm.pauseCh == nil { tm.pauseCh = make(chan struct{}) log.Println("TaskManager paused") } } func (tm *TaskManager) Resume() { tm.pauseMu.Lock() defer tm.pauseMu.Unlock() if tm.pauseCh != nil { close(tm.pauseCh) tm.pauseCh = nil log.Println("TaskManager resumed") } } // Stop gracefully stops the task manager func (tm *TaskManager) Stop() { close(tm.stopCh) tm.wg.Wait() // Cancel any pending operations tm.pauseMu.Lock() if tm.pauseCh != nil { close(tm.pauseCh) tm.pauseCh = nil } tm.pauseMu.Unlock() // Clean up resources tm.taskStates.Clear() tm.parentNodes.Clear() tm.childNodes.Clear() tm.deferredTasks.Clear() tm.currentNodePayload.Clear() tm.currentNodeResult.Clear() tm.dag.Logger().Info("TaskManager stopped gracefully", logger.Field{Key: "taskID", Value: tm.taskID}) } // debugNodeStart logs debug information when a node starts processing func (tm *TaskManager) debugNodeStart(exec *task, node *Node) { var payload map[string]any if err := json.Unmarshal(exec.payload, &payload); err != nil { payload = map[string]any{"raw_payload": string(exec.payload)} } tm.dag.Logger().Info("🐛 [DEBUG] Node processing started", logger.Field{Key: "dag_name", Value: tm.dag.name}, logger.Field{Key: "task_id", Value: exec.taskID}, logger.Field{Key: "node_id", Value: node.ID}, logger.Field{Key: "node_type", Value: node.NodeType.String()}, logger.Field{Key: "node_label", Value: node.Label}, logger.Field{Key: "timestamp", Value: time.Now().Format(time.RFC3339)}, logger.Field{Key: "has_timeout", Value: node.Timeout > 0}, logger.Field{Key: "timeout_duration", Value: node.Timeout.String()}, logger.Field{Key: "payload_size", Value: len(exec.payload)}, logger.Field{Key: "payload_preview", Value: tm.getPayloadPreview(payload)}, logger.Field{Key: "debug_mode", Value: "individual_node:" + fmt.Sprintf("%t", node.Debug) + ", dag_global:" + fmt.Sprintf("%t", tm.dag.IsDebugEnabled())}, ) // Log processor type if it implements the Debugger interface if debugger, ok := node.processor.(Debugger); ok { debugger.Debug(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID, mq.WithDAG(tm.dag))) } } // debugNodeComplete logs debug information when a node completes processing func (tm *TaskManager) debugNodeComplete(exec *task, node *Node, result mq.Result, latency time.Duration, attempts int) { var resultPayload map[string]any if len(result.Payload) > 0 { if err := json.Unmarshal(result.Payload, &resultPayload); err != nil { resultPayload = map[string]any{"raw_payload": string(result.Payload)} } } tm.dag.Logger().Info("🐛 [DEBUG] Node processing completed", logger.Field{Key: "dag_name", Value: tm.dag.name}, logger.Field{Key: "task_id", Value: exec.taskID}, logger.Field{Key: "node_id", Value: node.ID}, logger.Field{Key: "node_type", Value: node.NodeType.String()}, logger.Field{Key: "node_label", Value: node.Label}, logger.Field{Key: "timestamp", Value: time.Now().Format(time.RFC3339)}, logger.Field{Key: "status", Value: string(result.Status)}, logger.Field{Key: "latency", Value: latency.String()}, logger.Field{Key: "attempts", Value: attempts + 1}, logger.Field{Key: "has_error", Value: result.Error != nil}, logger.Field{Key: "error_message", Value: tm.getErrorMessage(result.Error)}, logger.Field{Key: "result_size", Value: len(result.Payload)}, logger.Field{Key: "result_preview", Value: tm.getPayloadPreview(resultPayload)}, logger.Field{Key: "is_last_node", Value: result.Last}, ) } // getPayloadPreview returns a truncated version of the payload for debug logging func (tm *TaskManager) getPayloadPreview(payload map[string]any) string { if payload == nil { return "null" } preview := make(map[string]any) count := 0 maxFields := 5 // Limit to first 5 fields to avoid log spam for key, value := range payload { if count >= maxFields { preview["..."] = fmt.Sprintf("and %d more fields", len(payload)-maxFields) break } // Truncate string values if they're too long if strVal, ok := value.(string); ok && len(strVal) > 100 { preview[key] = strVal[:97] + "..." } else { preview[key] = value } count++ } previewBytes, _ := json.Marshal(preview) return string(previewBytes) } // getErrorMessage safely extracts error message func (tm *TaskManager) getErrorMessage(err error) string { if err == nil { return "" } return err.Error() } // logActivity logs an activity for a task func (tm *TaskManager) logActivity(ctx context.Context, taskID, nodeID, action, message string, data json.RawMessage) { if tm.storage == nil { return } logEntry := &dagstorage.TaskActivityLog{ TaskID: taskID, DAGID: tm.dag.key, NodeID: nodeID, Action: action, Message: message, Data: data, Level: "info", CreatedAt: time.Now(), } if err := tm.storage.LogActivity(ctx, logEntry); err != nil { tm.dag.Logger().Error("Failed to log activity", logger.Field{Key: "taskID", Value: taskID}, logger.Field{Key: "action", Value: action}, logger.Field{Key: "error", Value: err.Error()}) } } // updateTaskPosition updates the current position of a task in the DAG func (tm *TaskManager) updateTaskPosition(ctx context.Context, taskID, currentNodeID, processingState string) error { if tm.storage == nil { return nil } // Get the current task task, err := tm.storage.GetTask(ctx, taskID) if err != nil { return fmt.Errorf("failed to get task for position update: %w", err) } // Update position fields task.CurrentNodeID = currentNodeID task.ProcessingState = processingState task.UpdatedAt = time.Now() // Save the updated task return tm.storage.SaveTask(ctx, task) } // handleResetTo handles the ResetTo functionality for resetting a task to a specific node func (tm *TaskManager) handleResetTo(nr nodeResult) { resetTo := nr.result.ResetTo nodeID := strings.Split(nr.nodeID, Delimiter)[0] var targetNodeID string var err error if resetTo == "back" { // Use GetPreviousPageNode to find the previous page node var prevNode *Node prevNode, err = tm.dag.GetPreviousPageNode(nodeID) if err != nil { tm.dag.Logger().Error("Failed to get previous page node", logger.Field{Key: "currentNodeID", Value: nodeID}, logger.Field{Key: "error", Value: err.Error()}) // Send error result tm.resultCh <- mq.Result{ Error: fmt.Errorf("failed to reset to previous page node: %w", err), Ctx: nr.ctx, TaskID: nr.result.TaskID, Topic: nr.result.Topic, Status: mq.Failed, Payload: nr.result.Payload, } return } if prevNode == nil { tm.dag.Logger().Error("No previous page node found", logger.Field{Key: "currentNodeID", Value: nodeID}) // Send error result tm.resultCh <- mq.Result{ Error: fmt.Errorf("no previous page node found"), Ctx: nr.ctx, TaskID: nr.result.TaskID, Topic: nr.result.Topic, Status: mq.Failed, Payload: nr.result.Payload, } return } targetNodeID = prevNode.ID } else { // Use the specified node ID targetNodeID = resetTo // Validate that the target node exists if _, exists := tm.dag.nodes.Get(targetNodeID); !exists { tm.dag.Logger().Error("Reset target node does not exist", logger.Field{Key: "targetNodeID", Value: targetNodeID}) // Send error result tm.resultCh <- mq.Result{ Error: fmt.Errorf("reset target node %s does not exist", targetNodeID), Ctx: nr.ctx, TaskID: nr.result.TaskID, Topic: nr.result.Topic, Status: mq.Failed, Payload: nr.result.Payload, } return } } if tm.dag.debug { tm.dag.Logger().Info("Resetting task to node", logger.Field{Key: "taskID", Value: nr.result.TaskID}, logger.Field{Key: "fromNode", Value: nodeID}, logger.Field{Key: "toNode", Value: targetNodeID}, logger.Field{Key: "resetTo", Value: resetTo}) } // Clear task states of all nodes between current node and target node // This ensures that when we reset, the workflow can proceed correctly tm.clearTaskStatesInPath(nodeID, targetNodeID) // Also clear any deferred tasks for the target node itself tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool { if strings.Split(tsk.nodeID, Delimiter)[0] == targetNodeID { tm.deferredTasks.Del(taskID) if tm.dag.debug { tm.dag.Logger().Debug("Cleared deferred task for target node", logger.Field{Key: "nodeID", Value: targetNodeID}, logger.Field{Key: "taskID", Value: taskID}) } } return true }) // Handle dependencies of the target node - if they exist and are not completed, // we need to mark them as completed to allow the workflow to proceed tm.handleTargetNodeDependencies(targetNodeID, nr) // Get previously received data for the target node var previousPayload json.RawMessage if prevResult, hasResult := tm.currentNodeResult.Get(targetNodeID); hasResult { previousPayload = prevResult.Payload if tm.dag.debug { tm.dag.Logger().Info("Using previous payload for reset", logger.Field{Key: "targetNodeID", Value: targetNodeID}, logger.Field{Key: "payloadSize", Value: len(previousPayload)}) } } else { // If no previous data, use the current result's payload previousPayload = nr.result.Payload if tm.dag.debug { tm.dag.Logger().Info("No previous payload found, using current payload", logger.Field{Key: "targetNodeID", Value: targetNodeID}) } } // Reset task state for the target node if state, exists := tm.taskStates.Get(targetNodeID); exists { state.Status = mq.Completed // Mark as completed to satisfy dependencies state.UpdatedAt = time.Now() state.Result = mq.Result{ Status: mq.Completed, Ctx: nr.ctx, } } else { // Create new state if it doesn't exist and mark as completed newState := newTaskState(targetNodeID) newState.Status = mq.Completed newState.Result = mq.Result{ Status: mq.Completed, Ctx: nr.ctx, } tm.taskStates.Set(targetNodeID, newState) } // Update current node result with the reset result (clear ResetTo to avoid loops) resetResult := mq.Result{ TaskID: nr.result.TaskID, Topic: targetNodeID, Status: mq.Completed, // Mark as completed Payload: previousPayload, Ctx: nr.ctx, // ResetTo is intentionally not set to avoid infinite loops } tm.currentNodeResult.Set(targetNodeID, resetResult) // Re-enqueue the task for the target node tm.enqueueTask(nr.ctx, targetNodeID, nr.result.TaskID, previousPayload) // Log the reset activity tm.logActivity(nr.ctx, nr.result.TaskID, targetNodeID, "task_reset", fmt.Sprintf("Task reset from %s to %s", nodeID, targetNodeID), nil) } // clearTaskStatesInPath clears all task states in the path from current node to target node // This is necessary when resetting to ensure the workflow can proceed without dependency issues func (tm *TaskManager) clearTaskStatesInPath(currentNodeID, targetNodeID string) { // Get all nodes in the path from current to target pathNodes := tm.getNodesInPath(currentNodeID, targetNodeID) if tm.dag.debug { tm.dag.Logger().Info("Clearing task states in path", logger.Field{Key: "fromNode", Value: currentNodeID}, logger.Field{Key: "toNode", Value: targetNodeID}, logger.Field{Key: "pathNodeCount", Value: len(pathNodes)}) } // Also clear the current node itself (ValidateInput in the example) if state, exists := tm.taskStates.Get(currentNodeID); exists { state.Status = mq.Pending state.UpdatedAt = time.Now() state.Result = mq.Result{} // Clear previous result if tm.dag.debug { tm.dag.Logger().Debug("Cleared task state for current node", logger.Field{Key: "nodeID", Value: currentNodeID}) } } // Also clear any cached results for the current node tm.currentNodeResult.Del(currentNodeID) // Clear any deferred tasks for the current node tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool { if strings.Split(tsk.nodeID, Delimiter)[0] == currentNodeID { tm.deferredTasks.Del(taskID) if tm.dag.debug { tm.dag.Logger().Debug("Cleared deferred task for current node", logger.Field{Key: "nodeID", Value: currentNodeID}, logger.Field{Key: "taskID", Value: taskID}) } } return true }) // Clear task states for all nodes in the path for _, pathNodeID := range pathNodes { if state, exists := tm.taskStates.Get(pathNodeID); exists { state.Status = mq.Pending state.UpdatedAt = time.Now() state.Result = mq.Result{} // Clear previous result if tm.dag.debug { tm.dag.Logger().Debug("Cleared task state for path node", logger.Field{Key: "nodeID", Value: pathNodeID}) } } // Also clear any cached results for this node tm.currentNodeResult.Del(pathNodeID) // Clear any deferred tasks for this node tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool { if strings.Split(tsk.nodeID, Delimiter)[0] == pathNodeID { tm.deferredTasks.Del(taskID) if tm.dag.debug { tm.dag.Logger().Debug("Cleared deferred task for path node", logger.Field{Key: "nodeID", Value: pathNodeID}, logger.Field{Key: "taskID", Value: taskID}) } } return true }) } } // getNodesInPath returns all nodes in the path from start node to end node func (tm *TaskManager) getNodesInPath(startNodeID, endNodeID string) []string { visited := make(map[string]bool) var result []string // Use BFS to find the path from start to end queue := []string{startNodeID} visited[startNodeID] = true parent := make(map[string]string) found := false for len(queue) > 0 && !found { currentNodeID := queue[0] queue = queue[1:] // Get all nodes that this node points to if node, exists := tm.dag.nodes.Get(currentNodeID); exists { for _, edge := range node.Edges { if edge.Type == Simple || edge.Type == Iterator { targetNodeID := edge.To.ID if !visited[targetNodeID] { visited[targetNodeID] = true parent[targetNodeID] = currentNodeID queue = append(queue, targetNodeID) if targetNodeID == endNodeID { found = true break } } } } } } // If we found the end node, reconstruct the path if found { current := endNodeID for current != startNodeID { result = append([]string{current}, result...) if parentNode, exists := parent[current]; exists { current = parentNode } else { break } } result = append([]string{startNodeID}, result...) } return result } // getAllDownstreamNodes returns all nodes that come after the given node in the workflow func (tm *TaskManager) getAllDownstreamNodes(nodeID string) []string { visited := make(map[string]bool) var result []string // Use BFS to find all downstream nodes queue := []string{nodeID} visited[nodeID] = true for len(queue) > 0 { currentNodeID := queue[0] queue = queue[1:] // Get all nodes that this node points to if node, exists := tm.dag.nodes.Get(currentNodeID); exists { for _, edge := range node.Edges { if edge.Type == Simple || edge.Type == Iterator { targetNodeID := edge.To.ID if !visited[targetNodeID] { visited[targetNodeID] = true result = append(result, targetNodeID) queue = append(queue, targetNodeID) } } } } } return result } // handleTargetNodeDependencies handles the dependencies of the target node during reset // If the target node has unmet dependencies, we mark them as completed to allow the workflow to proceed func (tm *TaskManager) handleTargetNodeDependencies(targetNodeID string, nr nodeResult) { // Get the dependencies of the target node prevNodes, err := tm.dag.GetPreviousNodes(targetNodeID) if err != nil { tm.dag.Logger().Error("Error getting previous nodes for target", logger.Field{Key: "targetNodeID", Value: targetNodeID}, logger.Field{Key: "error", Value: err.Error()}) return } if tm.dag.debug { tm.dag.Logger().Info("Checking dependencies for target node", logger.Field{Key: "targetNodeID", Value: targetNodeID}, logger.Field{Key: "dependencyCount", Value: len(prevNodes)}) } // Check each dependency and ensure it's marked as completed for reset for _, prevNode := range prevNodes { // Check both the pure node ID and the indexed node ID for state state, exists := tm.taskStates.Get(prevNode.ID) if !exists { // Also check if there's a state with an index suffix tm.taskStates.ForEach(func(key string, s *TaskState) bool { if strings.Split(key, Delimiter)[0] == prevNode.ID { state = s exists = true return false // Stop iteration } return true }) } if !exists { // Create new state and mark as completed for reset newState := newTaskState(prevNode.ID) newState.Status = mq.Completed newState.UpdatedAt = time.Now() newState.Result = mq.Result{ Status: mq.Completed, Ctx: nr.ctx, } tm.taskStates.Set(prevNode.ID, newState) if tm.dag.debug { tm.dag.Logger().Debug("Created completed state for dependency node during reset", logger.Field{Key: "dependencyNodeID", Value: prevNode.ID}) } } else if state.Status != mq.Completed { // Mark existing state as completed for reset state.Status = mq.Completed state.UpdatedAt = time.Now() if state.Result.Status == "" { state.Result = mq.Result{ Status: mq.Completed, Ctx: nr.ctx, } } if tm.dag.debug { tm.dag.Logger().Debug("Marked dependency node as completed during reset", logger.Field{Key: "dependencyNodeID", Value: prevNode.ID}, logger.Field{Key: "previousStatus", Value: string(state.Status)}) } } else { if tm.dag.debug { tm.dag.Logger().Debug("Dependency already satisfied", logger.Field{Key: "dependencyNodeID", Value: prevNode.ID}, logger.Field{Key: "status", Value: string(state.Status)}) } } // Ensure cached result exists for this dependency if _, hasResult := tm.currentNodeResult.Get(prevNode.ID); !hasResult { tm.currentNodeResult.Set(prevNode.ID, mq.Result{ Status: mq.Completed, Ctx: nr.ctx, }) } // Clear any deferred tasks for this dependency since it's now satisfied tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool { if strings.Split(tsk.nodeID, Delimiter)[0] == prevNode.ID { tm.deferredTasks.Del(taskID) if tm.dag.debug { tm.dag.Logger().Debug("Cleared deferred task for satisfied dependency", logger.Field{Key: "dependencyNodeID", Value: prevNode.ID}, logger.Field{Key: "taskID", Value: taskID}) } } return true }) } }