This commit is contained in:
sujit
2025-09-17 23:09:19 +05:45
parent 73dc3b276f
commit 59fc4f18aa
3 changed files with 300 additions and 12 deletions

View File

@@ -170,11 +170,18 @@ func (tm *DAG) getCurrentNode(manager *TaskManager) string {
func (tm *DAG) AddDAGNode(nodeType NodeType, name string, key string, dag *DAG, firstNode ...bool) *DAG { func (tm *DAG) AddDAGNode(nodeType NodeType, name string, key string, dag *DAG, firstNode ...bool) *DAG {
dag.AssignTopic(key) dag.AssignTopic(key)
dag.name += fmt.Sprintf("(%s)", name) dag.name += fmt.Sprintf("(%s)", name)
// Create a wrapper processor that ensures proper completion reporting for iterator patterns
processor := &DAGNodeProcessor{
subDAG: dag,
nodeID: key,
}
tm.nodes.Set(key, &Node{ tm.nodes.Set(key, &Node{
Label: name, Label: name,
ID: key, ID: key,
NodeType: nodeType, NodeType: nodeType,
processor: dag, processor: processor,
isReady: true, isReady: true,
IsLast: true, // Assume it's last until edges are added IsLast: true, // Assume it's last until edges are added
}) })

View File

@@ -18,6 +18,85 @@ import (
"github.com/oarkflow/mq/storage/memory" "github.com/oarkflow/mq/storage/memory"
) )
// DAGNodeProcessor wraps a sub-DAG to ensure it reports completion properly
// when used as part of an iterator pattern
type DAGNodeProcessor struct {
subDAG *DAG
nodeID string
}
func (p *DAGNodeProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
// Process the task through the sub-DAG but capture the result
// instead of letting it go to the final callback
// Create a result channel to capture the sub-DAG's result
resultCh := make(chan mq.Result, 1)
// Temporarily replace the sub-DAG's final result callback
originalCallback := p.subDAG.finalResult
p.subDAG.finalResult = func(taskID string, result mq.Result) {
resultCh <- result
}
// Process through the sub-DAG
result := p.subDAG.Process(ctx, task.Payload)
// Restore the original callback
p.subDAG.finalResult = originalCallback
// If the sub-DAG completed immediately, return the result
if result.Status == mq.Completed || result.Error != nil {
return result
}
// Otherwise wait for the final result from the callback
select {
case finalResult := <-resultCh:
return finalResult
case <-ctx.Done():
return mq.Result{Error: ctx.Err(), Status: mq.Failed}
}
}
func (p *DAGNodeProcessor) Consume(ctx context.Context) error {
// No-op for DAG nodes since they're processed directly
return nil
}
func (p *DAGNodeProcessor) Pause(ctx context.Context) error {
// No-op for DAG nodes
return nil
}
func (p *DAGNodeProcessor) Resume(ctx context.Context) error {
// No-op for DAG nodes
return nil
}
func (p *DAGNodeProcessor) Stop(ctx context.Context) error {
return p.subDAG.Stop(ctx)
}
func (p *DAGNodeProcessor) Close() error {
return p.subDAG.Stop(context.Background())
}
func (p *DAGNodeProcessor) GetType() string {
return "DAGNodeProcessor"
}
func (p *DAGNodeProcessor) GetKey() string {
return p.nodeID
}
func (p *DAGNodeProcessor) SetKey(key string) {
p.nodeID = key
}
func (p *DAGNodeProcessor) SetNotifyResponse(callback mq.Callback) {
// Sub-DAG already has its own callback
}
// TaskError is used by node processors to indicate whether an error is recoverable. // TaskError is used by node processors to indicate whether an error is recoverable.
type TaskError struct { type TaskError struct {
Err error Err error
@@ -139,6 +218,13 @@ func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payloa
} }
func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string, payload json.RawMessage) { func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
if tm.dag.debug {
tm.dag.Logger().Info("enqueueTask called",
logger.Field{Key: "startNode", Value: startNode},
logger.Field{Key: "taskID", Value: taskID},
logger.Field{Key: "payloadSize", Value: len(payload)})
}
if index, ok := ctx.Value(ContextIndex).(string); ok { if index, ok := ctx.Value(ContextIndex).(string); ok {
base := strings.Split(startNode, Delimiter)[0] base := strings.Split(startNode, Delimiter)[0]
startNode = fmt.Sprintf("%s%s%s", base, Delimiter, index) startNode = fmt.Sprintf("%s%s%s", base, Delimiter, index)
@@ -466,7 +552,11 @@ func (tm *TaskManager) processNode(exec *task) {
if err != nil { if err != nil {
tm.dag.Logger().Error("Error checking if node is last", logger.Field{Key: "nodeID", Value: pureNodeID}, logger.Field{Key: "error", Value: err.Error()}) tm.dag.Logger().Error("Error checking if node is last", logger.Field{Key: "nodeID", Value: pureNodeID}, logger.Field{Key: "error", Value: err.Error()})
} else if isLast { } else if isLast {
result.Last = true // Check if this node has a parent (part of iterator pattern)
// If it has a parent, it should not be treated as a final node
if _, hasParent := tm.parentNodes.Get(exec.nodeID); !hasParent {
result.Last = true
}
} }
tm.currentNodeResult.Set(exec.nodeID, result) tm.currentNodeResult.Set(exec.nodeID, result)
tm.logNodeExecution(exec, pureNodeID, result, nodeLatency) tm.logNodeExecution(exec, pureNodeID, result, nodeLatency)
@@ -535,10 +625,24 @@ func (tm *TaskManager) updateTimestamps(rs *mq.Result) {
} }
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) { func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
if tm.dag.debug {
tm.dag.Logger().Info("handlePrevious called",
logger.Field{Key: "parentNodeID", Value: state.NodeID},
logger.Field{Key: "childNode", Value: childNode})
}
state.targetResults.Set(childNode, result) state.targetResults.Set(childNode, result)
state.targetResults.Del(state.NodeID) state.targetResults.Del(state.NodeID)
targetsCount, _ := tm.childNodes.Get(state.NodeID) targetsCount, _ := tm.childNodes.Get(state.NodeID)
size := state.targetResults.Size() size := state.targetResults.Size()
if tm.dag.debug {
tm.dag.Logger().Info("Aggregation check",
logger.Field{Key: "parentNodeID", Value: state.NodeID},
logger.Field{Key: "targetsCount", Value: targetsCount},
logger.Field{Key: "currentSize", Value: size})
}
if size == targetsCount { if size == targetsCount {
if size > 1 { if size > 1 {
aggregated := make([]json.RawMessage, size) aggregated := make([]json.RawMessage, size)
@@ -572,7 +676,8 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res
} }
if parentKey, ok := tm.parentNodes.Get(state.NodeID); ok { if parentKey, ok := tm.parentNodes.Get(state.NodeID); ok {
parts := strings.Split(state.NodeID, Delimiter) parts := strings.Split(state.NodeID, Delimiter)
if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed { // For iterator nodes, only continue to next edge after ALL children have completed and been aggregated
if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed && size == targetsCount {
state.Status = mq.Processing state.Status = mq.Processing
tm.iteratorNodes.Del(parts[0]) tm.iteratorNodes.Del(parts[0])
state.targetResults.Clear() state.targetResults.Clear()
@@ -668,6 +773,13 @@ func (tm *TaskManager) enqueueResult(nr nodeResult) {
} }
func (tm *TaskManager) onNodeCompleted(nr nodeResult) { func (tm *TaskManager) onNodeCompleted(nr nodeResult) {
if tm.dag.debug {
tm.dag.Logger().Info("onNodeCompleted called",
logger.Field{Key: "nodeID", Value: nr.nodeID},
logger.Field{Key: "status", Value: string(nr.status)},
logger.Field{Key: "hasError", Value: nr.result.Error != nil})
}
nodeID := strings.Split(nr.nodeID, Delimiter)[0] nodeID := strings.Split(nr.nodeID, Delimiter)[0]
node, ok := tm.dag.nodes.Get(nodeID) node, ok := tm.dag.nodes.Get(nodeID)
if !ok { if !ok {
@@ -684,14 +796,22 @@ func (tm *TaskManager) onNodeCompleted(nr nodeResult) {
tm.handleEdges(nr, edges) tm.handleEdges(nr, edges)
return return
} }
if index, ok := nr.ctx.Value(ContextIndex).(string); ok { // Check if this is a child node from an iterator (has a parent)
childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index) if parentKey, exists := tm.parentNodes.Get(nr.nodeID); exists {
if parentKey, exists := tm.parentNodes.Get(childNode); exists { if tm.dag.debug {
if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil { tm.dag.Logger().Info("Found parent for node",
tm.handlePrevious(nr.ctx, parentState, nr.result, nr.nodeID, true) logger.Field{Key: "nodeID", Value: nr.nodeID},
return // Don't send to resultCh if has parent logger.Field{Key: "parentKey", Value: parentKey})
}
} }
if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil {
tm.handlePrevious(nr.ctx, parentState, nr.result, nr.nodeID, true)
return // Don't send to resultCh if has parent
}
}
if tm.dag.debug {
tm.dag.Logger().Info("No parent found for node, sending to resultCh",
logger.Field{Key: "nodeID", Value: nr.nodeID},
logger.Field{Key: "result_topic", Value: nr.result.Topic})
} }
tm.updateTimestamps(&nr.result) tm.updateTimestamps(&nr.result)
tm.resultCh <- nr.result tm.resultCh <- nr.result
@@ -769,7 +889,8 @@ func (tm *TaskManager) processSingleEdge(currentResult nodeResult, edge Edge) {
if _, exists := tm.iteratorNodes.Get(edge.From.ID); !exists { if _, exists := tm.iteratorNodes.Get(edge.From.ID); !exists {
return return
} }
parentNode = edge.From.ID // Use the actual completing node as parent, not the edge From ID
parentNode = currentResult.nodeID
var items []json.RawMessage var items []json.RawMessage
if err := json.Unmarshal(currentResult.result.Payload, &items); err != nil { if err := json.Unmarshal(currentResult.result.Payload, &items); err != nil {
log.Printf("Error unmarshalling payload for node %s: %v", edge.To.ID, err) log.Printf("Error unmarshalling payload for node %s: %v", edge.To.ID, err)
@@ -793,7 +914,28 @@ func (tm *TaskManager) processSingleEdge(currentResult nodeResult, edge Edge) {
idx, _ := currentResult.ctx.Value(ContextIndex).(string) idx, _ := currentResult.ctx.Value(ContextIndex).(string)
childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx) childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
ctx := context.WithValue(currentResult.ctx, ContextIndex, idx) ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
tm.parentNodes.Set(childNode, parentNode)
// If the current result came from an iterator child that has a parent,
// we need to preserve that parent relationship for the new target node
if originalParent, hasParent := tm.parentNodes.Get(currentResult.nodeID); hasParent {
if tm.dag.debug {
tm.dag.Logger().Info("Transferring parent relationship for conditional edge",
logger.Field{Key: "originalChild", Value: currentResult.nodeID},
logger.Field{Key: "newChild", Value: childNode},
logger.Field{Key: "parent", Value: originalParent})
}
// Remove the original child from parent tracking since it's being replaced by conditional target
tm.parentNodes.Del(currentResult.nodeID)
// This edge target should now report back to the original parent instead
tm.parentNodes.Set(childNode, originalParent)
} else {
if tm.dag.debug {
tm.dag.Logger().Info("No parent found for conditional edge source",
logger.Field{Key: "nodeID", Value: currentResult.nodeID})
}
tm.parentNodes.Set(childNode, parentNode)
}
tm.enqueueTask(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload) tm.enqueueTask(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
} }
} }

139
examples/debug_dag.go Normal file
View File

@@ -0,0 +1,139 @@
package main
import (
"context"
"fmt"
"log"
"github.com/oarkflow/json"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/examples/tasks"
)
func subDAG() *dag.DAG {
f := dag.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) {
fmt.Printf("Sub DAG Final result for task %s: %s\n", taskID, string(result.Payload))
}, mq.WithSyncMode(true))
f.
AddNode(dag.Function, "Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: dag.Function}}, true).
AddNode(dag.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: dag.Function}}).
AddNode(dag.Function, "Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: dag.Function}}).
AddEdge(dag.Simple, "Store Payload to send sms", "store:data", "send:sms").
AddEdge(dag.Simple, "Store Payload to notification", "send:sms", "notification")
return f
}
func main() {
flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
fmt.Printf("DAG Final result for task %s: %s\n", taskID, string(result.Payload))
})
flow.ConfigureMemoryStorage()
flow.AddNode(dag.Function, "GetData", "GetData", &GetData{}, true)
flow.AddNode(dag.Function, "Loop", "Loop", &Loop{})
flow.AddNode(dag.Function, "ValidateAge", "ValidateAge", &ValidateAge{})
flow.AddNode(dag.Function, "ValidateGender", "ValidateGender", &ValidateGender{})
flow.AddNode(dag.Function, "Final", "Final", &Final{})
flow.AddDAGNode(dag.Function, "Check", "persistent", subDAG())
flow.AddEdge(dag.Simple, "GetData", "GetData", "Loop")
flow.AddEdge(dag.Iterator, "Validate age for each item", "Loop", "ValidateAge")
flow.AddCondition("ValidateAge", map[string]string{"pass": "ValidateGender", "default": "persistent"})
flow.AddEdge(dag.Simple, "Mark as Done", "Loop", "Final")
// Test without the Final node to see if it's causing the issue
// Let's also enable hook to see the flow
flow.SetPreProcessHook(func(ctx context.Context, node *dag.Node, taskID string, payload json.RawMessage) context.Context {
log.Printf("PRE-HOOK: Processing node %s, taskID %s, payload size: %d", node.ID, taskID, len(payload))
return ctx
})
flow.SetPostProcessHook(func(ctx context.Context, node *dag.Node, taskID string, result mq.Result) {
log.Printf("POST-HOOK: Completed node %s, taskID %s, status: %v, payload size: %d", node.ID, taskID, result.Status, len(result.Payload))
})
data := []byte(`[{"age": "15", "gender": "female"}, {"age": "18", "gender": "male"}]`)
if flow.Error != nil {
panic(flow.Error)
}
rs := flow.Process(context.Background(), data)
if rs.Error != nil {
panic(rs.Error)
}
fmt.Println(rs.Status, rs.Topic, string(rs.Payload))
}
type GetData struct {
dag.Operation
}
func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
log.Printf("GetData: Processing payload of size %d", len(task.Payload))
return mq.Result{Ctx: ctx, Payload: task.Payload}
}
type Loop struct {
dag.Operation
}
func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
log.Printf("Loop: Processing payload of size %d", len(task.Payload))
return mq.Result{Ctx: ctx, Payload: task.Payload}
}
type ValidateAge struct {
dag.Operation
}
func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
var data map[string]any
if err := json.Unmarshal(task.Payload, &data); err != nil {
return mq.Result{Error: fmt.Errorf("ValidateAge Error: %s", err.Error()), Ctx: ctx}
}
var status string
if data["age"] == "18" {
status = "pass"
} else {
status = "default"
}
log.Printf("ValidateAge: Processing age %s, status %s", data["age"], status)
updatedPayload, _ := json.Marshal(data)
return mq.Result{Payload: updatedPayload, Ctx: ctx, ConditionStatus: status}
}
type ValidateGender struct {
dag.Operation
}
func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
var data map[string]any
if err := json.Unmarshal(task.Payload, &data); err != nil {
return mq.Result{Error: fmt.Errorf("ValidateGender Error: %s", err.Error()), Ctx: ctx}
}
data["female_voter"] = data["gender"] == "female"
log.Printf("ValidateGender: Processing gender %s", data["gender"])
updatedPayload, _ := json.Marshal(data)
return mq.Result{Payload: updatedPayload, Ctx: ctx}
}
type Final struct {
dag.Operation
}
func (p *Final) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
var data []map[string]any
if err := json.Unmarshal(task.Payload, &data); err != nil {
return mq.Result{Error: fmt.Errorf("Final Error: %s", err.Error()), Ctx: ctx}
}
log.Printf("Final: Processing array with %d items", len(data))
for i, row := range data {
row["done"] = true
data[i] = row
}
updatedPayload, err := json.Marshal(data)
if err != nil {
panic(err)
}
return mq.Result{Payload: updatedPayload, Ctx: ctx}
}