diff --git a/dag/dag_node.go b/dag/dag_node.go index 9b47c3a..6fb808c 100644 --- a/dag/dag_node.go +++ b/dag/dag_node.go @@ -170,11 +170,18 @@ func (tm *DAG) getCurrentNode(manager *TaskManager) string { func (tm *DAG) AddDAGNode(nodeType NodeType, name string, key string, dag *DAG, firstNode ...bool) *DAG { dag.AssignTopic(key) dag.name += fmt.Sprintf("(%s)", name) + + // Create a wrapper processor that ensures proper completion reporting for iterator patterns + processor := &DAGNodeProcessor{ + subDAG: dag, + nodeID: key, + } + tm.nodes.Set(key, &Node{ Label: name, ID: key, NodeType: nodeType, - processor: dag, + processor: processor, isReady: true, IsLast: true, // Assume it's last until edges are added }) diff --git a/dag/task_manager.go b/dag/task_manager.go index b651fb4..ed1e601 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -18,6 +18,85 @@ import ( "github.com/oarkflow/mq/storage/memory" ) +// DAGNodeProcessor wraps a sub-DAG to ensure it reports completion properly +// when used as part of an iterator pattern +type DAGNodeProcessor struct { + subDAG *DAG + nodeID string +} + +func (p *DAGNodeProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + // Process the task through the sub-DAG but capture the result + // instead of letting it go to the final callback + + // Create a result channel to capture the sub-DAG's result + resultCh := make(chan mq.Result, 1) + + // Temporarily replace the sub-DAG's final result callback + originalCallback := p.subDAG.finalResult + p.subDAG.finalResult = func(taskID string, result mq.Result) { + resultCh <- result + } + + // Process through the sub-DAG + result := p.subDAG.Process(ctx, task.Payload) + + // Restore the original callback + p.subDAG.finalResult = originalCallback + + // If the sub-DAG completed immediately, return the result + if result.Status == mq.Completed || result.Error != nil { + return result + } + + // Otherwise wait for the final result from the callback + select { + case finalResult := <-resultCh: + return finalResult + case <-ctx.Done(): + return mq.Result{Error: ctx.Err(), Status: mq.Failed} + } +} + +func (p *DAGNodeProcessor) Consume(ctx context.Context) error { + // No-op for DAG nodes since they're processed directly + return nil +} + +func (p *DAGNodeProcessor) Pause(ctx context.Context) error { + // No-op for DAG nodes + return nil +} + +func (p *DAGNodeProcessor) Resume(ctx context.Context) error { + // No-op for DAG nodes + return nil +} + +func (p *DAGNodeProcessor) Stop(ctx context.Context) error { + return p.subDAG.Stop(ctx) +} + +func (p *DAGNodeProcessor) Close() error { + return p.subDAG.Stop(context.Background()) +} + +func (p *DAGNodeProcessor) GetType() string { + return "DAGNodeProcessor" +} + +func (p *DAGNodeProcessor) GetKey() string { + return p.nodeID +} + +func (p *DAGNodeProcessor) SetKey(key string) { + p.nodeID = key +} + +func (p *DAGNodeProcessor) SetNotifyResponse(callback mq.Callback) { + // Sub-DAG already has its own callback +} + // TaskError is used by node processors to indicate whether an error is recoverable. type TaskError struct { Err error @@ -139,6 +218,13 @@ func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payloa } func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string, payload json.RawMessage) { + if tm.dag.debug { + tm.dag.Logger().Info("enqueueTask called", + logger.Field{Key: "startNode", Value: startNode}, + logger.Field{Key: "taskID", Value: taskID}, + logger.Field{Key: "payloadSize", Value: len(payload)}) + } + if index, ok := ctx.Value(ContextIndex).(string); ok { base := strings.Split(startNode, Delimiter)[0] startNode = fmt.Sprintf("%s%s%s", base, Delimiter, index) @@ -466,7 +552,11 @@ func (tm *TaskManager) processNode(exec *task) { if err != nil { tm.dag.Logger().Error("Error checking if node is last", logger.Field{Key: "nodeID", Value: pureNodeID}, logger.Field{Key: "error", Value: err.Error()}) } else if isLast { - result.Last = true + // Check if this node has a parent (part of iterator pattern) + // If it has a parent, it should not be treated as a final node + if _, hasParent := tm.parentNodes.Get(exec.nodeID); !hasParent { + result.Last = true + } } tm.currentNodeResult.Set(exec.nodeID, result) tm.logNodeExecution(exec, pureNodeID, result, nodeLatency) @@ -535,10 +625,24 @@ func (tm *TaskManager) updateTimestamps(rs *mq.Result) { } func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) { + if tm.dag.debug { + tm.dag.Logger().Info("handlePrevious called", + logger.Field{Key: "parentNodeID", Value: state.NodeID}, + logger.Field{Key: "childNode", Value: childNode}) + } + state.targetResults.Set(childNode, result) state.targetResults.Del(state.NodeID) targetsCount, _ := tm.childNodes.Get(state.NodeID) size := state.targetResults.Size() + + if tm.dag.debug { + tm.dag.Logger().Info("Aggregation check", + logger.Field{Key: "parentNodeID", Value: state.NodeID}, + logger.Field{Key: "targetsCount", Value: targetsCount}, + logger.Field{Key: "currentSize", Value: size}) + } + if size == targetsCount { if size > 1 { aggregated := make([]json.RawMessage, size) @@ -572,7 +676,8 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res } if parentKey, ok := tm.parentNodes.Get(state.NodeID); ok { parts := strings.Split(state.NodeID, Delimiter) - if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed { + // For iterator nodes, only continue to next edge after ALL children have completed and been aggregated + if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed && size == targetsCount { state.Status = mq.Processing tm.iteratorNodes.Del(parts[0]) state.targetResults.Clear() @@ -668,6 +773,13 @@ func (tm *TaskManager) enqueueResult(nr nodeResult) { } func (tm *TaskManager) onNodeCompleted(nr nodeResult) { + if tm.dag.debug { + tm.dag.Logger().Info("onNodeCompleted called", + logger.Field{Key: "nodeID", Value: nr.nodeID}, + logger.Field{Key: "status", Value: string(nr.status)}, + logger.Field{Key: "hasError", Value: nr.result.Error != nil}) + } + nodeID := strings.Split(nr.nodeID, Delimiter)[0] node, ok := tm.dag.nodes.Get(nodeID) if !ok { @@ -684,14 +796,22 @@ func (tm *TaskManager) onNodeCompleted(nr nodeResult) { tm.handleEdges(nr, edges) return } - if index, ok := nr.ctx.Value(ContextIndex).(string); ok { - childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index) - if parentKey, exists := tm.parentNodes.Get(childNode); exists { - if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil { - tm.handlePrevious(nr.ctx, parentState, nr.result, nr.nodeID, true) - return // Don't send to resultCh if has parent - } + // Check if this is a child node from an iterator (has a parent) + if parentKey, exists := tm.parentNodes.Get(nr.nodeID); exists { + if tm.dag.debug { + tm.dag.Logger().Info("Found parent for node", + logger.Field{Key: "nodeID", Value: nr.nodeID}, + logger.Field{Key: "parentKey", Value: parentKey}) } + if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil { + tm.handlePrevious(nr.ctx, parentState, nr.result, nr.nodeID, true) + return // Don't send to resultCh if has parent + } + } + if tm.dag.debug { + tm.dag.Logger().Info("No parent found for node, sending to resultCh", + logger.Field{Key: "nodeID", Value: nr.nodeID}, + logger.Field{Key: "result_topic", Value: nr.result.Topic}) } tm.updateTimestamps(&nr.result) tm.resultCh <- nr.result @@ -769,7 +889,8 @@ func (tm *TaskManager) processSingleEdge(currentResult nodeResult, edge Edge) { if _, exists := tm.iteratorNodes.Get(edge.From.ID); !exists { return } - parentNode = edge.From.ID + // Use the actual completing node as parent, not the edge From ID + parentNode = currentResult.nodeID var items []json.RawMessage if err := json.Unmarshal(currentResult.result.Payload, &items); err != nil { log.Printf("Error unmarshalling payload for node %s: %v", edge.To.ID, err) @@ -793,7 +914,28 @@ func (tm *TaskManager) processSingleEdge(currentResult nodeResult, edge Edge) { idx, _ := currentResult.ctx.Value(ContextIndex).(string) childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx) ctx := context.WithValue(currentResult.ctx, ContextIndex, idx) - tm.parentNodes.Set(childNode, parentNode) + + // If the current result came from an iterator child that has a parent, + // we need to preserve that parent relationship for the new target node + if originalParent, hasParent := tm.parentNodes.Get(currentResult.nodeID); hasParent { + if tm.dag.debug { + tm.dag.Logger().Info("Transferring parent relationship for conditional edge", + logger.Field{Key: "originalChild", Value: currentResult.nodeID}, + logger.Field{Key: "newChild", Value: childNode}, + logger.Field{Key: "parent", Value: originalParent}) + } + // Remove the original child from parent tracking since it's being replaced by conditional target + tm.parentNodes.Del(currentResult.nodeID) + // This edge target should now report back to the original parent instead + tm.parentNodes.Set(childNode, originalParent) + } else { + if tm.dag.debug { + tm.dag.Logger().Info("No parent found for conditional edge source", + logger.Field{Key: "nodeID", Value: currentResult.nodeID}) + } + tm.parentNodes.Set(childNode, parentNode) + } + tm.enqueueTask(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload) } } diff --git a/examples/debug_dag.go b/examples/debug_dag.go new file mode 100644 index 0000000..814a505 --- /dev/null +++ b/examples/debug_dag.go @@ -0,0 +1,139 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/oarkflow/json" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" + "github.com/oarkflow/mq/examples/tasks" +) + +func subDAG() *dag.DAG { + f := dag.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) { + fmt.Printf("Sub DAG Final result for task %s: %s\n", taskID, string(result.Payload)) + }, mq.WithSyncMode(true)) + f. + AddNode(dag.Function, "Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: dag.Function}}, true). + AddNode(dag.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: dag.Function}}). + AddNode(dag.Function, "Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: dag.Function}}). + AddEdge(dag.Simple, "Store Payload to send sms", "store:data", "send:sms"). + AddEdge(dag.Simple, "Store Payload to notification", "send:sms", "notification") + return f +} + +func main() { + flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { + fmt.Printf("DAG Final result for task %s: %s\n", taskID, string(result.Payload)) + }) + flow.ConfigureMemoryStorage() + flow.AddNode(dag.Function, "GetData", "GetData", &GetData{}, true) + flow.AddNode(dag.Function, "Loop", "Loop", &Loop{}) + flow.AddNode(dag.Function, "ValidateAge", "ValidateAge", &ValidateAge{}) + flow.AddNode(dag.Function, "ValidateGender", "ValidateGender", &ValidateGender{}) + flow.AddNode(dag.Function, "Final", "Final", &Final{}) + flow.AddDAGNode(dag.Function, "Check", "persistent", subDAG()) + flow.AddEdge(dag.Simple, "GetData", "GetData", "Loop") + flow.AddEdge(dag.Iterator, "Validate age for each item", "Loop", "ValidateAge") + flow.AddCondition("ValidateAge", map[string]string{"pass": "ValidateGender", "default": "persistent"}) + flow.AddEdge(dag.Simple, "Mark as Done", "Loop", "Final") + + // Test without the Final node to see if it's causing the issue + // Let's also enable hook to see the flow + flow.SetPreProcessHook(func(ctx context.Context, node *dag.Node, taskID string, payload json.RawMessage) context.Context { + log.Printf("PRE-HOOK: Processing node %s, taskID %s, payload size: %d", node.ID, taskID, len(payload)) + return ctx + }) + + flow.SetPostProcessHook(func(ctx context.Context, node *dag.Node, taskID string, result mq.Result) { + log.Printf("POST-HOOK: Completed node %s, taskID %s, status: %v, payload size: %d", node.ID, taskID, result.Status, len(result.Payload)) + }) + + data := []byte(`[{"age": "15", "gender": "female"}, {"age": "18", "gender": "male"}]`) + if flow.Error != nil { + panic(flow.Error) + } + + rs := flow.Process(context.Background(), data) + if rs.Error != nil { + panic(rs.Error) + } + fmt.Println(rs.Status, rs.Topic, string(rs.Payload)) +} + +type GetData struct { + dag.Operation +} + +func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + log.Printf("GetData: Processing payload of size %d", len(task.Payload)) + return mq.Result{Ctx: ctx, Payload: task.Payload} +} + +type Loop struct { + dag.Operation +} + +func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + log.Printf("Loop: Processing payload of size %d", len(task.Payload)) + return mq.Result{Ctx: ctx, Payload: task.Payload} +} + +type ValidateAge struct { + dag.Operation +} + +func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + var data map[string]any + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: fmt.Errorf("ValidateAge Error: %s", err.Error()), Ctx: ctx} + } + var status string + if data["age"] == "18" { + status = "pass" + } else { + status = "default" + } + log.Printf("ValidateAge: Processing age %s, status %s", data["age"], status) + updatedPayload, _ := json.Marshal(data) + return mq.Result{Payload: updatedPayload, Ctx: ctx, ConditionStatus: status} +} + +type ValidateGender struct { + dag.Operation +} + +func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + var data map[string]any + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: fmt.Errorf("ValidateGender Error: %s", err.Error()), Ctx: ctx} + } + data["female_voter"] = data["gender"] == "female" + log.Printf("ValidateGender: Processing gender %s", data["gender"]) + updatedPayload, _ := json.Marshal(data) + return mq.Result{Payload: updatedPayload, Ctx: ctx} +} + +type Final struct { + dag.Operation +} + +func (p *Final) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + var data []map[string]any + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: fmt.Errorf("Final Error: %s", err.Error()), Ctx: ctx} + } + log.Printf("Final: Processing array with %d items", len(data)) + for i, row := range data { + row["done"] = true + data[i] = row + } + updatedPayload, err := json.Marshal(data) + if err != nil { + panic(err) + } + return mq.Result{Payload: updatedPayload, Ctx: ctx} +}