diff --git a/dag/v2/task_manager.go b/dag/v2/task_manager.go index c3c16c1..68e6697 100644 --- a/dag/v2/task_manager.go +++ b/dag/v2/task_manager.go @@ -51,14 +51,18 @@ func NewTaskManager(dag *DAG) *TaskManager { func (tm *TaskManager) Trigger(taskID, startNode string, payload json.RawMessage) { tm.mu.Lock() - tm.taskStates[startNode] = &TaskState{ - NodeID: startNode, + tm.taskStates[startNode] = newTaskState(startNode) + tm.mu.Unlock() + tm.taskQueue <- taskExecution{taskID: taskID, nodeID: startNode, payload: payload} +} + +func newTaskState(nodeID string) *TaskState { + return &TaskState{ + NodeID: nodeID, Status: StatusPending, Timestamp: time.Now(), targetResults: make(map[string]Result), } - tm.mu.Unlock() - tm.taskQueue <- taskExecution{taskID: taskID, nodeID: startNode, payload: payload} } func (tm *TaskManager) Run() { @@ -78,7 +82,7 @@ func (tm *TaskManager) processNode(exec taskExecution) { tm.mu.Lock() state := tm.taskStates[exec.nodeID] if state == nil { - state = &TaskState{NodeID: exec.nodeID, Status: StatusPending, Timestamp: time.Now(), targetResults: make(map[string]Result)} + state = newTaskState(exec.nodeID) tm.taskStates[exec.nodeID] = state } state.Status = StatusProcessing diff --git a/examples/v2.go b/examples/v2.go index c6cf893..fbe02cd 100644 --- a/examples/v2.go +++ b/examples/v2.go @@ -7,42 +7,52 @@ import ( v2 "github.com/oarkflow/mq/dag/v2" ) -func main() { - dag := v2.NewDAG(func(taskID string, result v2.Result) { - fmt.Printf("Final result for Task %s: %s\n", taskID, string(result.Data)) - }) - dag.AddNode("NodeA", func(payload json.RawMessage) v2.Result { - var data map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: err, Status: v2.StatusFailed} - } - data["allowed_voting"] = data["age"] == "18" - updatedPayload, _ := json.Marshal(data) - return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} - }) - dag.AddNode("NodeB", func(payload json.RawMessage) v2.Result { - var data map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: err, Status: v2.StatusFailed} - } - data["female_voter"] = data["gender"] == "female" - updatedPayload, _ := json.Marshal(data) - return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} - }) - dag.AddNode("NodeC", func(payload json.RawMessage) v2.Result { - var data map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: err, Status: v2.StatusFailed} - } - data["voted"] = true - updatedPayload, _ := json.Marshal(data) - return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} - }) - dag.AddNode("Result", func(payload json.RawMessage) v2.Result { - var data map[string]any - json.Unmarshal(payload, &data) +func NodeA(payload json.RawMessage) v2.Result { + var data map[string]any + if err := json.Unmarshal(payload, &data); err != nil { + return v2.Result{Error: err, Status: v2.StatusFailed} + } + data["allowed_voting"] = data["age"] == "18" + updatedPayload, _ := json.Marshal(data) + return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} +} - return v2.Result{Data: payload, Status: v2.StatusCompleted} - }) +func NodeB(payload json.RawMessage) v2.Result { + var data map[string]any + if err := json.Unmarshal(payload, &data); err != nil { + return v2.Result{Error: err, Status: v2.StatusFailed} + } + data["female_voter"] = data["gender"] == "female" + updatedPayload, _ := json.Marshal(data) + return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} +} + +func NodeC(payload json.RawMessage) v2.Result { + var data map[string]any + if err := json.Unmarshal(payload, &data); err != nil { + return v2.Result{Error: err, Status: v2.StatusFailed} + } + data["voted"] = true + updatedPayload, _ := json.Marshal(data) + return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} +} + +func Result(payload json.RawMessage) v2.Result { + var data map[string]any + json.Unmarshal(payload, &data) + + return v2.Result{Data: payload, Status: v2.StatusCompleted} +} + +func notify(taskID string, result v2.Result) { + fmt.Printf("Final result for Task %s: %s\n", taskID, string(result.Data)) +} + +func main() { + dag := v2.NewDAG(notify) + dag.AddNode("NodeA", NodeA) + dag.AddNode("NodeB", NodeB) + dag.AddNode("NodeC", NodeC) + dag.AddNode("Result", Result) dag.Start(":8080") }