diff --git a/dag/v2/dag.go b/dag/v2/dag.go index f3e2181..fd5d1d0 100644 --- a/dag/v2/dag.go +++ b/dag/v2/dag.go @@ -1,6 +1,7 @@ package v2 import ( + "context" "encoding/json" "fmt" "net/http" @@ -21,6 +22,7 @@ const ( ) type Result struct { + Ctx context.Context Data json.RawMessage Error error Status TaskStatus @@ -28,7 +30,7 @@ type Result struct { type Node struct { ID string - Handler func(payload json.RawMessage) Result + Handler func(ctx context.Context, payload json.RawMessage) Result Edges []Edge } @@ -63,7 +65,7 @@ func NewDAG(finalResultCallback func(taskID string, result Result)) *DAG { } } -func (tm *DAG) AddNode(nodeID string, handler func(payload json.RawMessage) Result) *DAG { +func (tm *DAG) AddNode(nodeID string, handler func(ctx context.Context, payload json.RawMessage) Result) *DAG { if tm.Error != nil { return tm } @@ -126,7 +128,7 @@ func (tm *DAG) formHandler(w http.ResponseWriter, r *http.Request) { manager := NewTaskManager(tm) tm.taskManager.Set(taskID, manager) payload := fmt.Sprintf(`{"email": "%s", "age": "%s", "gender": "%s"}`, email, age, gender) - manager.Trigger(taskID, "NodeA", json.RawMessage(payload)) + manager.ProcessTask(r.Context(), taskID, "NodeA", json.RawMessage(payload)) http.Redirect(w, r, "/result?taskID="+taskID, http.StatusFound) } } diff --git a/dag/v2/task_manager.go b/dag/v2/task_manager.go index 2cab763..41c72c7 100644 --- a/dag/v2/task_manager.go +++ b/dag/v2/task_manager.go @@ -1,6 +1,7 @@ package v2 import ( + "context" "encoding/json" "fmt" "sync" @@ -13,12 +14,13 @@ import ( type TaskState struct { NodeID string Status TaskStatus - Timestamp time.Time + UpdatedAt time.Time Result Result targetResults storage.IMap[string, Result] } type nodeResult struct { + ctx context.Context taskID string nodeID string result Result @@ -27,12 +29,13 @@ type nodeResult struct { type TaskManager struct { taskStates map[string]*TaskState dag *DAG - mu sync.Mutex + mu sync.RWMutex taskQueue chan taskExecution resultQueue chan nodeResult } type taskExecution struct { + ctx context.Context taskID string nodeID string payload json.RawMessage @@ -50,18 +53,18 @@ func NewTaskManager(dag *DAG) *TaskManager { return tm } -func (tm *TaskManager) Trigger(taskID, startNode string, payload json.RawMessage) { +func (tm *TaskManager) ProcessTask(ctx context.Context, taskID, startNode string, payload json.RawMessage) { tm.mu.Lock() tm.taskStates[startNode] = newTaskState(startNode) tm.mu.Unlock() - tm.taskQueue <- taskExecution{taskID: taskID, nodeID: startNode, payload: payload} + tm.taskQueue <- taskExecution{taskID: taskID, nodeID: startNode, payload: payload, ctx: ctx} } func newTaskState(nodeID string) *TaskState { return &TaskState{ NodeID: nodeID, Status: StatusPending, - Timestamp: time.Now(), + UpdatedAt: time.Now(), targetResults: memory.New[string, Result](), } } @@ -81,26 +84,24 @@ func (tm *TaskManager) processNode(exec taskExecution) { return } tm.mu.Lock() + defer tm.mu.Unlock() state := tm.taskStates[exec.nodeID] if state == nil { state = newTaskState(exec.nodeID) tm.taskStates[exec.nodeID] = state } state.Status = StatusProcessing - state.Timestamp = time.Now() - tm.mu.Unlock() - result := node.Handler(exec.payload) - tm.mu.Lock() - state.Timestamp = time.Now() + state.UpdatedAt = time.Now() + result := node.Handler(context.Background(), exec.payload) + state.UpdatedAt = time.Now() state.Result = result state.Status = result.Status - tm.mu.Unlock() if result.Status == StatusFailed { fmt.Printf("Task %s failed at node %s: %v\n", exec.taskID, exec.nodeID, result.Error) tm.processFinalResult(exec.taskID, state) return } - tm.resultQueue <- nodeResult{taskID: exec.taskID, nodeID: exec.nodeID, result: result} + tm.resultQueue <- nodeResult{taskID: exec.taskID, nodeID: exec.nodeID, result: result, ctx: exec.ctx} } func (tm *TaskManager) WaitForResult() { @@ -123,7 +124,7 @@ func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) { tm.taskStates[edge.To.ID] = newTaskState(edge.To.ID) } tm.mu.Unlock() - tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: edge.To.ID, payload: nodeResult.result.Data} + tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: edge.To.ID, payload: nodeResult.result.Data, ctx: nodeResult.ctx} } } else { parentNodes, err := tm.dag.GetPreviousNodes(nodeResult.nodeID) diff --git a/examples/v2.go b/examples/v2.go index dc52e4d..f2ae51c 100644 --- a/examples/v2.go +++ b/examples/v2.go @@ -1,13 +1,33 @@ package main import ( + "context" "encoding/json" "fmt" + "github.com/oarkflow/jet" + v2 "github.com/oarkflow/mq/dag/v2" ) -func NodeA(payload json.RawMessage) v2.Result { +func Form(ctx context.Context, payload json.RawMessage) v2.Result { + var data map[string]any + if err := json.Unmarshal(payload, &data); err != nil { + return v2.Result{Error: err, Status: v2.StatusFailed} + } + if templateFile, ok := data["html_content"].(string); ok { + parser := jet.NewWithMemory(jet.WithDelims("{{", "}}")) + rs, err := parser.ParseTemplate(templateFile, data) + if err != nil { + return v2.Result{Error: err, Status: v2.StatusFailed} + } + ctx = context.WithValue(ctx, "Content-Type", "text/html; charset/utf-8") + return v2.Result{Data: []byte(rs), Status: v2.StatusCompleted, Ctx: ctx} + } + return v2.Result{Data: payload, Status: v2.StatusCompleted} +} + +func NodeA(ctx context.Context, payload json.RawMessage) v2.Result { var data map[string]any if err := json.Unmarshal(payload, &data); err != nil { return v2.Result{Error: err, Status: v2.StatusFailed} @@ -17,7 +37,7 @@ func NodeA(payload json.RawMessage) v2.Result { return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} } -func NodeB(payload json.RawMessage) v2.Result { +func NodeB(ctx context.Context, payload json.RawMessage) v2.Result { var data map[string]any if err := json.Unmarshal(payload, &data); err != nil { return v2.Result{Error: err, Status: v2.StatusFailed} @@ -27,7 +47,7 @@ func NodeB(payload json.RawMessage) v2.Result { return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} } -func NodeC(payload json.RawMessage) v2.Result { +func NodeC(ctx context.Context, payload json.RawMessage) v2.Result { var data map[string]any if err := json.Unmarshal(payload, &data); err != nil { return v2.Result{Error: err, Status: v2.StatusFailed} @@ -37,10 +57,20 @@ func NodeC(payload json.RawMessage) v2.Result { return v2.Result{Data: updatedPayload, Status: v2.StatusCompleted} } -func Result(payload json.RawMessage) v2.Result { +func Result(ctx context.Context, payload json.RawMessage) v2.Result { var data map[string]any - json.Unmarshal(payload, &data) - + if err := json.Unmarshal(payload, &data); err != nil { + return v2.Result{Error: err, Status: v2.StatusFailed} + } + if templateFile, ok := data["html_content"].(string); ok { + parser := jet.NewWithMemory(jet.WithDelims("{{", "}}")) + rs, err := parser.ParseTemplate(templateFile, data) + if err != nil { + return v2.Result{Error: err, Status: v2.StatusFailed} + } + ctx = context.WithValue(ctx, "Content-Type", "text/html; charset/utf-8") + return v2.Result{Data: []byte(rs), Status: v2.StatusCompleted, Ctx: ctx} + } return v2.Result{Data: payload, Status: v2.StatusCompleted} } @@ -50,10 +80,12 @@ func notify(taskID string, result v2.Result) { func main() { dag := v2.NewDAG(notify) + // dag.AddNode("Form", Form) dag.AddNode("NodeA", NodeA) dag.AddNode("NodeB", NodeB) dag.AddNode("NodeC", NodeC) dag.AddNode("Result", Result) + // dag.AddEdge("Form", "NodeA") dag.AddEdge("NodeA", "NodeB") dag.AddEdge("NodeB", "NodeC") dag.AddEdge("NodeC", "Result")