From e07b8ed3fd8be7c1ade070829285db901601d851 Mon Sep 17 00:00:00 2001 From: sujit Date: Mon, 18 Nov 2024 11:01:21 +0545 Subject: [PATCH] feat: add task completion --- examples/v2.go | 146 +++++++++++++++++++++++++++---------------------- 1 file changed, 80 insertions(+), 66 deletions(-) diff --git a/examples/v2.go b/examples/v2.go index e71d3ca..590ee0f 100644 --- a/examples/v2.go +++ b/examples/v2.go @@ -12,6 +12,40 @@ import ( "golang.org/x/exp/maps" ) +type DAG struct { + Nodes map[string]*Node + Edges map[string][]string + ParentNodes map[string]string + taskManager map[string]*TaskManager + mu sync.Mutex + finalResult func(taskID string, result Result) +} + +func NewDAG(finalResultCallback func(taskID string, result Result)) *DAG { + return &DAG{ + Nodes: make(map[string]*Node), + Edges: make(map[string][]string), + ParentNodes: make(map[string]string), + taskManager: make(map[string]*TaskManager), + finalResult: finalResultCallback, + } +} + +func (tm *DAG) AddNode(nodeID string, handler func(payload json.RawMessage) Result) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.Nodes[nodeID] = &Node{ID: nodeID, Handler: handler} +} + +func (tm *DAG) AddEdge(from string, to ...string) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.Edges[from] = append(tm.Edges[from], to...) + for _, targetNode := range to { + tm.ParentNodes[targetNode] = from + } +} + type TaskStatus string const ( @@ -48,14 +82,11 @@ type nodeResult struct { } type TaskManager struct { - Nodes map[string]*Node - Edges map[string][]string - ParentNodes map[string]string - TaskStates map[string]map[string]*TaskState + taskStates map[string]*TaskState + dag *DAG mu sync.Mutex taskQueue chan taskExecution resultQueue chan nodeResult - finalResult func(taskID string, result Result) } type taskExecution struct { @@ -64,41 +95,20 @@ type taskExecution struct { payload json.RawMessage } -func NewTaskManager(finalResultCallback func(taskID string, result Result)) *TaskManager { +func NewTaskManager(dag *DAG) *TaskManager { tm := &TaskManager{ - Nodes: make(map[string]*Node), - Edges: make(map[string][]string), - ParentNodes: make(map[string]string), - TaskStates: make(map[string]map[string]*TaskState), + taskStates: make(map[string]*TaskState), taskQueue: make(chan taskExecution, 100), resultQueue: make(chan nodeResult, 100), - finalResult: finalResultCallback, + dag: dag, } go tm.WaitForResult() return tm } -func (tm *TaskManager) AddNode(nodeID string, handler func(payload json.RawMessage) Result) { - tm.mu.Lock() - defer tm.mu.Unlock() - tm.Nodes[nodeID] = &Node{ID: nodeID, Handler: handler} -} - -func (tm *TaskManager) AddEdge(from string, to ...string) { - tm.mu.Lock() - defer tm.mu.Unlock() - tm.Edges[from] = append(tm.Edges[from], to...) - for _, targetNode := range to { - tm.ParentNodes[targetNode] = from - } -} - func (tm *TaskManager) Trigger(taskID, startNode string, payload json.RawMessage) { tm.mu.Lock() - if _, exists := tm.TaskStates[taskID]; !exists { - tm.TaskStates[taskID] = make(map[string]*TaskState) - } - tm.TaskStates[taskID][startNode] = &TaskState{ + tm.taskStates[startNode] = &TaskState{ NodeID: startNode, Status: StatusPending, Timestamp: time.Now(), @@ -117,16 +127,16 @@ func (tm *TaskManager) Run() { } func (tm *TaskManager) processNode(exec taskExecution) { - node, exists := tm.Nodes[exec.nodeID] + node, exists := tm.dag.Nodes[exec.nodeID] if !exists { fmt.Printf("Node %s does not exist\n", exec.nodeID) return } tm.mu.Lock() - state := tm.TaskStates[exec.taskID][exec.nodeID] + state := tm.taskStates[exec.nodeID] if state == nil { state = &TaskState{NodeID: exec.nodeID, Status: StatusPending, Timestamp: time.Now(), targetResults: make(map[string]Result)} - tm.TaskStates[exec.taskID][exec.nodeID] = state + tm.taskStates[exec.nodeID] = state } state.Status = StatusProcessing state.Timestamp = time.Now() @@ -154,12 +164,12 @@ func (tm *TaskManager) WaitForResult() { } func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) { - nextNodes := tm.Edges[nodeResult.nodeID] + nextNodes := tm.dag.Edges[nodeResult.nodeID] if len(nextNodes) > 0 { for _, nextNodeID := range nextNodes { tm.mu.Lock() - if _, exists := tm.TaskStates[nodeResult.taskID][nextNodeID]; !exists { - tm.TaskStates[nodeResult.taskID][nextNodeID] = &TaskState{ + if _, exists := tm.taskStates[nextNodeID]; !exists { + tm.taskStates[nextNodeID] = &TaskState{ NodeID: nextNodeID, Status: StatusPending, Timestamp: time.Now(), @@ -170,30 +180,30 @@ func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) { tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: nextNodeID, payload: nodeResult.result.Data} } } else { - parentNode := tm.ParentNodes[nodeResult.nodeID] + parentNode := tm.dag.ParentNodes[nodeResult.nodeID] if parentNode != "" { tm.mu.Lock() - state := tm.TaskStates[nodeResult.taskID][parentNode] + state := tm.taskStates[parentNode] if state == nil { state = &TaskState{NodeID: parentNode, Status: StatusPending, Timestamp: time.Now(), targetResults: make(map[string]Result)} - tm.TaskStates[nodeResult.taskID][parentNode] = state + tm.taskStates[parentNode] = state } state.targetResults[nodeResult.nodeID] = nodeResult.result - allTargetNodesdone := len(tm.Edges[parentNode]) == len(state.targetResults) + allTargetNodesdone := len(tm.dag.Edges[parentNode]) == len(state.targetResults) tm.mu.Unlock() - if tm.areAllTargetNodesCompleted(parentNode, nodeResult.taskID) && allTargetNodesdone { + if tm.areAllTargetNodesCompleted(parentNode) && allTargetNodesdone { tm.aggregateResults(parentNode, nodeResult.taskID) } } } } -func (tm *TaskManager) areAllTargetNodesCompleted(parentNode string, taskID string) bool { +func (tm *TaskManager) areAllTargetNodesCompleted(parentNode string) bool { tm.mu.Lock() defer tm.mu.Unlock() - for _, targetNode := range tm.Edges[parentNode] { - state := tm.TaskStates[taskID][targetNode] + for _, targetNode := range tm.dag.Edges[parentNode] { + state := tm.taskStates[targetNode] if state == nil || state.Status != StatusCompleted { return false } @@ -204,7 +214,7 @@ func (tm *TaskManager) areAllTargetNodesCompleted(parentNode string, taskID stri func (tm *TaskManager) aggregateResults(parentNode string, taskID string) { tm.mu.Lock() defer tm.mu.Unlock() - state := tm.TaskStates[taskID][parentNode] + state := tm.taskStates[parentNode] if len(state.targetResults) > 1 { aggregatedData := make([]json.RawMessage, len(state.targetResults)) i := 0 @@ -222,7 +232,7 @@ func (tm *TaskManager) aggregateResults(parentNode string, taskID string) { func (tm *TaskManager) processFinalResult(taskID string, state *TaskState) { clear(state.targetResults) - tm.finalResult(taskID, state.Result) + tm.dag.finalResult(taskID, state.Result) } func finalResultCallback(taskID string, result Result) { @@ -233,7 +243,7 @@ func generateTaskID() string { return strconv.Itoa(rand.Intn(100000)) } -func (tm *TaskManager) formHandler(w http.ResponseWriter, r *http.Request) { +func (tm *DAG) formHandler(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { http.ServeFile(w, r, "webroot/form.html") } else if r.Method == "POST" { @@ -242,36 +252,41 @@ func (tm *TaskManager) formHandler(w http.ResponseWriter, r *http.Request) { age := r.FormValue("age") gender := r.FormValue("gender") taskID := generateTaskID() + manager := NewTaskManager(tm) + tm.mu.Lock() + tm.taskManager[taskID] = manager + tm.mu.Unlock() + go manager.Run() payload := fmt.Sprintf(`{"email": "%s", "age": "%s", "gender": "%s"}`, email, age, gender) - tm.Trigger(taskID, "NodeA", json.RawMessage(payload)) + manager.Trigger(taskID, "NodeA", json.RawMessage(payload)) http.Redirect(w, r, "/result?taskID="+taskID, http.StatusFound) } } -func (tm *TaskManager) resultHandler(w http.ResponseWriter, r *http.Request) { +func (tm *DAG) resultHandler(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "webroot/result.html") } -func (tm *TaskManager) taskStatusHandler(w http.ResponseWriter, r *http.Request) { +func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) { taskID := r.URL.Query().Get("taskID") if taskID == "" { http.Error(w, "taskID is missing", http.StatusBadRequest) return } tm.mu.Lock() - state := tm.TaskStates[taskID] + manager, ok := tm.taskManager[taskID] tm.mu.Unlock() - if state == nil { + if !ok { http.Error(w, "Invalid taskID", http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(state) + json.NewEncoder(w).Encode(manager.taskStates) } func main() { - tm := NewTaskManager(finalResultCallback) - tm.AddNode("NodeA", func(payload json.RawMessage) Result { + dag := NewDAG(finalResultCallback) + dag.AddNode("NodeA", func(payload json.RawMessage) Result { var data map[string]any if err := json.Unmarshal(payload, &data); err != nil { return Result{Error: err, Status: StatusFailed} @@ -280,7 +295,7 @@ func main() { updatedPayload, _ := json.Marshal(data) return Result{Data: updatedPayload, Status: StatusCompleted} }) - tm.AddNode("NodeB", func(payload json.RawMessage) Result { + dag.AddNode("NodeB", func(payload json.RawMessage) Result { var data map[string]any if err := json.Unmarshal(payload, &data); err != nil { return Result{Error: err, Status: StatusFailed} @@ -289,7 +304,7 @@ func main() { updatedPayload, _ := json.Marshal(data) return Result{Data: updatedPayload, Status: StatusCompleted} }) - tm.AddNode("NodeC", func(payload json.RawMessage) Result { + dag.AddNode("NodeC", func(payload json.RawMessage) Result { var data map[string]any if err := json.Unmarshal(payload, &data); err != nil { return Result{Error: err, Status: StatusFailed} @@ -298,19 +313,18 @@ func main() { updatedPayload, _ := json.Marshal(data) return Result{Data: updatedPayload, Status: StatusCompleted} }) - tm.AddNode("Result", func(payload json.RawMessage) Result { + dag.AddNode("Result", func(payload json.RawMessage) Result { var data map[string]any json.Unmarshal(payload, &data) return Result{Data: payload, Status: StatusCompleted} }) - tm.AddEdge("Form", "NodeA") - tm.AddEdge("NodeA", "NodeB") - tm.AddEdge("NodeB", "NodeC") - tm.AddEdge("NodeC", "Result") - http.HandleFunc("/form", tm.formHandler) - http.HandleFunc("/result", tm.resultHandler) - http.HandleFunc("/task-result", tm.taskStatusHandler) - go tm.Run() + dag.AddEdge("Form", "NodeA") + dag.AddEdge("NodeA", "NodeB") + dag.AddEdge("NodeB", "NodeC") + dag.AddEdge("NodeC", "Result") + http.HandleFunc("/form", dag.formHandler) + http.HandleFunc("/result", dag.resultHandler) + http.HandleFunc("/task-result", dag.taskStatusHandler) http.ListenAndServe(":8080", nil) }