diff --git a/dag/task_manager.go b/dag/task_manager.go index 681d6e7..0ea8067 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -335,7 +335,8 @@ func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result { aggregatedOutput := make([]json.RawMessage, 0) - var status, topic string + var status mq.Status + var topic string var err1 error if nodeStatus.totalItems == 1 { rs := nodeStatus.itemResults.Values()[0] diff --git a/dag/v2/api.go b/dag/v2/api.go index ee5a5bd..db46702 100644 --- a/dag/v2/api.go +++ b/dag/v2/api.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq" "log" "net/http" "strings" @@ -15,7 +16,7 @@ import ( func renderNotFound(w http.ResponseWriter) { html := []byte(`
`) @@ -39,11 +40,11 @@ func (tm *DAG) render(w http.ResponseWriter, r *http.Request) { renderNotFound(w) return } - http.Error(w, fmt.Sprintf(`{"message": "%s"}`, "Task not found"), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf(`{"message": "%s"}`, "task not found"), http.StatusInternalServerError) return } } - result := tm.ProcessTask(ctx, data) + result := tm.Process(ctx, data) if result.Error != nil { http.Error(w, fmt.Sprintf(`{"message": "%s"}`, result.Error.Error()), http.StatusInternalServerError) return @@ -87,14 +88,14 @@ func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) { nodeID := strings.Split(value.NodeID, Delimiter)[0] rs := jsonparser.Delete(value.Result.Payload, "html_content") status := value.Status - if status == Processing { - status = Completed + if status == mq.Processing { + status = mq.Completed } state := TaskState{ NodeID: nodeID, Status: status, UpdatedAt: value.UpdatedAt, - Result: Result{ + Result: mq.Result{ Payload: rs, Error: value.Result.Error, Status: status, diff --git a/dag/v2/consts.go b/dag/v2/consts.go index 13f5a87..3010831 100644 --- a/dag/v2/consts.go +++ b/dag/v2/consts.go @@ -9,15 +9,6 @@ const ( RetryInterval = 5 * time.Second ) -type Status string - -const ( - Pending Status = "Pending" - Processing Status = "Processing" - Completed Status = "Completed" - Failed Status = "Failed" -) - type NodeType int func (c NodeType) IsValid() bool { return c >= Function && c <= Page } diff --git a/dag/v2/dag.go b/dag/v2/dag.go index 219e35d..dd05bb6 100644 --- a/dag/v2/dag.go +++ b/dag/v2/dag.go @@ -2,8 +2,9 @@ package v2 import ( "context" - "encoding/json" "fmt" + "github.com/oarkflow/mq/sio" + "log" "strings" "github.com/oarkflow/mq" @@ -11,20 +12,13 @@ import ( "github.com/oarkflow/mq/storage/memory" ) -type Result struct { - Ctx context.Context `json:"-"` - Payload json.RawMessage - Error error - Status Status - ConditionStatus string - Topic string -} - type Node struct { - NodeType NodeType - ID string - Handler func(ctx context.Context, payload json.RawMessage) Result - Edges []Edge + NodeType NodeType + Label string + ID string + Edges []Edge + processor mq.Processor + isReady bool } type Edge struct { @@ -34,16 +28,26 @@ type Edge struct { } type DAG struct { - nodes storage.IMap[string, *Node] - taskManager storage.IMap[string, *TaskManager] - iteratorNodes storage.IMap[string, []Edge] - finalResult func(taskID string, result Result) - Error error - startNode string - conditions map[string]map[string]string + server *mq.Broker + consumer *mq.Consumer + nodes storage.IMap[string, *Node] + taskManager storage.IMap[string, *TaskManager] + iteratorNodes storage.IMap[string, []Edge] + finalResult func(taskID string, result mq.Result) + pool *mq.Pool + name string + key string + startNode string + conditions map[string]map[string]string + consumerTopic string + reportNodeResultCallback func(mq.Result) + Error error + Notifier *sio.Server + paused bool + report string } -func NewDAG(finalResultCallback func(taskID string, result Result)) *DAG { +func NewDAG(finalResultCallback func(taskID string, result mq.Result)) *DAG { return &DAG{ nodes: memory.New[string, *Node](), taskManager: memory.New[string, *TaskManager](), @@ -53,24 +57,170 @@ func NewDAG(finalResultCallback func(taskID string, result Result)) *DAG { } } +func (tm *DAG) SetKey(key string) { + tm.key = key +} + +func (tm *DAG) ReportNodeResult(callback func(mq.Result)) { + tm.reportNodeResultCallback = callback +} + +func (tm *DAG) GetType() string { + return tm.key +} + +func (tm *DAG) Consume(ctx context.Context) error { + if tm.consumer != nil { + tm.server.Options().SetSyncMode(true) + return tm.consumer.Consume(ctx) + } + return nil +} + +func (tm *DAG) Stop(ctx context.Context) error { + tm.nodes.ForEach(func(_ string, n *Node) bool { + err := n.processor.Stop(ctx) + if err != nil { + return false + } + return true + }) + return nil +} + +func (tm *DAG) GetKey() string { + return tm.key +} + +func (tm *DAG) AssignTopic(topic string) { + tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL())) + tm.consumerTopic = topic +} + +func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) { + if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" { + manager.onNodeCompleted(nodeResult{ + ctx: ctx, + nodeID: result.Topic, + status: result.Status, + result: result, + }) + } +} + +func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) { + if tm.consumer != nil { + result.Topic = tm.consumerTopic + if tm.consumer.Conn() == nil { + tm.onTaskCallback(ctx, result) + } else { + tm.consumer.OnResponse(ctx, result) + } + } +} + +func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { + if node, ok := tm.nodes.Get(topic); ok { + log.Printf("DAG - CONSUMER ~> ready on %s", topic) + node.isReady = true + } +} + +func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) { + if node, ok := tm.nodes.Get(topic); ok { + log.Printf("DAG - CONSUMER ~> down on %s", topic) + node.isReady = false + } +} + +func (tm *DAG) Pause(_ context.Context) error { + tm.paused = true + return nil +} + +func (tm *DAG) Resume(_ context.Context) error { + tm.paused = false + return nil +} + +func (tm *DAG) Close() error { + var err error + tm.nodes.ForEach(func(_ string, n *Node) bool { + err = n.processor.Close() + if err != nil { + return false + } + return true + }) + return nil +} + +func (tm *DAG) SetStartNode(node string) { + tm.startNode = node +} + +func (tm *DAG) SetNotifyResponse(callback mq.Callback) { + tm.server.SetNotifyHandler(callback) +} + +func (tm *DAG) GetStartNode() string { + return tm.startNode +} + func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG { tm.conditions[fromNode] = conditions return tm } -type Handler func(ctx context.Context, payload json.RawMessage) Result - -func (tm *DAG) AddNode(nodeType NodeType, nodeID string, handler Handler, startNode ...bool) *DAG { +func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG { if tm.Error != nil { return tm } - tm.nodes.Set(nodeID, &Node{ID: nodeID, Handler: handler, NodeType: nodeType}) + con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask) + n := &Node{ + Label: name, + ID: nodeID, + NodeType: nodeType, + processor: con, + } + if tm.server != nil && tm.server.SyncMode() { + n.isReady = true + } + tm.nodes.Set(nodeID, n) + tm.nodes.Set(nodeID, &Node{ID: nodeID, processor: handler, NodeType: nodeType}) if len(startNode) > 0 && startNode[0] { tm.startNode = nodeID } return tm } +func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error { + if tm.server.SyncMode() { + return fmt.Errorf("DAG cannot have deferred node in Sync Mode") + } + tm.nodes.Set(key, &Node{ + Label: name, + ID: key, + NodeType: nodeType, + }) + if len(firstNode) > 0 && firstNode[0] { + tm.startNode = key + } + return nil +} + +func (tm *DAG) IsReady() bool { + var isReady bool + tm.nodes.ForEach(func(_ string, n *Node) bool { + if !n.isReady { + return false + } + isReady = true + return true + }) + return isReady +} + func (tm *DAG) AddEdge(edgeType EdgeType, from string, targets ...string) *DAG { if tm.Error != nil { return tm @@ -98,22 +248,15 @@ func (tm *DAG) AddEdge(edgeType EdgeType, from string, targets ...string) *DAG { return tm } -func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) Result { - var taskID string - userCtx := UserContext(ctx) - if val := userCtx.Get("task_id"); val != "" { - taskID = val - } else { - taskID = mq.NewID() - } - ctx = context.WithValue(ctx, "task_id", taskID) +func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + ctx = context.WithValue(ctx, "task_id", task.ID) userContext := UserContext(ctx) next := userContext.Get("next") - manager, ok := tm.taskManager.Get(taskID) - resultCh := make(chan Result, 1) + manager, ok := tm.taskManager.Get(task.ID) + resultCh := make(chan mq.Result, 1) if !ok { - manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone()) - tm.taskManager.Set(taskID, manager) + manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone()) + tm.taskManager.Set(task.ID, manager) } else { manager.resultCh = resultCh } @@ -125,7 +268,7 @@ func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) Result { } else if next == "true" { nodes, err := tm.GetNextNodes(currentNode) if err != nil { - return Result{Error: err, Ctx: ctx} + return mq.Result{Error: err, Ctx: ctx} } if len(nodes) > 0 { ctx = context.WithValue(ctx, "initial_node", nodes[0].ID) @@ -133,13 +276,25 @@ func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) Result { } firstNode, err := tm.parseInitialNode(ctx) if err != nil { - return Result{Error: err, Ctx: ctx} + return mq.Result{Error: err, Ctx: ctx} } node, ok = tm.nodes.Get(firstNode) - if ok && node.NodeType != Page && payload == nil { - return Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx} + if ok && node.NodeType != Page && task.Payload == nil { + return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx} } + task.Topic = firstNode ctx = context.WithValue(ctx, ContextIndex, "0") - manager.ProcessTask(ctx, firstNode, payload) + manager.ProcessTask(ctx, firstNode, task.Payload) return <-resultCh } + +func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { + var taskID string + userCtx := UserContext(ctx) + if val := userCtx.Get("task_id"); val != "" { + taskID = val + } else { + taskID = mq.NewID() + } + return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, "")) +} diff --git a/dag/v2/task_manager.go b/dag/v2/task_manager.go index 8155743..4d9fbdd 100644 --- a/dag/v2/task_manager.go +++ b/dag/v2/task_manager.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq" "log" "strings" "time" @@ -14,52 +15,52 @@ import ( type TaskState struct { NodeID string - Status Status + Status mq.Status UpdatedAt time.Time - Result Result - targetResults storage.IMap[string, Result] + Result mq.Result + targetResults storage.IMap[string, mq.Result] } func newTaskState(nodeID string) *TaskState { return &TaskState{ NodeID: nodeID, - Status: Pending, + Status: mq.Pending, UpdatedAt: time.Now(), - targetResults: memory.New[string, Result](), + targetResults: memory.New[string, mq.Result](), } } type nodeResult struct { ctx context.Context nodeID string - status Status - result Result + status mq.Status + result mq.Result } type TaskManager struct { taskStates storage.IMap[string, *TaskState] parentNodes storage.IMap[string, string] childNodes storage.IMap[string, int] - deferredTasks storage.IMap[string, *Task] + deferredTasks storage.IMap[string, *task] iteratorNodes storage.IMap[string, []Edge] currentNode string dag *DAG taskID string - taskQueue chan *Task + taskQueue chan *task resultQueue chan nodeResult - resultCh chan Result + resultCh chan mq.Result stopCh chan struct{} } -type Task struct { +type task struct { ctx context.Context taskID string nodeID string payload json.RawMessage } -func NewTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *Task { - return &Task{ +func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task { + return &task{ ctx: ctx, taskID: taskID, nodeID: nodeID, @@ -67,13 +68,13 @@ func NewTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage } } -func NewTaskManager(dag *DAG, taskID string, resultCh chan Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager { +func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager { tm := &TaskManager{ taskStates: memory.New[string, *TaskState](), parentNodes: memory.New[string, string](), childNodes: memory.New[string, int](), - deferredTasks: memory.New[string, *Task](), - taskQueue: make(chan *Task, DefaultChannelSize), + deferredTasks: memory.New[string, *task](), + taskQueue: make(chan *task, DefaultChannelSize), resultQueue: make(chan nodeResult, DefaultChannelSize), iteratorNodes: iteratorNodes, stopCh: make(chan struct{}), @@ -98,12 +99,12 @@ func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, paylo if _, exists := tm.taskStates.Get(startNode); !exists { tm.taskStates.Set(startNode, newTaskState(startNode)) } - task := NewTask(ctx, taskID, startNode, payload) + t := newTask(ctx, taskID, startNode, payload) select { - case tm.taskQueue <- task: + case tm.taskQueue <- t: default: - log.Println("Task queue is full, dropping task.") - tm.deferredTasks.Set(taskID, task) + log.Println("task queue is full, dropping task.") + tm.deferredTasks.Set(taskID, t) } } @@ -131,7 +132,7 @@ func (tm *TaskManager) waitForResult() { } } -func (tm *TaskManager) processNode(exec *Task) { +func (tm *TaskManager) processNode(exec *task) { pureNodeID := strings.Split(exec.nodeID, Delimiter)[0] node, exists := tm.dag.nodes.Get(pureNodeID) if !exists { @@ -144,10 +145,10 @@ func (tm *TaskManager) processNode(exec *Task) { state = newTaskState(exec.nodeID) tm.taskStates.Set(exec.nodeID, state) } - state.Status = Processing + state.Status = mq.Processing state.UpdatedAt = time.Now() tm.currentNode = exec.nodeID - result := node.Handler(exec.ctx, exec.payload) + result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID)) state.Result = result result.Topic = node.ID if result.Error != nil { @@ -162,7 +163,7 @@ func (tm *TaskManager) processNode(exec *Task) { tm.handleNext(exec.ctx, node, state, result) } -func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result Result, childNode string, dispatchFinal bool) { +func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) { state.targetResults.Set(childNode, result) state.targetResults.Del(state.NodeID) targetsCount, _ := tm.childNodes.Get(state.NodeID) @@ -172,7 +173,7 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res if size > 1 { aggregatedData := make([]json.RawMessage, size) i := 0 - state.targetResults.ForEach(func(_ string, rs Result) bool { + state.targetResults.ForEach(func(_ string, rs mq.Result) bool { aggregatedData[i] = rs.Payload i++ return true @@ -181,7 +182,7 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res if err != nil { panic(err) } - state.Result = Result{Payload: aggregatedPayload, Status: Completed, Ctx: ctx, Topic: state.NodeID} + state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID} } else if size == 1 { state.Result = state.targetResults.Values()[0] } @@ -196,11 +197,11 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res result.Ctx = ctx } if result.Error != nil { - state.Status = Failed + state.Status = mq.Failed } pn, ok := tm.parentNodes.Get(state.NodeID) - if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == Completed { - state.Status = Processing + if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed { + state.Status = mq.Processing tm.iteratorNodes.Del(nodeID[0]) state.targetResults.Clear() if len(nodeID) == 2 { @@ -228,17 +229,17 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res } } -func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result Result) { +func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) { state.UpdatedAt = time.Now() if result.Ctx == nil { result.Ctx = ctx } if result.Error != nil { - state.Status = Failed + state.Status = mq.Failed } else { edges := tm.getConditionalEdges(node, result) if len(edges) == 0 { - state.Status = Completed + state.Status = mq.Completed } } if result.Status == "" { @@ -281,7 +282,7 @@ func (tm *TaskManager) onNodeCompleted(rs nodeResult) { tm.handleEdges(rs, edges) } -func (tm *TaskManager) getConditionalEdges(node *Node, result Result) []Edge { +func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge { edges := make([]Edge, len(node.Edges)) copy(edges, node.Edges) if result.ConditionStatus != "" { @@ -320,8 +321,8 @@ func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) { tm.resultQueue <- nodeResult{ ctx: currentResult.ctx, nodeID: edge.To.ID, - status: Failed, - result: Result{Error: err}, + status: mq.Failed, + result: mq.Result{Error: err}, } return } @@ -352,10 +353,10 @@ func (tm *TaskManager) retryDeferredTasks() { for retries < maxRetries { select { case <-tm.stopCh: - log.Println("Stopping Deferred Task Retrier") + log.Println("Stopping Deferred task Retrier") return case <-time.After(RetryInterval): - tm.deferredTasks.ForEach(func(taskID string, task *Task) bool { + tm.deferredTasks.ForEach(func(taskID string, task *task) bool { tm.send(task.ctx, task.nodeID, taskID, task.payload) retries++ return true diff --git a/examples/html.go b/examples/html.go deleted file mode 100644 index 949dd40..0000000 --- a/examples/html.go +++ /dev/null @@ -1,164 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "github.com/oarkflow/mq/dag/v2" - "log" - "net/http" -) - -func main() { - graph := v2.NewGraph() - customRegistrationNode := &v2.Operation{ - Type: "page", - ID: "customRegistration", - Content: `Click here to verify your email
Verify`, - } - dashboardNode := &v2.Operation{ - Type: "page", - ID: "dashboard", - Content: `Welcome to your dashboard!
`, - } - manualVerificationNode := &v2.Operation{ - Type: "page", - ID: "manualVerificationPage", - Content: `Please verify the user's information manually.
`, - } - verifyApprovedNode := &v2.Operation{ - Type: "process", - ID: "verifyApproved", - Func: func(task *v2.Task) v2.Result { - return v2.Result{} - }, - } - denyVerificationNode := &v2.Operation{ - Type: "process", - ID: "denyVerification", - Func: func(task *v2.Task) v2.Result { - task.FinalResult = "Verification Denied" - return v2.Result{} - }, - } - - graph.AddNode(customRegistrationNode) - graph.AddNode(checkValidityNode) - graph.AddNode(checkManualVerificationNode) - graph.AddNode(approveCustomerNode) - graph.AddNode(sendVerificationEmailNode) - graph.AddNode(verificationLinkPageNode) - graph.AddNode(dashboardNode) - graph.AddNode(manualVerificationNode) - graph.AddNode(verifyApprovedNode) - graph.AddNode(denyVerificationNode) - - graph.AddEdge("customRegistration", "checkValidity") - graph.AddEdge("checkValidity", "checkManualVerification") - graph.AddEdge("checkManualVerification", "approveCustomer") - graph.AddEdge("checkManualVerification", "manualVerificationPage") - graph.AddEdge("approveCustomer", "sendVerificationEmail") - graph.AddEdge("sendVerificationEmail", "verificationLinkPage") - graph.AddEdge("verificationLinkPage", "dashboard") - graph.AddEdge("manualVerificationPage", "verifyApproved") - graph.AddEdge("manualVerificationPage", "denyVerification") - graph.AddEdge("verifyApproved", "approveCustomer") - graph.AddEdge("denyVerification", "verificationLinkPage") - - http.HandleFunc("/verify", func(w http.ResponseWriter, r *http.Request) { - verifyHandler(w, r, graph.Tm) - }) - graph.Start() -} - -func isValidEmail(email string) bool { - return email != "" -} - -func isValidPhone(phone string) bool { - return phone != "" -} - -func verifyHandler(w http.ResponseWriter, r *http.Request, tm *v2.TaskManager) { - taskID := r.URL.Query().Get("taskID") - if taskID == "" { - http.Error(w, "Missing taskID", http.StatusBadRequest) - return - } - task, exists := tm.GetTask(taskID) - if !exists { - http.Error(w, "Task not found", http.StatusNotFound) - return - } - data := map[string]any{ - "email_verified": "true", - } - bt, _ := json.Marshal(data) - task.Payload = bt - log.Printf("Email for taskID %s successfully verified.", task.ID) - nextNode, exists := tm.Graph.Nodes["dashboard"] - if !exists { - http.Error(w, "Dashboard Operation not found", http.StatusInternalServerError) - return - } - result := nextNode.ProcessTask(context.Background(), task) - if result.Error != nil { - http.Error(w, result.Error.Error(), http.StatusInternalServerError) - return - } - fmt.Fprintf(w, string(result.Payload)) -} diff --git a/examples/v2.go b/examples/v2.go index 9a518bc..8bd74e5 100644 --- a/examples/v2.go +++ b/examples/v2.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" "os" "github.com/oarkflow/jet" @@ -12,94 +14,114 @@ import ( v2 "github.com/oarkflow/mq/dag/v2" ) -func Form(ctx context.Context, payload json.RawMessage) v2.Result { +type Form struct { + dag.Operation +} + +func (p *Form) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { bt, err := os.ReadFile("webroot/form.html") if err != nil { - return v2.Result{Error: err, Ctx: ctx} + return mq.Result{Error: err, Ctx: ctx} } parser := jet.NewWithMemory(jet.WithDelims("{{", "}}")) rs, err := parser.ParseTemplate(string(bt), map[string]any{ "task_id": ctx.Value("task_id"), }) if err != nil { - return v2.Result{Error: err, Ctx: ctx} + return mq.Result{Error: err, Ctx: ctx} } ctx = context.WithValue(ctx, consts.ContentType, consts.TypeHtml) data := map[string]any{ "html_content": rs, } bt, _ = json.Marshal(data) - return v2.Result{Payload: bt, Ctx: ctx} + return mq.Result{Payload: bt, Ctx: ctx} } -func NodeA(ctx context.Context, payload json.RawMessage) v2.Result { +type NodeA struct { + dag.Operation +} + +func (p *NodeA) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: err, Ctx: ctx} + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: err, Ctx: ctx} } data["allowed_voting"] = data["age"] == "18" updatedPayload, _ := json.Marshal(data) - return v2.Result{Payload: updatedPayload, Ctx: ctx} + return mq.Result{Payload: updatedPayload, Ctx: ctx} } -func NodeB(ctx context.Context, payload json.RawMessage) v2.Result { +type NodeB struct { + dag.Operation +} + +func (p *NodeB) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: err, Ctx: ctx} + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: err, Ctx: ctx} } data["female_voter"] = data["gender"] == "female" updatedPayload, _ := json.Marshal(data) - return v2.Result{Payload: updatedPayload, Ctx: ctx} + return mq.Result{Payload: updatedPayload, Ctx: ctx} } -func NodeC(ctx context.Context, payload json.RawMessage) v2.Result { +type NodeC struct { + dag.Operation +} + +func (p *NodeC) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: err, Ctx: ctx} + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: err, Ctx: ctx} } data["voted"] = true updatedPayload, _ := json.Marshal(data) - return v2.Result{Payload: updatedPayload, Ctx: ctx} + return mq.Result{Payload: updatedPayload, Ctx: ctx} } -func Result(ctx context.Context, payload json.RawMessage) v2.Result { +type Result struct { + dag.Operation +} + +func (p *Result) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { bt, err := os.ReadFile("webroot/result.html") if err != nil { - return v2.Result{Error: err, Ctx: ctx} + return mq.Result{Error: err, Ctx: ctx} } var data map[string]any - if payload != nil { - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: err, Ctx: ctx} + if task.Payload != nil { + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: err, Ctx: ctx} } } if bt != nil { parser := jet.NewWithMemory(jet.WithDelims("{{", "}}")) rs, err := parser.ParseTemplate(string(bt), data) if err != nil { - return v2.Result{Error: err, Ctx: ctx} + return mq.Result{Error: err, Ctx: ctx} } ctx = context.WithValue(ctx, consts.ContentType, consts.TypeHtml) data := map[string]any{ "html_content": rs, } bt, _ := json.Marshal(data) - return v2.Result{Payload: bt, Ctx: ctx} + return mq.Result{Payload: bt, Ctx: ctx} } - return v2.Result{Payload: payload, Ctx: ctx} + return mq.Result{Payload: task.Payload, Ctx: ctx} } -func notify(taskID string, result v2.Result) { - fmt.Printf("Final result for Task %s: %s\n", taskID, string(result.Payload)) +func notify(taskID string, result mq.Result) { + fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) } func main() { dag := v2.NewDAG(notify) - dag.AddNode(v2.Page, "Form", Form) - dag.AddNode(v2.Function, "NodeA", NodeA) - dag.AddNode(v2.Function, "NodeB", NodeB) - dag.AddNode(v2.Function, "NodeC", NodeC) - dag.AddNode(v2.Page, "Result", Result) + dag.AddNode(v2.Page, "Form", "Form", &Form{}) + dag.AddNode(v2.Function, "NodeA", "NodeA", &NodeA{}) + dag.AddNode(v2.Function, "NodeB", "NodeB", &NodeB{}) + dag.AddNode(v2.Function, "NodeC", "NodeC", &NodeC{}) + dag.AddNode(v2.Page, "Result", "Result", &Result{}) dag.AddEdge(v2.Simple, "Form", "NodeA") dag.AddEdge(v2.Simple, "NodeA", "NodeB") dag.AddEdge(v2.Simple, "NodeB", "NodeC") diff --git a/examples/v3.go b/examples/v3.go index 215f059..d0a3760 100644 --- a/examples/v3.go +++ b/examples/v3.go @@ -4,19 +4,21 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" v2 "github.com/oarkflow/mq/dag/v2" ) func main() { - dag := v2.NewDAG(func(taskID string, result v2.Result) { - // fmt.Printf("Final result for Task %s: %s\n", taskID, string(result.Payload)) + dag := v2.NewDAG(func(taskID string, result mq.Result) { + // fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) }) - dag.AddNode(v2.Function, "GetData", GetData, true) - dag.AddNode(v2.Function, "Loop", Loop) - dag.AddNode(v2.Function, "ValidateAge", ValidateAge) - dag.AddNode(v2.Function, "ValidateGender", ValidateGender) - dag.AddNode(v2.Function, "Final", Final) + dag.AddNode(v2.Function, "GetData", "GetData", &GetData{}, true) + dag.AddNode(v2.Function, "Loop", "Loop", &Loop{}) + dag.AddNode(v2.Function, "ValidateAge", "ValidateAge", &ValidateAge{}) + dag.AddNode(v2.Function, "ValidateGender", "ValidateGender", &ValidateGender{}) + dag.AddNode(v2.Function, "Final", "Final", &Final{}) dag.AddEdge(v2.Simple, "GetData", "Loop") dag.AddEdge(v2.Iterator, "Loop", "ValidateAge") @@ -29,25 +31,37 @@ func main() { panic(dag.Error) } - rs := dag.ProcessTask(context.Background(), data) + rs := dag.Process(context.Background(), data) if rs.Error != nil { panic(rs.Error) } fmt.Println(rs.Status, rs.Topic, string(rs.Payload)) } -func GetData(ctx context.Context, payload json.RawMessage) v2.Result { - return v2.Result{Ctx: ctx, Payload: payload} +type GetData struct { + dag.Operation } -func Loop(ctx context.Context, payload json.RawMessage) v2.Result { - return v2.Result{Ctx: ctx, Payload: payload} +func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Ctx: ctx, Payload: task.Payload} } -func ValidateAge(ctx context.Context, payload json.RawMessage) v2.Result { +type Loop struct { + dag.Operation +} + +func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Ctx: ctx, Payload: task.Payload} +} + +type ValidateAge struct { + dag.Operation +} + +func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: fmt.Errorf("ValidateAge Error: %s", err.Error()), Ctx: ctx} + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: fmt.Errorf("ValidateAge Error: %s", err.Error()), Ctx: ctx} } var status string if data["age"] == "18" { @@ -56,23 +70,31 @@ func ValidateAge(ctx context.Context, payload json.RawMessage) v2.Result { status = "default" } updatedPayload, _ := json.Marshal(data) - return v2.Result{Payload: updatedPayload, Ctx: ctx, ConditionStatus: status} + return mq.Result{Payload: updatedPayload, Ctx: ctx, ConditionStatus: status} } -func ValidateGender(ctx context.Context, payload json.RawMessage) v2.Result { +type ValidateGender struct { + dag.Operation +} + +func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: fmt.Errorf("ValidateGender Error: %s", err.Error()), Ctx: ctx} + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: fmt.Errorf("ValidateGender Error: %s", err.Error()), Ctx: ctx} } data["female_voter"] = data["gender"] == "female" updatedPayload, _ := json.Marshal(data) - return v2.Result{Payload: updatedPayload, Ctx: ctx} + return mq.Result{Payload: updatedPayload, Ctx: ctx} } -func Final(ctx context.Context, payload json.RawMessage) v2.Result { +type Final struct { + dag.Operation +} + +func (p *Final) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { var data []map[string]any - if err := json.Unmarshal(payload, &data); err != nil { - return v2.Result{Error: fmt.Errorf("Final Error: %s", err.Error()), Ctx: ctx} + if err := json.Unmarshal(task.Payload, &data); err != nil { + return mq.Result{Error: fmt.Errorf("Final Error: %s", err.Error()), Ctx: ctx} } for i, row := range data { row["done"] = true @@ -82,5 +104,5 @@ func Final(ctx context.Context, payload json.RawMessage) v2.Result { if err != nil { panic(err) } - return v2.Result{Payload: updatedPayload, Ctx: ctx} + return mq.Result{Payload: updatedPayload, Ctx: ctx} } diff --git a/options.go b/options.go index 2de23ad..c3a8cb1 100644 --- a/options.go +++ b/options.go @@ -12,6 +12,15 @@ import ( "github.com/oarkflow/mq/consts" ) +type Status string + +const ( + Pending Status = "Pending" + Processing Status = "Processing" + Completed Status = "Completed" + Failed Status = "Failed" +) + type Result struct { CreatedAt time.Time `json:"created_at"` ProcessedAt time.Time `json:"processed_at,omitempty"` @@ -19,7 +28,7 @@ type Result struct { Error error `json:"-"` // Keep error as an error type Topic string `json:"topic"` TaskID string `json:"task_id"` - Status string `json:"status"` + Status Status `json:"status"` ConditionStatus string `json:"condition_status"` Ctx context.Context `json:"-"` Payload json.RawMessage `json:"payload"` @@ -67,8 +76,8 @@ func (r Result) Unmarshal(data any) error { return json.Unmarshal(r.Payload, data) } -func HandleError(ctx context.Context, err error, status ...string) Result { - st := "Failed" +func HandleError(ctx context.Context, err error, status ...Status) Result { + st := Failed if len(status) > 0 { st = status[0] } @@ -82,7 +91,7 @@ func HandleError(ctx context.Context, err error, status ...string) Result { } } -func (r Result) WithData(status string, data []byte) Result { +func (r Result) WithData(status Status, data []byte) Result { if r.Error != nil { return r }