feat: add task completion

This commit is contained in:
sujit
2024-11-18 11:01:21 +05:45
parent f9b09272e7
commit e07b8ed3fd

View File

@@ -12,6 +12,40 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
type DAG struct {
Nodes map[string]*Node
Edges map[string][]string
ParentNodes map[string]string
taskManager map[string]*TaskManager
mu sync.Mutex
finalResult func(taskID string, result Result)
}
func NewDAG(finalResultCallback func(taskID string, result Result)) *DAG {
return &DAG{
Nodes: make(map[string]*Node),
Edges: make(map[string][]string),
ParentNodes: make(map[string]string),
taskManager: make(map[string]*TaskManager),
finalResult: finalResultCallback,
}
}
func (tm *DAG) AddNode(nodeID string, handler func(payload json.RawMessage) Result) {
tm.mu.Lock()
defer tm.mu.Unlock()
tm.Nodes[nodeID] = &Node{ID: nodeID, Handler: handler}
}
func (tm *DAG) AddEdge(from string, to ...string) {
tm.mu.Lock()
defer tm.mu.Unlock()
tm.Edges[from] = append(tm.Edges[from], to...)
for _, targetNode := range to {
tm.ParentNodes[targetNode] = from
}
}
type TaskStatus string type TaskStatus string
const ( const (
@@ -48,14 +82,11 @@ type nodeResult struct {
} }
type TaskManager struct { type TaskManager struct {
Nodes map[string]*Node taskStates map[string]*TaskState
Edges map[string][]string dag *DAG
ParentNodes map[string]string
TaskStates map[string]map[string]*TaskState
mu sync.Mutex mu sync.Mutex
taskQueue chan taskExecution taskQueue chan taskExecution
resultQueue chan nodeResult resultQueue chan nodeResult
finalResult func(taskID string, result Result)
} }
type taskExecution struct { type taskExecution struct {
@@ -64,41 +95,20 @@ type taskExecution struct {
payload json.RawMessage payload json.RawMessage
} }
func NewTaskManager(finalResultCallback func(taskID string, result Result)) *TaskManager { func NewTaskManager(dag *DAG) *TaskManager {
tm := &TaskManager{ tm := &TaskManager{
Nodes: make(map[string]*Node), taskStates: make(map[string]*TaskState),
Edges: make(map[string][]string),
ParentNodes: make(map[string]string),
TaskStates: make(map[string]map[string]*TaskState),
taskQueue: make(chan taskExecution, 100), taskQueue: make(chan taskExecution, 100),
resultQueue: make(chan nodeResult, 100), resultQueue: make(chan nodeResult, 100),
finalResult: finalResultCallback, dag: dag,
} }
go tm.WaitForResult() go tm.WaitForResult()
return tm return tm
} }
func (tm *TaskManager) AddNode(nodeID string, handler func(payload json.RawMessage) Result) {
tm.mu.Lock()
defer tm.mu.Unlock()
tm.Nodes[nodeID] = &Node{ID: nodeID, Handler: handler}
}
func (tm *TaskManager) AddEdge(from string, to ...string) {
tm.mu.Lock()
defer tm.mu.Unlock()
tm.Edges[from] = append(tm.Edges[from], to...)
for _, targetNode := range to {
tm.ParentNodes[targetNode] = from
}
}
func (tm *TaskManager) Trigger(taskID, startNode string, payload json.RawMessage) { func (tm *TaskManager) Trigger(taskID, startNode string, payload json.RawMessage) {
tm.mu.Lock() tm.mu.Lock()
if _, exists := tm.TaskStates[taskID]; !exists { tm.taskStates[startNode] = &TaskState{
tm.TaskStates[taskID] = make(map[string]*TaskState)
}
tm.TaskStates[taskID][startNode] = &TaskState{
NodeID: startNode, NodeID: startNode,
Status: StatusPending, Status: StatusPending,
Timestamp: time.Now(), Timestamp: time.Now(),
@@ -117,16 +127,16 @@ func (tm *TaskManager) Run() {
} }
func (tm *TaskManager) processNode(exec taskExecution) { func (tm *TaskManager) processNode(exec taskExecution) {
node, exists := tm.Nodes[exec.nodeID] node, exists := tm.dag.Nodes[exec.nodeID]
if !exists { if !exists {
fmt.Printf("Node %s does not exist\n", exec.nodeID) fmt.Printf("Node %s does not exist\n", exec.nodeID)
return return
} }
tm.mu.Lock() tm.mu.Lock()
state := tm.TaskStates[exec.taskID][exec.nodeID] state := tm.taskStates[exec.nodeID]
if state == nil { if state == nil {
state = &TaskState{NodeID: exec.nodeID, Status: StatusPending, Timestamp: time.Now(), targetResults: make(map[string]Result)} state = &TaskState{NodeID: exec.nodeID, Status: StatusPending, Timestamp: time.Now(), targetResults: make(map[string]Result)}
tm.TaskStates[exec.taskID][exec.nodeID] = state tm.taskStates[exec.nodeID] = state
} }
state.Status = StatusProcessing state.Status = StatusProcessing
state.Timestamp = time.Now() state.Timestamp = time.Now()
@@ -154,12 +164,12 @@ func (tm *TaskManager) WaitForResult() {
} }
func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) { func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) {
nextNodes := tm.Edges[nodeResult.nodeID] nextNodes := tm.dag.Edges[nodeResult.nodeID]
if len(nextNodes) > 0 { if len(nextNodes) > 0 {
for _, nextNodeID := range nextNodes { for _, nextNodeID := range nextNodes {
tm.mu.Lock() tm.mu.Lock()
if _, exists := tm.TaskStates[nodeResult.taskID][nextNodeID]; !exists { if _, exists := tm.taskStates[nextNodeID]; !exists {
tm.TaskStates[nodeResult.taskID][nextNodeID] = &TaskState{ tm.taskStates[nextNodeID] = &TaskState{
NodeID: nextNodeID, NodeID: nextNodeID,
Status: StatusPending, Status: StatusPending,
Timestamp: time.Now(), Timestamp: time.Now(),
@@ -170,30 +180,30 @@ func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) {
tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: nextNodeID, payload: nodeResult.result.Data} tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: nextNodeID, payload: nodeResult.result.Data}
} }
} else { } else {
parentNode := tm.ParentNodes[nodeResult.nodeID] parentNode := tm.dag.ParentNodes[nodeResult.nodeID]
if parentNode != "" { if parentNode != "" {
tm.mu.Lock() tm.mu.Lock()
state := tm.TaskStates[nodeResult.taskID][parentNode] state := tm.taskStates[parentNode]
if state == nil { if state == nil {
state = &TaskState{NodeID: parentNode, Status: StatusPending, Timestamp: time.Now(), targetResults: make(map[string]Result)} state = &TaskState{NodeID: parentNode, Status: StatusPending, Timestamp: time.Now(), targetResults: make(map[string]Result)}
tm.TaskStates[nodeResult.taskID][parentNode] = state tm.taskStates[parentNode] = state
} }
state.targetResults[nodeResult.nodeID] = nodeResult.result state.targetResults[nodeResult.nodeID] = nodeResult.result
allTargetNodesdone := len(tm.Edges[parentNode]) == len(state.targetResults) allTargetNodesdone := len(tm.dag.Edges[parentNode]) == len(state.targetResults)
tm.mu.Unlock() tm.mu.Unlock()
if tm.areAllTargetNodesCompleted(parentNode, nodeResult.taskID) && allTargetNodesdone { if tm.areAllTargetNodesCompleted(parentNode) && allTargetNodesdone {
tm.aggregateResults(parentNode, nodeResult.taskID) tm.aggregateResults(parentNode, nodeResult.taskID)
} }
} }
} }
} }
func (tm *TaskManager) areAllTargetNodesCompleted(parentNode string, taskID string) bool { func (tm *TaskManager) areAllTargetNodesCompleted(parentNode string) bool {
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
for _, targetNode := range tm.Edges[parentNode] { for _, targetNode := range tm.dag.Edges[parentNode] {
state := tm.TaskStates[taskID][targetNode] state := tm.taskStates[targetNode]
if state == nil || state.Status != StatusCompleted { if state == nil || state.Status != StatusCompleted {
return false return false
} }
@@ -204,7 +214,7 @@ func (tm *TaskManager) areAllTargetNodesCompleted(parentNode string, taskID stri
func (tm *TaskManager) aggregateResults(parentNode string, taskID string) { func (tm *TaskManager) aggregateResults(parentNode string, taskID string) {
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
state := tm.TaskStates[taskID][parentNode] state := tm.taskStates[parentNode]
if len(state.targetResults) > 1 { if len(state.targetResults) > 1 {
aggregatedData := make([]json.RawMessage, len(state.targetResults)) aggregatedData := make([]json.RawMessage, len(state.targetResults))
i := 0 i := 0
@@ -222,7 +232,7 @@ func (tm *TaskManager) aggregateResults(parentNode string, taskID string) {
func (tm *TaskManager) processFinalResult(taskID string, state *TaskState) { func (tm *TaskManager) processFinalResult(taskID string, state *TaskState) {
clear(state.targetResults) clear(state.targetResults)
tm.finalResult(taskID, state.Result) tm.dag.finalResult(taskID, state.Result)
} }
func finalResultCallback(taskID string, result Result) { func finalResultCallback(taskID string, result Result) {
@@ -233,7 +243,7 @@ func generateTaskID() string {
return strconv.Itoa(rand.Intn(100000)) return strconv.Itoa(rand.Intn(100000))
} }
func (tm *TaskManager) formHandler(w http.ResponseWriter, r *http.Request) { func (tm *DAG) formHandler(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" { if r.Method == "GET" {
http.ServeFile(w, r, "webroot/form.html") http.ServeFile(w, r, "webroot/form.html")
} else if r.Method == "POST" { } else if r.Method == "POST" {
@@ -242,36 +252,41 @@ func (tm *TaskManager) formHandler(w http.ResponseWriter, r *http.Request) {
age := r.FormValue("age") age := r.FormValue("age")
gender := r.FormValue("gender") gender := r.FormValue("gender")
taskID := generateTaskID() taskID := generateTaskID()
manager := NewTaskManager(tm)
tm.mu.Lock()
tm.taskManager[taskID] = manager
tm.mu.Unlock()
go manager.Run()
payload := fmt.Sprintf(`{"email": "%s", "age": "%s", "gender": "%s"}`, email, age, gender) payload := fmt.Sprintf(`{"email": "%s", "age": "%s", "gender": "%s"}`, email, age, gender)
tm.Trigger(taskID, "NodeA", json.RawMessage(payload)) manager.Trigger(taskID, "NodeA", json.RawMessage(payload))
http.Redirect(w, r, "/result?taskID="+taskID, http.StatusFound) http.Redirect(w, r, "/result?taskID="+taskID, http.StatusFound)
} }
} }
func (tm *TaskManager) resultHandler(w http.ResponseWriter, r *http.Request) { func (tm *DAG) resultHandler(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "webroot/result.html") http.ServeFile(w, r, "webroot/result.html")
} }
func (tm *TaskManager) taskStatusHandler(w http.ResponseWriter, r *http.Request) { func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) {
taskID := r.URL.Query().Get("taskID") taskID := r.URL.Query().Get("taskID")
if taskID == "" { if taskID == "" {
http.Error(w, "taskID is missing", http.StatusBadRequest) http.Error(w, "taskID is missing", http.StatusBadRequest)
return return
} }
tm.mu.Lock() tm.mu.Lock()
state := tm.TaskStates[taskID] manager, ok := tm.taskManager[taskID]
tm.mu.Unlock() tm.mu.Unlock()
if state == nil { if !ok {
http.Error(w, "Invalid taskID", http.StatusNotFound) http.Error(w, "Invalid taskID", http.StatusNotFound)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(state) json.NewEncoder(w).Encode(manager.taskStates)
} }
func main() { func main() {
tm := NewTaskManager(finalResultCallback) dag := NewDAG(finalResultCallback)
tm.AddNode("NodeA", func(payload json.RawMessage) Result { dag.AddNode("NodeA", func(payload json.RawMessage) Result {
var data map[string]any var data map[string]any
if err := json.Unmarshal(payload, &data); err != nil { if err := json.Unmarshal(payload, &data); err != nil {
return Result{Error: err, Status: StatusFailed} return Result{Error: err, Status: StatusFailed}
@@ -280,7 +295,7 @@ func main() {
updatedPayload, _ := json.Marshal(data) updatedPayload, _ := json.Marshal(data)
return Result{Data: updatedPayload, Status: StatusCompleted} return Result{Data: updatedPayload, Status: StatusCompleted}
}) })
tm.AddNode("NodeB", func(payload json.RawMessage) Result { dag.AddNode("NodeB", func(payload json.RawMessage) Result {
var data map[string]any var data map[string]any
if err := json.Unmarshal(payload, &data); err != nil { if err := json.Unmarshal(payload, &data); err != nil {
return Result{Error: err, Status: StatusFailed} return Result{Error: err, Status: StatusFailed}
@@ -289,7 +304,7 @@ func main() {
updatedPayload, _ := json.Marshal(data) updatedPayload, _ := json.Marshal(data)
return Result{Data: updatedPayload, Status: StatusCompleted} return Result{Data: updatedPayload, Status: StatusCompleted}
}) })
tm.AddNode("NodeC", func(payload json.RawMessage) Result { dag.AddNode("NodeC", func(payload json.RawMessage) Result {
var data map[string]any var data map[string]any
if err := json.Unmarshal(payload, &data); err != nil { if err := json.Unmarshal(payload, &data); err != nil {
return Result{Error: err, Status: StatusFailed} return Result{Error: err, Status: StatusFailed}
@@ -298,19 +313,18 @@ func main() {
updatedPayload, _ := json.Marshal(data) updatedPayload, _ := json.Marshal(data)
return Result{Data: updatedPayload, Status: StatusCompleted} return Result{Data: updatedPayload, Status: StatusCompleted}
}) })
tm.AddNode("Result", func(payload json.RawMessage) Result { dag.AddNode("Result", func(payload json.RawMessage) Result {
var data map[string]any var data map[string]any
json.Unmarshal(payload, &data) json.Unmarshal(payload, &data)
return Result{Data: payload, Status: StatusCompleted} return Result{Data: payload, Status: StatusCompleted}
}) })
tm.AddEdge("Form", "NodeA") dag.AddEdge("Form", "NodeA")
tm.AddEdge("NodeA", "NodeB") dag.AddEdge("NodeA", "NodeB")
tm.AddEdge("NodeB", "NodeC") dag.AddEdge("NodeB", "NodeC")
tm.AddEdge("NodeC", "Result") dag.AddEdge("NodeC", "Result")
http.HandleFunc("/form", tm.formHandler) http.HandleFunc("/form", dag.formHandler)
http.HandleFunc("/result", tm.resultHandler) http.HandleFunc("/result", dag.resultHandler)
http.HandleFunc("/task-result", tm.taskStatusHandler) http.HandleFunc("/task-result", dag.taskStatusHandler)
go tm.Run()
http.ListenAndServe(":8080", nil) http.ListenAndServe(":8080", nil)
} }