feat: add task completion

This commit is contained in:
sujit
2024-11-18 14:52:07 +05:45
parent fa15a45626
commit e52a538d50
3 changed files with 57 additions and 22 deletions

View File

@@ -1,6 +1,7 @@
package v2
import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -21,6 +22,7 @@ const (
)
type Result struct {
Ctx context.Context
Data json.RawMessage
Error error
Status TaskStatus
@@ -28,7 +30,7 @@ type Result struct {
type Node struct {
ID string
Handler func(payload json.RawMessage) Result
Handler func(ctx context.Context, payload json.RawMessage) Result
Edges []Edge
}
@@ -63,7 +65,7 @@ func NewDAG(finalResultCallback func(taskID string, result Result)) *DAG {
}
}
func (tm *DAG) AddNode(nodeID string, handler func(payload json.RawMessage) Result) *DAG {
func (tm *DAG) AddNode(nodeID string, handler func(ctx context.Context, payload json.RawMessage) Result) *DAG {
if tm.Error != nil {
return tm
}
@@ -126,7 +128,7 @@ func (tm *DAG) formHandler(w http.ResponseWriter, r *http.Request) {
manager := NewTaskManager(tm)
tm.taskManager.Set(taskID, manager)
payload := fmt.Sprintf(`{"email": "%s", "age": "%s", "gender": "%s"}`, email, age, gender)
manager.Trigger(taskID, "NodeA", json.RawMessage(payload))
manager.ProcessTask(r.Context(), taskID, "NodeA", json.RawMessage(payload))
http.Redirect(w, r, "/result?taskID="+taskID, http.StatusFound)
}
}

View File

@@ -1,6 +1,7 @@
package v2
import (
"context"
"encoding/json"
"fmt"
"sync"
@@ -13,12 +14,13 @@ import (
type TaskState struct {
NodeID string
Status TaskStatus
Timestamp time.Time
UpdatedAt time.Time
Result Result
targetResults storage.IMap[string, Result]
}
type nodeResult struct {
ctx context.Context
taskID string
nodeID string
result Result
@@ -27,12 +29,13 @@ type nodeResult struct {
type TaskManager struct {
taskStates map[string]*TaskState
dag *DAG
mu sync.Mutex
mu sync.RWMutex
taskQueue chan taskExecution
resultQueue chan nodeResult
}
type taskExecution struct {
ctx context.Context
taskID string
nodeID string
payload json.RawMessage
@@ -50,18 +53,18 @@ func NewTaskManager(dag *DAG) *TaskManager {
return tm
}
func (tm *TaskManager) Trigger(taskID, startNode string, payload json.RawMessage) {
func (tm *TaskManager) ProcessTask(ctx context.Context, taskID, startNode string, payload json.RawMessage) {
tm.mu.Lock()
tm.taskStates[startNode] = newTaskState(startNode)
tm.mu.Unlock()
tm.taskQueue <- taskExecution{taskID: taskID, nodeID: startNode, payload: payload}
tm.taskQueue <- taskExecution{taskID: taskID, nodeID: startNode, payload: payload, ctx: ctx}
}
func newTaskState(nodeID string) *TaskState {
return &TaskState{
NodeID: nodeID,
Status: StatusPending,
Timestamp: time.Now(),
UpdatedAt: time.Now(),
targetResults: memory.New[string, Result](),
}
}
@@ -81,26 +84,24 @@ func (tm *TaskManager) processNode(exec taskExecution) {
return
}
tm.mu.Lock()
defer tm.mu.Unlock()
state := tm.taskStates[exec.nodeID]
if state == nil {
state = newTaskState(exec.nodeID)
tm.taskStates[exec.nodeID] = state
}
state.Status = StatusProcessing
state.Timestamp = time.Now()
tm.mu.Unlock()
result := node.Handler(exec.payload)
tm.mu.Lock()
state.Timestamp = time.Now()
state.UpdatedAt = time.Now()
result := node.Handler(context.Background(), exec.payload)
state.UpdatedAt = time.Now()
state.Result = result
state.Status = result.Status
tm.mu.Unlock()
if result.Status == StatusFailed {
fmt.Printf("Task %s failed at node %s: %v\n", exec.taskID, exec.nodeID, result.Error)
tm.processFinalResult(exec.taskID, state)
return
}
tm.resultQueue <- nodeResult{taskID: exec.taskID, nodeID: exec.nodeID, result: result}
tm.resultQueue <- nodeResult{taskID: exec.taskID, nodeID: exec.nodeID, result: result, ctx: exec.ctx}
}
func (tm *TaskManager) WaitForResult() {
@@ -123,7 +124,7 @@ func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) {
tm.taskStates[edge.To.ID] = newTaskState(edge.To.ID)
}
tm.mu.Unlock()
tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: edge.To.ID, payload: nodeResult.result.Data}
tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: edge.To.ID, payload: nodeResult.result.Data, ctx: nodeResult.ctx}
}
} else {
parentNodes, err := tm.dag.GetPreviousNodes(nodeResult.nodeID)

View File

@@ -1,13 +1,33 @@
package main
import (
"context"
"encoding/json"
"fmt"
"github.com/oarkflow/jet"
v2 "github.com/oarkflow/mq/dag/v2"
)
func NodeA(payload json.RawMessage) v2.Result {
func Form(ctx context.Context, payload json.RawMessage) v2.Result {
var data map[string]any
if err := json.Unmarshal(payload, &data); err != nil {
return v2.Result{Error: err, Status: v2.StatusFailed}
}
if templateFile, ok := data["html_content"].(string); ok {
parser := jet.NewWithMemory(jet.WithDelims("{{", "}}"))
rs, err := parser.ParseTemplate(templateFile, data)
if err != nil {
return v2.Result{Error: err, Status: v2.StatusFailed}
}
ctx = context.WithValue(ctx, "Content-Type", "text/html; charset/utf-8")
return v2.Result{Data: []byte(rs), Status: v2.StatusCompleted, Ctx: ctx}
}
return v2.Result{Data: payload, Status: v2.StatusCompleted}
}
func NodeA(ctx context.Context, payload json.RawMessage) v2.Result {
var data map[string]any
if err := json.Unmarshal(payload, &data); err != nil {
return v2.Result{Error: err, Status: v2.StatusFailed}
@@ -17,7 +37,7 @@ func NodeA(payload json.RawMessage) v2.Result {
return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted}
}
func NodeB(payload json.RawMessage) v2.Result {
func NodeB(ctx context.Context, payload json.RawMessage) v2.Result {
var data map[string]any
if err := json.Unmarshal(payload, &data); err != nil {
return v2.Result{Error: err, Status: v2.StatusFailed}
@@ -27,7 +47,7 @@ func NodeB(payload json.RawMessage) v2.Result {
return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted}
}
func NodeC(payload json.RawMessage) v2.Result {
func NodeC(ctx context.Context, payload json.RawMessage) v2.Result {
var data map[string]any
if err := json.Unmarshal(payload, &data); err != nil {
return v2.Result{Error: err, Status: v2.StatusFailed}
@@ -37,10 +57,20 @@ func NodeC(payload json.RawMessage) v2.Result {
return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted}
}
func Result(payload json.RawMessage) v2.Result {
func Result(ctx context.Context, payload json.RawMessage) v2.Result {
var data map[string]any
json.Unmarshal(payload, &data)
if err := json.Unmarshal(payload, &data); err != nil {
return v2.Result{Error: err, Status: v2.StatusFailed}
}
if templateFile, ok := data["html_content"].(string); ok {
parser := jet.NewWithMemory(jet.WithDelims("{{", "}}"))
rs, err := parser.ParseTemplate(templateFile, data)
if err != nil {
return v2.Result{Error: err, Status: v2.StatusFailed}
}
ctx = context.WithValue(ctx, "Content-Type", "text/html; charset/utf-8")
return v2.Result{Data: []byte(rs), Status: v2.StatusCompleted, Ctx: ctx}
}
return v2.Result{Data: payload, Status: v2.StatusCompleted}
}
@@ -50,10 +80,12 @@ func notify(taskID string, result v2.Result) {
func main() {
dag := v2.NewDAG(notify)
// dag.AddNode("Form", Form)
dag.AddNode("NodeA", NodeA)
dag.AddNode("NodeB", NodeB)
dag.AddNode("NodeC", NodeC)
dag.AddNode("Result", Result)
// dag.AddEdge("Form", "NodeA")
dag.AddEdge("NodeA", "NodeB")
dag.AddEdge("NodeB", "NodeC")
dag.AddEdge("NodeC", "Result")