mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-25 01:20:22 +08:00
update
This commit is contained in:
@@ -170,11 +170,18 @@ func (tm *DAG) getCurrentNode(manager *TaskManager) string {
|
||||
func (tm *DAG) AddDAGNode(nodeType NodeType, name string, key string, dag *DAG, firstNode ...bool) *DAG {
|
||||
dag.AssignTopic(key)
|
||||
dag.name += fmt.Sprintf("(%s)", name)
|
||||
|
||||
// Create a wrapper processor that ensures proper completion reporting for iterator patterns
|
||||
processor := &DAGNodeProcessor{
|
||||
subDAG: dag,
|
||||
nodeID: key,
|
||||
}
|
||||
|
||||
tm.nodes.Set(key, &Node{
|
||||
Label: name,
|
||||
ID: key,
|
||||
NodeType: nodeType,
|
||||
processor: dag,
|
||||
processor: processor,
|
||||
isReady: true,
|
||||
IsLast: true, // Assume it's last until edges are added
|
||||
})
|
||||
|
@@ -18,6 +18,85 @@ import (
|
||||
"github.com/oarkflow/mq/storage/memory"
|
||||
)
|
||||
|
||||
// DAGNodeProcessor wraps a sub-DAG to ensure it reports completion properly
|
||||
// when used as part of an iterator pattern
|
||||
type DAGNodeProcessor struct {
|
||||
subDAG *DAG
|
||||
nodeID string
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
// Process the task through the sub-DAG but capture the result
|
||||
// instead of letting it go to the final callback
|
||||
|
||||
// Create a result channel to capture the sub-DAG's result
|
||||
resultCh := make(chan mq.Result, 1)
|
||||
|
||||
// Temporarily replace the sub-DAG's final result callback
|
||||
originalCallback := p.subDAG.finalResult
|
||||
p.subDAG.finalResult = func(taskID string, result mq.Result) {
|
||||
resultCh <- result
|
||||
}
|
||||
|
||||
// Process through the sub-DAG
|
||||
result := p.subDAG.Process(ctx, task.Payload)
|
||||
|
||||
// Restore the original callback
|
||||
p.subDAG.finalResult = originalCallback
|
||||
|
||||
// If the sub-DAG completed immediately, return the result
|
||||
if result.Status == mq.Completed || result.Error != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
// Otherwise wait for the final result from the callback
|
||||
select {
|
||||
case finalResult := <-resultCh:
|
||||
return finalResult
|
||||
case <-ctx.Done():
|
||||
return mq.Result{Error: ctx.Err(), Status: mq.Failed}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) Consume(ctx context.Context) error {
|
||||
// No-op for DAG nodes since they're processed directly
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) Pause(ctx context.Context) error {
|
||||
// No-op for DAG nodes
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) Resume(ctx context.Context) error {
|
||||
// No-op for DAG nodes
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) Stop(ctx context.Context) error {
|
||||
return p.subDAG.Stop(ctx)
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) Close() error {
|
||||
return p.subDAG.Stop(context.Background())
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) GetType() string {
|
||||
return "DAGNodeProcessor"
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) GetKey() string {
|
||||
return p.nodeID
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) SetKey(key string) {
|
||||
p.nodeID = key
|
||||
}
|
||||
|
||||
func (p *DAGNodeProcessor) SetNotifyResponse(callback mq.Callback) {
|
||||
// Sub-DAG already has its own callback
|
||||
}
|
||||
|
||||
// TaskError is used by node processors to indicate whether an error is recoverable.
|
||||
type TaskError struct {
|
||||
Err error
|
||||
@@ -139,6 +218,13 @@ func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payloa
|
||||
}
|
||||
|
||||
func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
|
||||
if tm.dag.debug {
|
||||
tm.dag.Logger().Info("enqueueTask called",
|
||||
logger.Field{Key: "startNode", Value: startNode},
|
||||
logger.Field{Key: "taskID", Value: taskID},
|
||||
logger.Field{Key: "payloadSize", Value: len(payload)})
|
||||
}
|
||||
|
||||
if index, ok := ctx.Value(ContextIndex).(string); ok {
|
||||
base := strings.Split(startNode, Delimiter)[0]
|
||||
startNode = fmt.Sprintf("%s%s%s", base, Delimiter, index)
|
||||
@@ -466,8 +552,12 @@ func (tm *TaskManager) processNode(exec *task) {
|
||||
if err != nil {
|
||||
tm.dag.Logger().Error("Error checking if node is last", logger.Field{Key: "nodeID", Value: pureNodeID}, logger.Field{Key: "error", Value: err.Error()})
|
||||
} else if isLast {
|
||||
// Check if this node has a parent (part of iterator pattern)
|
||||
// If it has a parent, it should not be treated as a final node
|
||||
if _, hasParent := tm.parentNodes.Get(exec.nodeID); !hasParent {
|
||||
result.Last = true
|
||||
}
|
||||
}
|
||||
tm.currentNodeResult.Set(exec.nodeID, result)
|
||||
tm.logNodeExecution(exec, pureNodeID, result, nodeLatency)
|
||||
|
||||
@@ -535,10 +625,24 @@ func (tm *TaskManager) updateTimestamps(rs *mq.Result) {
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
|
||||
if tm.dag.debug {
|
||||
tm.dag.Logger().Info("handlePrevious called",
|
||||
logger.Field{Key: "parentNodeID", Value: state.NodeID},
|
||||
logger.Field{Key: "childNode", Value: childNode})
|
||||
}
|
||||
|
||||
state.targetResults.Set(childNode, result)
|
||||
state.targetResults.Del(state.NodeID)
|
||||
targetsCount, _ := tm.childNodes.Get(state.NodeID)
|
||||
size := state.targetResults.Size()
|
||||
|
||||
if tm.dag.debug {
|
||||
tm.dag.Logger().Info("Aggregation check",
|
||||
logger.Field{Key: "parentNodeID", Value: state.NodeID},
|
||||
logger.Field{Key: "targetsCount", Value: targetsCount},
|
||||
logger.Field{Key: "currentSize", Value: size})
|
||||
}
|
||||
|
||||
if size == targetsCount {
|
||||
if size > 1 {
|
||||
aggregated := make([]json.RawMessage, size)
|
||||
@@ -572,7 +676,8 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res
|
||||
}
|
||||
if parentKey, ok := tm.parentNodes.Get(state.NodeID); ok {
|
||||
parts := strings.Split(state.NodeID, Delimiter)
|
||||
if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed {
|
||||
// For iterator nodes, only continue to next edge after ALL children have completed and been aggregated
|
||||
if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed && size == targetsCount {
|
||||
state.Status = mq.Processing
|
||||
tm.iteratorNodes.Del(parts[0])
|
||||
state.targetResults.Clear()
|
||||
@@ -668,6 +773,13 @@ func (tm *TaskManager) enqueueResult(nr nodeResult) {
|
||||
}
|
||||
|
||||
func (tm *TaskManager) onNodeCompleted(nr nodeResult) {
|
||||
if tm.dag.debug {
|
||||
tm.dag.Logger().Info("onNodeCompleted called",
|
||||
logger.Field{Key: "nodeID", Value: nr.nodeID},
|
||||
logger.Field{Key: "status", Value: string(nr.status)},
|
||||
logger.Field{Key: "hasError", Value: nr.result.Error != nil})
|
||||
}
|
||||
|
||||
nodeID := strings.Split(nr.nodeID, Delimiter)[0]
|
||||
node, ok := tm.dag.nodes.Get(nodeID)
|
||||
if !ok {
|
||||
@@ -684,14 +796,22 @@ func (tm *TaskManager) onNodeCompleted(nr nodeResult) {
|
||||
tm.handleEdges(nr, edges)
|
||||
return
|
||||
}
|
||||
if index, ok := nr.ctx.Value(ContextIndex).(string); ok {
|
||||
childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index)
|
||||
if parentKey, exists := tm.parentNodes.Get(childNode); exists {
|
||||
// Check if this is a child node from an iterator (has a parent)
|
||||
if parentKey, exists := tm.parentNodes.Get(nr.nodeID); exists {
|
||||
if tm.dag.debug {
|
||||
tm.dag.Logger().Info("Found parent for node",
|
||||
logger.Field{Key: "nodeID", Value: nr.nodeID},
|
||||
logger.Field{Key: "parentKey", Value: parentKey})
|
||||
}
|
||||
if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil {
|
||||
tm.handlePrevious(nr.ctx, parentState, nr.result, nr.nodeID, true)
|
||||
return // Don't send to resultCh if has parent
|
||||
}
|
||||
}
|
||||
if tm.dag.debug {
|
||||
tm.dag.Logger().Info("No parent found for node, sending to resultCh",
|
||||
logger.Field{Key: "nodeID", Value: nr.nodeID},
|
||||
logger.Field{Key: "result_topic", Value: nr.result.Topic})
|
||||
}
|
||||
tm.updateTimestamps(&nr.result)
|
||||
tm.resultCh <- nr.result
|
||||
@@ -769,7 +889,8 @@ func (tm *TaskManager) processSingleEdge(currentResult nodeResult, edge Edge) {
|
||||
if _, exists := tm.iteratorNodes.Get(edge.From.ID); !exists {
|
||||
return
|
||||
}
|
||||
parentNode = edge.From.ID
|
||||
// Use the actual completing node as parent, not the edge From ID
|
||||
parentNode = currentResult.nodeID
|
||||
var items []json.RawMessage
|
||||
if err := json.Unmarshal(currentResult.result.Payload, &items); err != nil {
|
||||
log.Printf("Error unmarshalling payload for node %s: %v", edge.To.ID, err)
|
||||
@@ -793,7 +914,28 @@ func (tm *TaskManager) processSingleEdge(currentResult nodeResult, edge Edge) {
|
||||
idx, _ := currentResult.ctx.Value(ContextIndex).(string)
|
||||
childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
|
||||
ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
|
||||
|
||||
// If the current result came from an iterator child that has a parent,
|
||||
// we need to preserve that parent relationship for the new target node
|
||||
if originalParent, hasParent := tm.parentNodes.Get(currentResult.nodeID); hasParent {
|
||||
if tm.dag.debug {
|
||||
tm.dag.Logger().Info("Transferring parent relationship for conditional edge",
|
||||
logger.Field{Key: "originalChild", Value: currentResult.nodeID},
|
||||
logger.Field{Key: "newChild", Value: childNode},
|
||||
logger.Field{Key: "parent", Value: originalParent})
|
||||
}
|
||||
// Remove the original child from parent tracking since it's being replaced by conditional target
|
||||
tm.parentNodes.Del(currentResult.nodeID)
|
||||
// This edge target should now report back to the original parent instead
|
||||
tm.parentNodes.Set(childNode, originalParent)
|
||||
} else {
|
||||
if tm.dag.debug {
|
||||
tm.dag.Logger().Info("No parent found for conditional edge source",
|
||||
logger.Field{Key: "nodeID", Value: currentResult.nodeID})
|
||||
}
|
||||
tm.parentNodes.Set(childNode, parentNode)
|
||||
}
|
||||
|
||||
tm.enqueueTask(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
|
||||
}
|
||||
}
|
||||
|
139
examples/debug_dag.go
Normal file
139
examples/debug_dag.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
)
|
||||
|
||||
func subDAG() *dag.DAG {
|
||||
f := dag.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) {
|
||||
fmt.Printf("Sub DAG Final result for task %s: %s\n", taskID, string(result.Payload))
|
||||
}, mq.WithSyncMode(true))
|
||||
f.
|
||||
AddNode(dag.Function, "Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: dag.Function}}, true).
|
||||
AddNode(dag.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: dag.Function}}).
|
||||
AddNode(dag.Function, "Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: dag.Function}}).
|
||||
AddEdge(dag.Simple, "Store Payload to send sms", "store:data", "send:sms").
|
||||
AddEdge(dag.Simple, "Store Payload to notification", "send:sms", "notification")
|
||||
return f
|
||||
}
|
||||
|
||||
func main() {
|
||||
flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
|
||||
fmt.Printf("DAG Final result for task %s: %s\n", taskID, string(result.Payload))
|
||||
})
|
||||
flow.ConfigureMemoryStorage()
|
||||
flow.AddNode(dag.Function, "GetData", "GetData", &GetData{}, true)
|
||||
flow.AddNode(dag.Function, "Loop", "Loop", &Loop{})
|
||||
flow.AddNode(dag.Function, "ValidateAge", "ValidateAge", &ValidateAge{})
|
||||
flow.AddNode(dag.Function, "ValidateGender", "ValidateGender", &ValidateGender{})
|
||||
flow.AddNode(dag.Function, "Final", "Final", &Final{})
|
||||
flow.AddDAGNode(dag.Function, "Check", "persistent", subDAG())
|
||||
flow.AddEdge(dag.Simple, "GetData", "GetData", "Loop")
|
||||
flow.AddEdge(dag.Iterator, "Validate age for each item", "Loop", "ValidateAge")
|
||||
flow.AddCondition("ValidateAge", map[string]string{"pass": "ValidateGender", "default": "persistent"})
|
||||
flow.AddEdge(dag.Simple, "Mark as Done", "Loop", "Final")
|
||||
|
||||
// Test without the Final node to see if it's causing the issue
|
||||
// Let's also enable hook to see the flow
|
||||
flow.SetPreProcessHook(func(ctx context.Context, node *dag.Node, taskID string, payload json.RawMessage) context.Context {
|
||||
log.Printf("PRE-HOOK: Processing node %s, taskID %s, payload size: %d", node.ID, taskID, len(payload))
|
||||
return ctx
|
||||
})
|
||||
|
||||
flow.SetPostProcessHook(func(ctx context.Context, node *dag.Node, taskID string, result mq.Result) {
|
||||
log.Printf("POST-HOOK: Completed node %s, taskID %s, status: %v, payload size: %d", node.ID, taskID, result.Status, len(result.Payload))
|
||||
})
|
||||
|
||||
data := []byte(`[{"age": "15", "gender": "female"}, {"age": "18", "gender": "male"}]`)
|
||||
if flow.Error != nil {
|
||||
panic(flow.Error)
|
||||
}
|
||||
|
||||
rs := flow.Process(context.Background(), data)
|
||||
if rs.Error != nil {
|
||||
panic(rs.Error)
|
||||
}
|
||||
fmt.Println(rs.Status, rs.Topic, string(rs.Payload))
|
||||
}
|
||||
|
||||
type GetData struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("GetData: Processing payload of size %d", len(task.Payload))
|
||||
return mq.Result{Ctx: ctx, Payload: task.Payload}
|
||||
}
|
||||
|
||||
type Loop struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("Loop: Processing payload of size %d", len(task.Payload))
|
||||
return mq.Result{Ctx: ctx, Payload: task.Payload}
|
||||
}
|
||||
|
||||
type ValidateAge struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(task.Payload, &data); err != nil {
|
||||
return mq.Result{Error: fmt.Errorf("ValidateAge Error: %s", err.Error()), Ctx: ctx}
|
||||
}
|
||||
var status string
|
||||
if data["age"] == "18" {
|
||||
status = "pass"
|
||||
} else {
|
||||
status = "default"
|
||||
}
|
||||
log.Printf("ValidateAge: Processing age %s, status %s", data["age"], status)
|
||||
updatedPayload, _ := json.Marshal(data)
|
||||
return mq.Result{Payload: updatedPayload, Ctx: ctx, ConditionStatus: status}
|
||||
}
|
||||
|
||||
type ValidateGender struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(task.Payload, &data); err != nil {
|
||||
return mq.Result{Error: fmt.Errorf("ValidateGender Error: %s", err.Error()), Ctx: ctx}
|
||||
}
|
||||
data["female_voter"] = data["gender"] == "female"
|
||||
log.Printf("ValidateGender: Processing gender %s", data["gender"])
|
||||
updatedPayload, _ := json.Marshal(data)
|
||||
return mq.Result{Payload: updatedPayload, Ctx: ctx}
|
||||
}
|
||||
|
||||
type Final struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *Final) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var data []map[string]any
|
||||
if err := json.Unmarshal(task.Payload, &data); err != nil {
|
||||
return mq.Result{Error: fmt.Errorf("Final Error: %s", err.Error()), Ctx: ctx}
|
||||
}
|
||||
log.Printf("Final: Processing array with %d items", len(data))
|
||||
for i, row := range data {
|
||||
row["done"] = true
|
||||
data[i] = row
|
||||
}
|
||||
updatedPayload, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return mq.Result{Payload: updatedPayload, Ctx: ctx}
|
||||
}
|
Reference in New Issue
Block a user