diff --git a/dag/api.go b/dag/api.go index f37fe4d..d7395f0 100644 --- a/dag/api.go +++ b/dag/api.go @@ -4,26 +4,110 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" - "net/url" "os" - "time" - - "github.com/oarkflow/mq/jsonparser" - "github.com/oarkflow/mq/sio" + "strings" "github.com/oarkflow/mq" + "github.com/oarkflow/mq/sio" + "github.com/oarkflow/mq/consts" - "github.com/oarkflow/mq/metrics" + "github.com/oarkflow/mq/jsonparser" ) -type Request struct { - Payload json.RawMessage `json:"payload"` - Interval time.Duration `json:"interval"` - Schedule bool `json:"schedule"` - Overlap bool `json:"overlap"` - Recurring bool `json:"recurring"` +func renderNotFound(w http.ResponseWriter) { + html := []byte(` +
+

task not found

+

Back to home

+
+`) + w.Header().Set(consts.ContentType, consts.TypeHtml) + w.Write(html) +} + +func (tm *DAG) render(w http.ResponseWriter, r *http.Request) { + ctx, data, err := parse(r) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + accept := r.Header.Get("Accept") + userCtx := UserContext(ctx) + ctx = context.WithValue(ctx, "method", r.Method) + if r.Method == "GET" && userCtx.Get("task_id") != "" { + manager, ok := tm.taskManager.Get(userCtx.Get("task_id")) + if !ok || manager == nil { + if strings.Contains(accept, "text/html") || accept == "" { + renderNotFound(w) + return + } + http.Error(w, fmt.Sprintf(`{"message": "%s"}`, "task not found"), http.StatusInternalServerError) + return + } + } + result := tm.Process(ctx, data) + if result.Error != nil { + http.Error(w, fmt.Sprintf(`{"message": "%s"}`, result.Error.Error()), http.StatusInternalServerError) + return + } + contentType, ok := result.Ctx.Value(consts.ContentType).(string) + if !ok { + contentType = consts.TypeJson + } + switch contentType { + case consts.TypeHtml: + w.Header().Set(consts.ContentType, consts.TypeHtml) + data, err := jsonparser.GetString(result.Payload, "html_content") + if err != nil { + return + } + w.Write([]byte(data)) + default: + if r.Method != "POST" { + http.Error(w, `{"message": "not allowed"}`, http.StatusMethodNotAllowed) + return + } + w.Header().Set(consts.ContentType, consts.TypeJson) + json.NewEncoder(w).Encode(result.Payload) + } +} + +func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) { + taskID := r.URL.Query().Get("taskID") + if taskID == "" { + http.Error(w, `{"message": "taskID is missing"}`, http.StatusBadRequest) + return + } + manager, ok := tm.taskManager.Get(taskID) + if !ok { + http.Error(w, `{"message": "Invalid TaskID"}`, http.StatusNotFound) + return + } + result := make(map[string]TaskState) + manager.taskStates.ForEach(func(key string, value *TaskState) bool { + key = strings.Split(key, Delimiter)[0] + nodeID := strings.Split(value.NodeID, Delimiter)[0] + rs := jsonparser.Delete(value.Result.Payload, "html_content") + status := value.Status + if status == mq.Processing { + status = mq.Completed + } + state := TaskState{ + NodeID: nodeID, + Status: status, + UpdatedAt: value.UpdatedAt, + Result: mq.Result{ + Payload: rs, + Error: value.Result.Error, + Status: status, + }, + } + result[key] = state + return true + }) + w.Header().Set(consts.ContentType, consts.TypeJson) + json.NewEncoder(w).Encode(result) } func (tm *DAG) SetupWS() *sio.Server { @@ -37,57 +121,11 @@ func (tm *DAG) SetupWS() *sio.Server { } func (tm *DAG) Handlers() { - metrics.HandleHTTP() http.Handle("/", http.FileServer(http.Dir("webroot"))) http.Handle("/notify", tm.SetupWS()) - http.HandleFunc("GET /render", tm.Render) - http.HandleFunc("POST /request", tm.Request) - http.HandleFunc("POST /publish", tm.Publish) - http.HandleFunc("POST /schedule", tm.Schedule) - http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { - id := request.PathValue("id") - if id != "" { - tm.PauseConsumer(request.Context(), id) - } - }) - http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { - id := request.PathValue("id") - if id != "" { - tm.ResumeConsumer(request.Context(), id) - } - }) - http.HandleFunc("/pause", func(w http.ResponseWriter, request *http.Request) { - err := tm.Pause(request.Context()) - if err != nil { - http.Error(w, "Failed to pause", http.StatusBadRequest) - return - } - json.NewEncoder(w).Encode(map[string]string{"status": "paused"}) - }) - http.HandleFunc("/resume", func(w http.ResponseWriter, request *http.Request) { - err := tm.Resume(request.Context()) - if err != nil { - http.Error(w, "Failed to resume", http.StatusBadRequest) - return - } - json.NewEncoder(w).Encode(map[string]string{"status": "resumed"}) - }) - http.HandleFunc("/stop", func(w http.ResponseWriter, request *http.Request) { - err := tm.Stop(request.Context()) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - json.NewEncoder(w).Encode(map[string]string{"status": "stopped"}) - }) - http.HandleFunc("/close", func(w http.ResponseWriter, request *http.Request) { - err := tm.Close() - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - json.NewEncoder(w).Encode(map[string]string{"status": "closed"}) - }) + http.HandleFunc("/process", tm.render) + http.HandleFunc("/request", tm.render) + http.HandleFunc("/task/status", tm.taskStatusHandler) http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") fmt.Fprintln(w, tm.ExportDOT()) @@ -112,98 +150,3 @@ func (tm *DAG) Handlers() { } }) } - -func (tm *DAG) request(w http.ResponseWriter, r *http.Request, async bool) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var request Request - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - err = json.Unmarshal(payload, &request) - if err != nil { - http.Error(w, "Failed to unmarshal body", http.StatusBadRequest) - return - } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return - } - ctx := r.Context() - if async { - ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) - } - var opts []mq.SchedulerOption - if request.Interval > 0 { - opts = append(opts, mq.WithInterval(request.Interval)) - } - if request.Overlap { - opts = append(opts, mq.WithOverlap()) - } - if request.Recurring { - opts = append(opts, mq.WithRecurring()) - } - ctx = context.WithValue(ctx, "query_params", r.URL.Query()) - var rs mq.Result - if request.Schedule { - rs = tm.ScheduleTask(ctx, request.Payload, opts...) - } else { - rs = tm.Process(ctx, request.Payload) - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(rs) -} - -func (tm *DAG) Render(w http.ResponseWriter, r *http.Request) { - ctx := mq.SetHeaders(r.Context(), map[string]string{consts.AwaitResponseKey: "true", "request_type": "render"}) - ctx = context.WithValue(ctx, "query_params", r.URL.Query()) - rs := tm.Process(ctx, nil) - content, err := jsonparser.GetString(rs.Payload, "html_content") - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - w.Header().Set("Content-Type", consts.TypeHtml) - w.Write([]byte(content)) -} - -func (tm *DAG) Request(w http.ResponseWriter, r *http.Request) { - tm.request(w, r, true) -} - -func (tm *DAG) Publish(w http.ResponseWriter, r *http.Request) { - tm.request(w, r, false) -} - -func (tm *DAG) Schedule(w http.ResponseWriter, r *http.Request) { - tm.request(w, r, false) -} - -func GetTaskID(ctx context.Context) string { - if queryParams := ctx.Value("query_params"); queryParams != nil { - if params, ok := queryParams.(url.Values); ok { - if id := params.Get("taskID"); id != "" { - return id - } - } - } - return "" -} - -func CanNextNode(ctx context.Context) string { - if queryParams := ctx.Value("query_params"); queryParams != nil { - if params, ok := queryParams.(url.Values); ok { - if id := params.Get("next"); id != "" { - return id - } - } - } - return "" -} diff --git a/dag/consts.go b/dag/consts.go index 372a7ee..7e6d712 100644 --- a/dag/consts.go +++ b/dag/consts.go @@ -1,26 +1,28 @@ package dag -type NodeStatus int - -func (c NodeStatus) IsValid() bool { return c >= Pending && c <= Failed } - -func (c NodeStatus) String() string { - switch c { - case Pending: - return "Pending" - case Processing: - return "Processing" - case Completed: - return "Completed" - case Failed: - return "Failed" - } - return "" -} +import "time" const ( - Pending NodeStatus = iota - Processing - Completed - Failed + Delimiter = "___" + ContextIndex = "index" + DefaultChannelSize = 1000 + RetryInterval = 5 * time.Second +) + +type NodeType int + +func (c NodeType) IsValid() bool { return c >= Function && c <= Page } + +const ( + Function NodeType = iota + Page +) + +type EdgeType int + +func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator } + +const ( + Simple EdgeType = iota + Iterator ) diff --git a/dag/v2/context.go b/dag/context.go similarity index 99% rename from dag/v2/context.go rename to dag/context.go index d6fa0d2..c5fc06a 100644 --- a/dag/v2/context.go +++ b/dag/context.go @@ -1,4 +1,4 @@ -package v2 +package dag import ( "context" diff --git a/dag/dag.go b/dag/dag.go index fb1bc64..a010e91 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -2,155 +2,70 @@ package dag import ( "context" + "encoding/json" "fmt" "log" "net/http" + "strings" "time" - "github.com/oarkflow/mq/storage" - "github.com/oarkflow/mq/storage/memory" - - "github.com/oarkflow/mq/sio" - "golang.org/x/time/rate" - "github.com/oarkflow/mq" "github.com/oarkflow/mq/consts" - "github.com/oarkflow/mq/metrics" -) + "github.com/oarkflow/mq/sio" -type EdgeType int - -func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator } - -const ( - Simple EdgeType = iota - Iterator -) - -type NodeType int - -func (c NodeType) IsValid() bool { return c >= Process && c <= Page } - -const ( - Process NodeType = iota - Page + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/storage" + "github.com/oarkflow/mq/storage/memory" ) type Node struct { - processor mq.Processor - Name string - Type NodeType - Key string + NodeType NodeType + Label string + ID string Edges []Edge + processor mq.Processor isReady bool } -func (n *Node) ProcessTask(ctx context.Context, msg *mq.Task) mq.Result { - return n.processor.ProcessTask(ctx, msg) -} - -func (n *Node) Close() error { - return n.processor.Close() -} - type Edge struct { - Label string From *Node - To []*Node + To *Node Type EdgeType + Label string } -type ( - FromNode string - When string - Then string -) - type DAG struct { server *mq.Broker consumer *mq.Consumer - taskContext storage.IMap[string, *TaskManager] - nodes map[string]*Node + nodes storage.IMap[string, *Node] + taskManager storage.IMap[string, *TaskManager] iteratorNodes storage.IMap[string, []Edge] - conditions map[FromNode]map[When]Then + finalResult func(taskID string, result mq.Result) pool *mq.Pool - taskCleanupCh chan string name string key string startNode string - consumerTopic string opts []mq.Option + conditions map[string]map[string]string + consumerTopic string reportNodeResultCallback func(mq.Result) + Error error Notifier *sio.Server paused bool - Error error report string - index string } -func (tm *DAG) SetKey(key string) { - tm.key = key -} - -func (tm *DAG) ReportNodeResult(callback func(mq.Result)) { - tm.reportNodeResultCallback = callback -} - -func (tm *DAG) GetType() string { - return tm.key -} - -func (tm *DAG) listenForTaskCleanup() { - for taskID := range tm.taskCleanupCh { - if tm.server.Options().CleanTaskOnComplete() { - tm.taskCleanup(taskID) - } - } -} - -func (tm *DAG) taskCleanup(taskID string) { - tm.taskContext.Del(taskID) - log.Printf("DAG - Task %s cleaned up", taskID) -} - -func (tm *DAG) Consume(ctx context.Context) error { - if tm.consumer != nil { - tm.server.Options().SetSyncMode(true) - return tm.consumer.Consume(ctx) - } - return nil -} - -func (tm *DAG) Stop(ctx context.Context) error { - for _, n := range tm.nodes { - err := n.processor.Stop(ctx) - if err != nil { - return err - } - } - return nil -} - -func (tm *DAG) GetKey() string { - return tm.key -} - -func (tm *DAG) AssignTopic(topic string) { - tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL())) - tm.consumerTopic = topic -} - -func NewDAG(name, key string, opts ...mq.Option) *DAG { +func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.Result), opts ...mq.Option) *DAG { callback := func(ctx context.Context, result mq.Result) error { return nil } d := &DAG{ name: name, key: key, - nodes: make(map[string]*Node), + nodes: memory.New[string, *Node](), + taskManager: memory.New[string, *TaskManager](), iteratorNodes: memory.New[string, []Edge](), - taskContext: memory.New[string, *TaskManager](), - conditions: make(map[FromNode]map[When]Then), - taskCleanupCh: make(chan string), + conditions: make(map[string]map[string]string), + finalResult: finalResultCallback, } opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) d.server = mq.NewBroker(opts...) @@ -165,10 +80,61 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG { mq.WithTaskStorage(options.Storage()), ) d.pool.Start(d.server.Options().NumOfWorkers()) - go d.listenForTaskCleanup() return d } +func (tm *DAG) SetKey(key string) { + tm.key = key +} + +func (tm *DAG) ReportNodeResult(callback func(mq.Result)) { + tm.reportNodeResultCallback = callback +} + +func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { + if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" { + manager.onNodeCompleted(nodeResult{ + ctx: ctx, + nodeID: result.Topic, + status: result.Status, + result: result, + }) + } + return mq.Result{} +} + +func (tm *DAG) GetType() string { + return tm.key +} + +func (tm *DAG) Consume(ctx context.Context) error { + if tm.consumer != nil { + tm.server.Options().SetSyncMode(true) + return tm.consumer.Consume(ctx) + } + return nil +} + +func (tm *DAG) Stop(ctx context.Context) error { + tm.nodes.ForEach(func(_ string, n *Node) bool { + err := n.processor.Stop(ctx) + if err != nil { + return false + } + return true + }) + return nil +} + +func (tm *DAG) GetKey() string { + return tm.key +} + +func (tm *DAG) AssignTopic(topic string) { + tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL())) + tm.consumerTopic = topic +} + func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) { if tm.consumer != nil { result.Topic = tm.consumerTopic @@ -180,114 +146,89 @@ func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) { } } -func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { - if taskContext, ok := tm.taskContext.Get(result.TaskID); ok && result.Topic != "" { - return taskContext.handleNextTask(ctx, result) - } - return mq.Result{} -} - func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { - if node, ok := tm.nodes[topic]; ok { + if node, ok := tm.nodes.Get(topic); ok { log.Printf("DAG - CONSUMER ~> ready on %s", topic) node.isReady = true } } func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) { - if node, ok := tm.nodes[topic]; ok { + if node, ok := tm.nodes.Get(topic); ok { log.Printf("DAG - CONSUMER ~> down on %s", topic) node.isReady = false } } +func (tm *DAG) Pause(_ context.Context) error { + tm.paused = true + return nil +} + +func (tm *DAG) Resume(_ context.Context) error { + tm.paused = false + return nil +} + +func (tm *DAG) Close() error { + var err error + tm.nodes.ForEach(func(_ string, n *Node) bool { + err = n.processor.Close() + if err != nil { + return false + } + return true + }) + return nil +} + func (tm *DAG) SetStartNode(node string) { tm.startNode = node } +func (tm *DAG) SetNotifyResponse(callback mq.Callback) { + tm.server.SetNotifyHandler(callback) +} + func (tm *DAG) GetStartNode() string { return tm.startNode } -func (tm *DAG) Start(ctx context.Context, addr string) error { - // Start the server in a separate goroutine - go func() { - defer mq.RecoverPanic(mq.RecoverTitle) - if err := tm.server.Start(ctx); err != nil { - panic(err) - } - }() - - // Start the node consumers if not in sync mode - if !tm.server.SyncMode() { - for _, con := range tm.nodes { - go func(con *Node) { - defer mq.RecoverPanic(mq.RecoverTitle) - limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second - for { - err := con.processor.Consume(ctx) - if err != nil { - log.Printf("[ERROR] - Consumer %s failed to start: %v", con.Key, err) - } else { - log.Printf("[INFO] - Consumer %s started successfully", con.Key) - break - } - limiter.Wait(ctx) // Wait with rate limiting before retrying - } - }(con) - } - } - log.Printf("DAG - HTTP_SERVER ~> started on http://localhost%s", addr) - tm.Handlers() - config := tm.server.TLSConfig() - if config.UseTLS { - return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) - } - return http.ListenAndServe(addr, nil) -} - -func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG { - dag.AssignTopic(key) - tm.nodes[key] = &Node{ - Name: name, - Key: key, - processor: dag, - isReady: true, - } - if len(firstNode) > 0 && firstNode[0] { - tm.startNode = key - } +func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG { + tm.conditions[fromNode] = conditions return tm } -func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG { - con := mq.NewConsumer(key, key, handler.ProcessTask, tm.opts...) +func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG { + if tm.Error != nil { + return tm + } + con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask) n := &Node{ - Name: name, - Key: key, + Label: name, + ID: nodeID, + NodeType: nodeType, processor: con, } - if handler.GetType() == "page" { - n.Type = Page - } - if tm.server.SyncMode() { + if tm.server != nil && tm.server.SyncMode() { n.isReady = true } - tm.nodes[key] = n - if len(firstNode) > 0 && firstNode[0] { - tm.startNode = key + tm.nodes.Set(nodeID, n) + if len(startNode) > 0 && startNode[0] { + tm.startNode = nodeID } return tm } -func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error { +func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error { if tm.server.SyncMode() { return fmt.Errorf("DAG cannot have deferred node in Sync Mode") } - tm.nodes[key] = &Node{ - Name: name, - Key: key, - } + tm.nodes.Set(key, &Node{ + Label: name, + ID: key, + NodeType: nodeType, + }) if len(firstNode) > 0 && firstNode[0] { tm.startNode = key } @@ -295,52 +236,124 @@ func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error { } func (tm *DAG) IsReady() bool { - for _, node := range tm.nodes { - if !node.isReady { + var isReady bool + tm.nodes.ForEach(func(_ string, n *Node) bool { + if !n.isReady { return false } + isReady = true + return true + }) + return isReady +} + +func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG { + if tm.Error != nil { + return tm } - return true -} - -func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG { - tm.conditions[fromNode] = conditions - return tm -} - -func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG { - tm.Error = tm.addEdge(Iterator, label, from, targets...) - tm.iteratorNodes.Set(from, []Edge{}) - return tm -} - -func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG { - tm.Error = tm.addEdge(Simple, label, from, targets...) - return tm -} - -func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error { - fromNode, ok := tm.nodes[from] + if edgeType == Iterator { + tm.iteratorNodes.Set(from, []Edge{}) + } + node, ok := tm.nodes.Get(from) if !ok { - return fmt.Errorf("Error: 'from' node %s does not exist\n", from) + tm.Error = fmt.Errorf("node not found %s", from) + return tm } - var nodes []*Node for _, target := range targets { - toNode, ok := tm.nodes[target] - if !ok { - return fmt.Errorf("Error: 'from' node %s does not exist\n", target) - } - nodes = append(nodes, toNode) - } - edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label} - fromNode.Edges = append(fromNode.Edges, edge) - if edgeType != Iterator { - if edges, ok := tm.iteratorNodes.Get(fromNode.Key); ok { - edges = append(edges, edge) - tm.iteratorNodes.Set(fromNode.Key, edges) + if targetNode, ok := tm.nodes.Get(target); ok { + edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label} + node.Edges = append(node.Edges, edge) + if edgeType != Iterator { + if edges, ok := tm.iteratorNodes.Get(node.ID); ok { + edges = append(edges, edge) + tm.iteratorNodes.Set(node.ID, edges) + } + } } } - return nil + return tm +} + +func (tm *DAG) getCurrentNode(manager *TaskManager) string { + if manager.currentNodePayload.Size() == 0 { + return "" + } + return manager.currentNodePayload.Keys()[0] +} + +func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + ctx = context.WithValue(ctx, "task_id", task.ID) + userContext := UserContext(ctx) + next := userContext.Get("next") + manager, ok := tm.taskManager.Get(task.ID) + resultCh := make(chan mq.Result, 1) + if !ok { + manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone()) + tm.taskManager.Set(task.ID, manager) + } else { + manager.resultCh = resultCh + } + currentKey := tm.getCurrentNode(manager) + currentNode := strings.Split(currentKey, Delimiter)[0] + node, exists := tm.nodes.Get(currentNode) + method, ok := ctx.Value("method").(string) + if method == "GET" && exists && node.NodeType == Page { + ctx = context.WithValue(ctx, "initial_node", currentNode) + /* + if isLastNode, err := tm.IsLastNode(currentNode); err != nil && isLastNode { + if manager.result != nil { + fmt.Println(string(manager.result.Payload)) + resultCh <- *manager.result + return <-resultCh + } + } + */ + if manager.result != nil { + task.Payload = manager.result.Payload + } + } else if next == "true" { + nodes, err := tm.GetNextNodes(currentNode) + if err != nil { + return mq.Result{Error: err, Ctx: ctx} + } + if len(nodes) > 0 { + ctx = context.WithValue(ctx, "initial_node", nodes[0].ID) + } + } + if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult { + var taskPayload, resultPayload map[string]any + if err := json.Unmarshal(task.Payload, &taskPayload); err == nil { + if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil { + for key, val := range resultPayload { + taskPayload[key] = val + } + task.Payload, _ = json.Marshal(taskPayload) + } + } + } + firstNode, err := tm.parseInitialNode(ctx) + if err != nil { + return mq.Result{Error: err, Ctx: ctx} + } + node, ok = tm.nodes.Get(firstNode) + if ok && node.NodeType != Page && task.Payload == nil { + return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx} + } + task.Topic = firstNode + ctx = context.WithValue(ctx, ContextIndex, "0") + manager.ProcessTask(ctx, firstNode, task.Payload) + return <-resultCh +} + +func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { + var taskID string + userCtx := UserContext(ctx) + if val := userCtx.Get("task_id"); val != "" { + taskID = val + } else { + taskID = mq.NewID() + } + return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, "")) } func (tm *DAG) Validate() error { @@ -357,177 +370,124 @@ func (tm *DAG) GetReport() string { return tm.report } -func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { - if task.ID == "" { - task.ID = mq.NewID() - } - if index, ok := mq.GetHeader(ctx, "index"); ok { - tm.index = index - } - manager, exists := tm.taskContext.Get(task.ID) - if !exists { - manager = NewTaskManager(tm, task.ID, tm.iteratorNodes) - manager.createdAt = task.CreatedAt - tm.taskContext.Set(task.ID, manager) +func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG { + dag.AssignTopic(key) + tm.nodes.Set(key, &Node{ + Label: name, + ID: key, + processor: dag, + isReady: true, + }) + if len(firstNode) > 0 && firstNode[0] { + tm.startNode = key } + return tm +} - if tm.consumer != nil { - initialNode, err := tm.parseInitialNode(ctx) - if err != nil { - metrics.TasksErrors.WithLabelValues("unknown").Inc() // Increase error count - return mq.Result{Error: err} +func (tm *DAG) Start(ctx context.Context, addr string) error { + // Start the server in a separate goroutine + go func() { + defer mq.RecoverPanic(mq.RecoverTitle) + if err := tm.server.Start(ctx); err != nil { + panic(err) } - task.Topic = initialNode - } - if manager.topic != "" { - task.Topic = manager.topic - canNext := CanNextNode(ctx) - if canNext != "" { - if n, ok := tm.nodes[task.Topic]; ok { - if len(n.Edges) > 0 { - task.Topic = n.Edges[0].To[0].Key + }() + + // Start the node consumers if not in sync mode + if !tm.server.SyncMode() { + tm.nodes.ForEach(func(_ string, con *Node) bool { + go func(con *Node) { + defer mq.RecoverPanic(mq.RecoverTitle) + limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second + for { + err := con.processor.Consume(ctx) + if err != nil { + log.Printf("[ERROR] - Consumer %s failed to start: %v", con.ID, err) + } else { + log.Printf("[INFO] - Consumer %s started successfully", con.ID) + break + } + limiter.Wait(ctx) // Wait with rate limiting before retrying } - } - } else { - } + }(con) + return true + }) } - result := manager.processTask(ctx, task.Topic, task.Payload) - if result.Ctx != nil && tm.index != "" { - result.Ctx = mq.SetHeaders(result.Ctx, map[string]string{"index": tm.index}) + log.Printf("DAG - HTTP_SERVER ~> started on http://%s", addr) + tm.Handlers() + config := tm.server.TLSConfig() + log.Printf("Server listening on http://%s", addr) + if config.UseTLS { + return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) } - if result.Error != nil { - metrics.TasksErrors.WithLabelValues(task.Topic).Inc() // Increase error count - } else { - metrics.TasksProcessed.WithLabelValues("success").Inc() // Increase processed task count - } - return result -} - -func (tm *DAG) check(ctx context.Context, payload []byte) (context.Context, *mq.Task, error) { - if tm.paused { - return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not accepting any task") - } - if !tm.IsReady() { - return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not ready yet") - } - initialNode, err := tm.parseInitialNode(ctx) - if err != nil { - return ctx, nil, err - } - if tm.server.SyncMode() { - ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) - } - taskID := GetTaskID(ctx) - if taskID != "" { - if _, exists := tm.taskContext.Get(taskID); !exists { - return ctx, nil, fmt.Errorf("provided task ID doesn't exist") - } - } - if taskID == "" { - taskID = mq.NewID() - } - return ctx, mq.NewTask(taskID, payload, initialNode), nil -} - -func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { - ctx, task, err := tm.check(ctx, payload) - if err != nil { - return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} - } - awaitResponse, _ := mq.GetAwaitResponse(ctx) - if awaitResponse != "true" { - headers, ok := mq.GetHeaders(ctx) - ctxx := context.Background() - if ok { - ctxx = mq.SetHeaders(ctxx, headers.AsMap()) - } - if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil { - return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "FAILED", Error: err} - } - return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"} - } - return tm.ProcessTask(ctx, task) + return http.ListenAndServe(addr, nil) } func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result { - ctx, task, err := tm.check(ctx, payload) - if err != nil { - return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} + var taskID string + userCtx := UserContext(ctx) + if val := userCtx.Get("task_id"); val != "" { + taskID = val + } else { + taskID = mq.NewID() } + t := mq.NewTask(taskID, payload, "") + + ctx = context.WithValue(ctx, "task_id", taskID) + userContext := UserContext(ctx) + next := userContext.Get("next") + manager, ok := tm.taskManager.Get(taskID) + resultCh := make(chan mq.Result, 1) + if !ok { + manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone()) + tm.taskManager.Set(taskID, manager) + } else { + manager.resultCh = resultCh + } + currentKey := tm.getCurrentNode(manager) + currentNode := strings.Split(currentKey, Delimiter)[0] + node, exists := tm.nodes.Get(currentNode) + method, ok := ctx.Value("method").(string) + if method == "GET" && exists && node.NodeType == Page { + ctx = context.WithValue(ctx, "initial_node", currentNode) + } else if next == "true" { + nodes, err := tm.GetNextNodes(currentNode) + if err != nil { + return mq.Result{Error: err, Ctx: ctx} + } + if len(nodes) > 0 { + ctx = context.WithValue(ctx, "initial_node", nodes[0].ID) + } + } + if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult { + var taskPayload, resultPayload map[string]any + if err := json.Unmarshal(payload, &taskPayload); err == nil { + if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil { + for key, val := range resultPayload { + taskPayload[key] = val + } + payload, _ = json.Marshal(taskPayload) + } + } + } + firstNode, err := tm.parseInitialNode(ctx) + if err != nil { + return mq.Result{Error: err, Ctx: ctx} + } + node, ok = tm.nodes.Get(firstNode) + if ok && node.NodeType != Page && t.Payload == nil { + return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx} + } + t.Topic = firstNode + ctx = context.WithValue(ctx, ContextIndex, "0") + headers, ok := mq.GetHeaders(ctx) ctxx := context.Background() if ok { ctxx = mq.SetHeaders(ctxx, headers.AsMap()) } - tm.pool.Scheduler().AddTask(ctxx, task, opts...) - return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"} -} - -func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { - val := ctx.Value("initial_node") - initialNode, ok := val.(string) - if ok { - return initialNode, nil - } - if tm.startNode == "" { - firstNode := tm.findStartNode() - if firstNode != nil { - tm.startNode = firstNode.Key - } - } - - if tm.startNode == "" { - return "", fmt.Errorf("initial node not found") - } - return tm.startNode, nil -} - -func (tm *DAG) findStartNode() *Node { - incomingEdges := make(map[string]bool) - connectedNodes := make(map[string]bool) - for _, node := range tm.nodes { - for _, edge := range node.Edges { - if edge.Type.IsValid() { - connectedNodes[node.Key] = true - for _, to := range edge.To { - connectedNodes[to.Key] = true - incomingEdges[to.Key] = true - } - } - } - if cond, ok := tm.conditions[FromNode(node.Key)]; ok { - for _, target := range cond { - connectedNodes[string(target)] = true - incomingEdges[string(target)] = true - } - } - } - for nodeID, node := range tm.nodes { - if !incomingEdges[nodeID] && connectedNodes[nodeID] { - return node - } - } - return nil -} - -func (tm *DAG) Pause(_ context.Context) error { - tm.paused = true - return nil -} - -func (tm *DAG) Resume(_ context.Context) error { - tm.paused = false - return nil -} - -func (tm *DAG) Close() error { - for _, n := range tm.nodes { - err := n.Close() - if err != nil { - return err - } - } - return nil + tm.pool.Scheduler().AddTask(ctxx, t, opts...) + return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: "PENDING"} } func (tm *DAG) PauseConsumer(ctx context.Context, id string) { @@ -539,74 +499,26 @@ func (tm *DAG) ResumeConsumer(ctx context.Context, id string) { } func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) { - if node, ok := tm.nodes[id]; ok { + if node, ok := tm.nodes.Get(id); ok { switch action { case consts.CONSUMER_PAUSE: err := node.processor.Pause(ctx) if err == nil { node.isReady = false - log.Printf("[INFO] - Consumer %s paused successfully", node.Key) + log.Printf("[INFO] - Consumer %s paused successfully", node.ID) } else { - log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.Key, err) + log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.ID, err) } case consts.CONSUMER_RESUME: err := node.processor.Resume(ctx) if err == nil { node.isReady = true - log.Printf("[INFO] - Consumer %s resumed successfully", node.Key) + log.Printf("[INFO] - Consumer %s resumed successfully", node.ID) } else { - log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.Key, err) + log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.ID, err) } } } else { log.Printf("[WARNING] - Consumer %s not found", id) } } - -func (tm *DAG) SetNotifyResponse(callback mq.Callback) { - tm.server.SetNotifyHandler(callback) -} - -func (tm *DAG) GetNextNodes(key string) ([]*Node, error) { - node, exists := tm.nodes[key] - if !exists { - return nil, fmt.Errorf("Node with key %s does not exist", key) - } - var successors []*Node - for _, edge := range node.Edges { - successors = append(successors, edge.To...) - } - if conds, exists := tm.conditions[FromNode(key)]; exists { - for _, targetKey := range conds { - if targetNode, exists := tm.nodes[string(targetKey)]; exists { - successors = append(successors, targetNode) - } - } - } - return successors, nil -} - -func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) { - var predecessors []*Node - for _, node := range tm.nodes { - for _, edge := range node.Edges { - for _, target := range edge.To { - if target.Key == key { - predecessors = append(predecessors, node) - } - } - } - } - for fromNode, conds := range tm.conditions { - for _, targetKey := range conds { - if string(targetKey) == key { - node, exists := tm.nodes[string(fromNode)] - if !exists { - return nil, fmt.Errorf("Node with key %s does not exist", fromNode) - } - predecessors = append(predecessors, node) - } - } - } - return predecessors, nil -} diff --git a/dag/v2/node.go b/dag/node.go similarity index 99% rename from dag/v2/node.go rename to dag/node.go index bba93ac..66161e8 100644 --- a/dag/v2/node.go +++ b/dag/node.go @@ -1,4 +1,4 @@ -package v2 +package dag import ( "context" diff --git a/dag/task_manager.go b/dag/task_manager.go index 0ea8067..3b4e68a 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -4,247 +4,307 @@ import ( "context" "encoding/json" "fmt" + "log" "strings" "time" - "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/storage" "github.com/oarkflow/mq/storage/memory" - - "github.com/oarkflow/mq" ) +type TaskState struct { + NodeID string + Status mq.Status + UpdatedAt time.Time + Result mq.Result + targetResults storage.IMap[string, mq.Result] +} + +func newTaskState(nodeID string) *TaskState { + return &TaskState{ + NodeID: nodeID, + Status: mq.Pending, + UpdatedAt: time.Now(), + targetResults: memory.New[string, mq.Result](), + } +} + +type nodeResult struct { + ctx context.Context + nodeID string + status mq.Status + result mq.Result +} + type TaskManager struct { - createdAt time.Time - processedAt time.Time - status string - dag *DAG - taskID string - wg *WaitGroup - topic string - result mq.Result - - iteratorNodes storage.IMap[string, []Edge] - taskNodeStatus storage.IMap[string, *taskNodeStatus] + taskStates storage.IMap[string, *TaskState] + parentNodes storage.IMap[string, string] + childNodes storage.IMap[string, int] + deferredTasks storage.IMap[string, *task] + iteratorNodes storage.IMap[string, []Edge] + currentNodePayload storage.IMap[string, json.RawMessage] + currentNodeResult storage.IMap[string, mq.Result] + result *mq.Result + dag *DAG + taskID string + taskQueue chan *task + resultQueue chan nodeResult + resultCh chan mq.Result + stopCh chan struct{} } -func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager { - return &TaskManager{ - dag: d, - taskNodeStatus: memory.New[string, *taskNodeStatus](), - taskID: taskID, - iteratorNodes: iteratorNodes, - wg: NewWaitGroup(), +type task struct { + ctx context.Context + taskID string + nodeID string + payload json.RawMessage +} + +func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task { + return &task{ + ctx: ctx, + taskID: taskID, + nodeID: nodeID, + payload: payload, } } -func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result { - tm.updateTS(&tm.result) - tm.dag.callbackToConsumer(ctx, tm.result) - if tm.dag.server.NotifyHandler() != nil { - _ = tm.dag.server.NotifyHandler()(ctx, tm.result) +func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager { + tm := &TaskManager{ + taskStates: memory.New[string, *TaskState](), + parentNodes: memory.New[string, string](), + childNodes: memory.New[string, int](), + deferredTasks: memory.New[string, *task](), + currentNodePayload: memory.New[string, json.RawMessage](), + currentNodeResult: memory.New[string, mq.Result](), + taskQueue: make(chan *task, DefaultChannelSize), + resultQueue: make(chan nodeResult, DefaultChannelSize), + iteratorNodes: iteratorNodes, + stopCh: make(chan struct{}), + resultCh: resultCh, + taskID: taskID, + dag: dag, } - tm.dag.taskCleanupCh <- tm.taskID - tm.topic = tm.result.Topic - return tm.result + go tm.run() + go tm.waitForResult() + return tm } -func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) { - if tm.dag.reportNodeResultCallback != nil { - tm.dag.reportNodeResultCallback(result) - } +func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) { + tm.send(ctx, startNode, tm.taskID, payload) } -func (tm *TaskManager) SetTotalItems(topic string, i int) { - if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok { - nodeStatus.totalItems = i +func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) { + if index, ok := ctx.Value(ContextIndex).(string); ok { + startNode = strings.Split(startNode, Delimiter)[0] + startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index) } -} - -func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { - topic := getTopic(ctx, node.Key) - tm.taskNodeStatus.Set(topic, newNodeStatus(topic)) - defer mq.RecoverPanic(mq.RecoverTitle) - dag, isDAG := isDAGNode(node) - if isDAG { - if tm.dag.server.SyncMode() && !dag.server.SyncMode() { - dag.server.Options().SetSyncMode(true) - } - } - tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key}) - var result mq.Result - if tm.dag.server.SyncMode() { - defer func() { - if isDAG { - result.Topic = dag.consumerTopic - result.TaskID = tm.taskID - tm.reportNodeResult(result, false) - tm.handleNextTask(result.Ctx, result) - } else { - result.Topic = node.Key - tm.reportNodeResult(result, false) - tm.handleNextTask(ctx, result) - } - }() + if _, exists := tm.taskStates.Get(startNode); !exists { + tm.taskStates.Set(startNode, newTaskState(startNode)) } + t := newTask(ctx, taskID, startNode, payload) select { - case <-ctx.Done(): - result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx} - tm.reportNodeResult(result, true) - tm.ChangeNodeStatus(ctx, node.Key, Failed, result) - return + case tm.taskQueue <- t: default: - ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) - if tm.dag.server.SyncMode() { - result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key)) - if result.Error != nil { - tm.reportNodeResult(result, true) - tm.ChangeNodeStatus(ctx, node.Key, Failed, result) - return + log.Println("task queue is full, dropping task.") + tm.deferredTasks.Set(taskID, t) + } +} + +func (tm *TaskManager) run() { + for { + select { + case <-tm.stopCh: + log.Println("Stopping TaskManager") + return + case task := <-tm.taskQueue: + tm.processNode(task) + } + } +} + +func (tm *TaskManager) waitForResult() { + for { + select { + case <-tm.stopCh: + log.Println("Stopping Result Listener") + return + case nr := <-tm.resultQueue: + tm.onNodeCompleted(nr) + } + } +} + +func (tm *TaskManager) processNode(exec *task) { + pureNodeID := strings.Split(exec.nodeID, Delimiter)[0] + node, exists := tm.dag.nodes.Get(pureNodeID) + if !exists { + log.Printf("Node %s does not exist while processing node\n", pureNodeID) + return + } + state, _ := tm.taskStates.Get(exec.nodeID) + if state == nil { + log.Printf("State for node %s not found; creating new state.\n", exec.nodeID) + state = newTaskState(exec.nodeID) + tm.taskStates.Set(exec.nodeID, state) + } + state.Status = mq.Processing + state.UpdatedAt = time.Now() + tm.currentNodePayload.Clear() + tm.currentNodeResult.Clear() + tm.currentNodePayload.Set(exec.nodeID, exec.payload) + result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID)) + tm.currentNodeResult.Set(exec.nodeID, result) + state.Result = result + result.Topic = node.ID + if result.Error != nil { + tm.result = &result + tm.resultCh <- result + tm.processFinalResult(state) + return + } + if node.NodeType == Page { + tm.result = &result + tm.resultCh <- result + return + } + tm.handleNext(exec.ctx, node, state, result) +} + +func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) { + state.targetResults.Set(childNode, result) + state.targetResults.Del(state.NodeID) + targetsCount, _ := tm.childNodes.Get(state.NodeID) + size := state.targetResults.Size() + nodeID := strings.Split(state.NodeID, Delimiter) + if size == targetsCount { + if size > 1 { + aggregatedData := make([]json.RawMessage, size) + i := 0 + state.targetResults.ForEach(func(_ string, rs mq.Result) bool { + aggregatedData[i] = rs.Payload + i++ + return true + }) + aggregatedPayload, err := json.Marshal(aggregatedData) + if err != nil { + panic(err) } - return - } - err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key) - if err != nil { - tm.reportNodeResult(mq.Result{Error: err}, true) - tm.ChangeNodeStatus(ctx, node.Key, Failed, result) - return + state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID} + } else if size == 1 { + state.Result = state.targetResults.Values()[0] } + state.Status = result.Status + state.Result.Status = result.Status } -} - -func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { - defer mq.RecoverPanic(mq.RecoverTitle) - node, ok := tm.dag.nodes[nodeID] - if !ok { - return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} + if state.Result.Payload == nil { + state.Result.Payload = result.Payload } - if tm.createdAt.IsZero() { - tm.createdAt = time.Now() - } - tm.wg.Add(1) - go func() { - ctxx := context.Background() - if headers, ok := mq.GetHeaders(ctx); ok { - headers.Set(consts.QueueKey, node.Key) - headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0)) - ctxx = mq.SetHeaders(ctxx, headers.AsMap()) - } - go tm.processNode(ctx, node, payload) - }() - tm.wg.Wait() - requestType, ok := mq.GetHeader(ctx, "request_type") - if ok && requestType == "render" { - return tm.renderResult(ctx) - } - return tm.dispatchFinalResult(ctx) -} - -func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result { - tm.topic = result.Topic - defer func() { - tm.wg.Done() - mq.RecoverPanic(mq.RecoverTitle) - }() - if result.Ctx != nil { - if headers, ok := mq.GetHeaders(ctx); ok { - ctx = mq.SetHeaders(result.Ctx, headers.AsMap()) - } - } - node, ok := tm.dag.nodes[result.Topic] - if !ok { - return result + state.UpdatedAt = time.Now() + if result.Ctx == nil { + result.Ctx = ctx } if result.Error != nil { - tm.reportNodeResult(result, true) - tm.ChangeNodeStatus(ctx, node.Key, Failed, result) - return result + state.Status = mq.Failed } - edges := tm.getConditionalEdges(node, result) - if len(edges) == 0 { - tm.reportNodeResult(result, true) - tm.ChangeNodeStatus(ctx, node.Key, Completed, result) - return result - } else { - tm.reportNodeResult(result, false) - } - if node.Type == Page { - return result - } - for _, edge := range edges { - switch edge.Type { - case Iterator: - var items []json.RawMessage - err := json.Unmarshal(result.Payload, &items) - if err != nil { - tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false) - result.Error = err - tm.ChangeNodeStatus(ctx, node.Key, Failed, result) - return result + pn, ok := tm.parentNodes.Get(state.NodeID) + if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed { + state.Status = mq.Processing + tm.iteratorNodes.Del(nodeID[0]) + state.targetResults.Clear() + if len(nodeID) == 2 { + ctx = context.WithValue(ctx, ContextIndex, nodeID[1]) + } + toProcess := nodeResult{ + ctx: ctx, + nodeID: state.NodeID, + status: state.Status, + result: state.Result, + } + tm.handleEdges(toProcess, edges) + } else if ok { + if targetsCount == size { + parentState, _ := tm.taskStates.Get(pn) + if parentState != nil { + state.Result.Topic = state.NodeID + tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal) } - tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To)) - for _, target := range edge.To { - for i, item := range items { - tm.wg.Add(1) - go func(ctx context.Context, target *Node, item json.RawMessage, i int) { - ctxx := context.Background() - if headers, ok := mq.GetHeaders(ctx); ok { - headers.Set(consts.QueueKey, target.Key) - headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i)) - ctxx = mq.SetHeaders(ctxx, headers.AsMap()) - } - tm.processNode(ctxx, target, item) - }(ctx, target, item, i) + } + } else { + tm.result = &state.Result + state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0] + tm.resultCh <- state.Result + tm.processFinalResult(state) + } +} + +func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) { + state.UpdatedAt = time.Now() + if result.Ctx == nil { + result.Ctx = ctx + } + if result.Error != nil { + state.Status = mq.Failed + } else { + edges := tm.getConditionalEdges(node, result) + if len(edges) == 0 { + state.Status = mq.Completed + } + } + if result.Status == "" { + result.Status = state.Status + } + select { + case tm.resultQueue <- nodeResult{ + ctx: ctx, + nodeID: state.NodeID, + result: result, + status: state.Status, + }: + default: + log.Println("Result queue is full, dropping result.") + } +} + +func (tm *TaskManager) onNodeCompleted(rs nodeResult) { + nodeID := strings.Split(rs.nodeID, Delimiter)[0] + node, ok := tm.dag.nodes.Get(nodeID) + if !ok { + return + } + edges := tm.getConditionalEdges(node, rs.result) + hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0 + if hasErrorOrCompleted { + if index, ok := rs.ctx.Value(ContextIndex).(string); ok { + childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index) + pn, ok := tm.parentNodes.Get(childNode) + if ok { + parentState, _ := tm.taskStates.Get(pn) + if parentState != nil { + pn = strings.Split(pn, Delimiter)[0] + tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true) } } } + return } - for _, edge := range edges { - switch edge.Type { - case Simple: - if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok { - continue - } - tm.processEdge(ctx, edge, result) - } - } - return result -} - -func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) { - tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To)) - index, _ := mq.GetHeader(ctx, "index") - if index != "" && strings.Contains(index, "__") { - index = strings.Split(index, "__")[1] - } else { - index = "0" - } - for _, target := range edge.To { - tm.wg.Add(1) - go func(ctx context.Context, target *Node, result mq.Result) { - ctxx := context.Background() - if headers, ok := mq.GetHeaders(ctx); ok { - headers.Set(consts.QueueKey, target.Key) - headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index)) - ctxx = mq.SetHeaders(ctxx, headers.AsMap()) - } - tm.processNode(ctxx, target, result.Payload) - }(ctx, target, result) - } + tm.handleEdges(rs, edges) } func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge { edges := make([]Edge, len(node.Edges)) copy(edges, node.Edges) if result.ConditionStatus != "" { - if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok { - if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok { - if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok { - edges = append(edges, Edge{From: node, To: []*Node{targetNode}}) + if conditions, ok := tm.dag.conditions[result.Topic]; ok { + if targetNodeKey, ok := conditions[result.ConditionStatus]; ok { + if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok { + edges = append(edges, Edge{From: node, To: targetNode}) } } else if targetNodeKey, ok = conditions["default"]; ok { - if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok { - edges = append(edges, Edge{From: node, To: []*Node{targetNode}}) + if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok { + edges = append(edges, Edge{From: node, To: targetNode}) } } } @@ -252,123 +312,77 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge return edges } -func (tm *TaskManager) renderResult(ctx context.Context) mq.Result { - var rs mq.Result - tm.updateTS(&rs) - tm.dag.callbackToConsumer(ctx, rs) - tm.topic = rs.Topic - return rs -} - -func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) { - topic := nodeID - if !strings.Contains(nodeID, "__") { - nodeID = getTopic(ctx, nodeID) - } else { - topic = strings.Split(nodeID, "__")[0] - } - nodeStatus, ok := tm.taskNodeStatus.Get(nodeID) - if !ok || nodeStatus == nil { - return - } - - nodeStatus.markAs(rs, status) - switch status { - case Completed: - canProceed := false - edges, ok := tm.iteratorNodes.Get(topic) - if ok { - if len(edges) == 0 { - canProceed = true - } else { - nodeStatus.status = Processing - nodeStatus.totalItems = 1 - nodeStatus.itemResults.Clear() - for _, edge := range edges { - tm.processEdge(ctx, edge, rs) +func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) { + for _, edge := range edges { + index, ok := currentResult.ctx.Value(ContextIndex).(string) + if !ok { + index = "0" + } + parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index) + if edge.Type == Simple { + if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok { + continue + } + } + if edge.Type == Iterator { + var items []json.RawMessage + err := json.Unmarshal(currentResult.result.Payload, &items) + if err != nil { + log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err) + tm.resultQueue <- nodeResult{ + ctx: currentResult.ctx, + nodeID: edge.To.ID, + status: mq.Failed, + result: mq.Result{Error: err}, } - tm.iteratorNodes.Del(topic) + return } - } - if canProceed || !ok { - if topic == tm.dag.startNode { - tm.result = rs - } else { - tm.markParentTask(ctx, topic, nodeID, status, rs) + tm.childNodes.Set(parentNode, len(items)) + for i, item := range items { + childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i) + ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i)) + tm.parentNodes.Set(childNode, parentNode) + tm.send(ctx, edge.To.ID, tm.taskID, item) } - } - case Failed: - if topic == tm.dag.startNode { - tm.result = rs } else { - tm.markParentTask(ctx, topic, nodeID, status, rs) - } - } -} - -func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) { - parentNodes, err := tm.dag.GetPreviousNodes(topic) - if err != nil { - return - } - var index string - nodeParts := strings.Split(nodeID, "__") - if len(nodeParts) == 2 { - index = nodeParts[1] - } - for _, parentNode := range parentNodes { - parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index) - parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey) - if !exists { - parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0") - parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey) - } - if exists { - parentNodeStatus.itemResults.Set(nodeID, rs) - if parentNodeStatus.IsDone() { - rt := tm.prepareResult(ctx, parentNodeStatus) - tm.ChangeNodeStatus(ctx, parentKey, status, rt) + tm.childNodes.Set(parentNode, 1) + idx, ok := currentResult.ctx.Value(ContextIndex).(string) + if !ok { + idx = "0" } + childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx) + ctx := context.WithValue(currentResult.ctx, ContextIndex, idx) + tm.parentNodes.Set(childNode, parentNode) + tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload) } } } -func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result { - aggregatedOutput := make([]json.RawMessage, 0) - var status mq.Status - var topic string - var err1 error - if nodeStatus.totalItems == 1 { - rs := nodeStatus.itemResults.Values()[0] - if rs.Ctx == nil { - rs.Ctx = ctx +func (tm *TaskManager) retryDeferredTasks() { + const maxRetries = 5 + retries := 0 + for retries < maxRetries { + select { + case <-tm.stopCh: + log.Println("Stopping Deferred task Retrier") + return + case <-time.After(RetryInterval): + tm.deferredTasks.ForEach(func(taskID string, task *task) bool { + tm.send(task.ctx, task.nodeID, taskID, task.payload) + retries++ + return true + }) } - return rs } - nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool { - if topic == "" { - topic = result.Topic - status = result.Status - } - if result.Error != nil { - err1 = result.Error - return false - } - var item json.RawMessage - err := json.Unmarshal(result.Payload, &item) - if err != nil { - err1 = err - return false - } - aggregatedOutput = append(aggregatedOutput, item) - return true - }) - if err1 != nil { - return mq.HandleError(ctx, err1) - } - finalOutput, err := json.Marshal(aggregatedOutput) - if err != nil { - return mq.HandleError(ctx, err) - } - return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx} +} + +func (tm *TaskManager) processFinalResult(state *TaskState) { + state.targetResults.Clear() + if tm.dag.finalResult != nil { + tm.dag.finalResult(tm.taskID, state.Result) + } +} + +func (tm *TaskManager) Stop() { + close(tm.stopCh) } diff --git a/dag/ui.go b/dag/ui.go index d2074af..388bbe2 100644 --- a/dag/ui.go +++ b/dag/ui.go @@ -9,25 +9,24 @@ import ( func (tm *DAG) PrintGraph() { fmt.Println("DAG Graph structure:") - for _, node := range tm.nodes { - fmt.Printf("Node: %s (%s) -> ", node.Name, node.Key) - if conditions, ok := tm.conditions[FromNode(node.Key)]; ok { + tm.nodes.ForEach(func(_ string, node *Node) bool { + fmt.Printf("Node: %s (%s) -> ", node.Label, node.ID) + if conditions, ok := tm.conditions[node.ID]; ok { var c []string for when, then := range conditions { - if target, ok := tm.nodes[string(then)]; ok { - c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Name, target.Key)) + if target, ok := tm.nodes.Get(then); ok { + c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Label, target.ID)) } } fmt.Println(strings.Join(c, ", ")) } var edges []string - for _, edge := range node.Edges { - for _, target := range edge.To { - edges = append(edges, fmt.Sprintf("%s (%s)", target.Name, target.Key)) - } + for _, target := range node.Edges { + edges = append(edges, fmt.Sprintf("%s (%s)", target.To.Label, target.To.ID)) } fmt.Println(strings.Join(edges, ", ")) - } + return true + }) } func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) { @@ -44,7 +43,7 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) { if startNode == "" { firstNode := tm.findStartNode() if firstNode != nil { - startNode = firstNode.Key + startNode = firstNode.ID } } if startNode == "" { @@ -62,24 +61,22 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim inRecursionStack[v] = true // mark node as part of recursion stack *timeVal++ discoveryTime[v] = *timeVal - node := tm.nodes[v] + node, _ := tm.nodes.Get(v) hasCycle := false var err error for _, edge := range node.Edges { - for _, adj := range edge.To { - if !visited[adj.Key] { - builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key)) - hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) - if err != nil { - return true, err - } - if hasCycle { - return true, nil - } - } else if inRecursionStack[adj.Key] { - cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key) - return true, fmt.Errorf(cycleMsg) + if !visited[edge.To.ID] { + builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, edge.To.ID)) + hasCycle, err := tm.dfs(edge.To.ID, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) + if err != nil { + return true, err } + if hasCycle { + return true, nil + } + } else if inRecursionStack[edge.To.ID] { + cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, edge.To.ID) + return true, fmt.Errorf(cycleMsg) } } hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) @@ -93,20 +90,20 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim } func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) { - node := tm.nodes[v] - for when, then := range tm.conditions[FromNode(node.Key)] { - if targetNode, ok := tm.nodes[string(then)]; ok { - if !visited[targetNode.Key] { - builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)) - hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder) + node, _ := tm.nodes.Get(v) + for when, then := range tm.conditions[node.ID] { + if targetNode, ok := tm.nodes.Get(then); ok { + if !visited[targetNode.ID] { + builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID)) + hasCycle, err := tm.dfs(targetNode.ID, visited, discoveryTime, finishedTime, time, inRecursionStack, builder) if err != nil { return true, err } if hasCycle { return true, nil } - } else if inRecursionStack[targetNode.Key] { - cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key) + } else if inRecursionStack[targetNode.ID] { + cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID) return true, fmt.Errorf(cycleMsg) } } @@ -146,98 +143,113 @@ func (tm *DAG) ExportDOT() string { var sb strings.Builder sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name)) sb.WriteString("\n") - sb.WriteString(fmt.Sprintf(` label="%s";`, tm.name)) + sb.WriteString(` label="Enhanced DAG Representation";`) sb.WriteString("\n") - sb.WriteString(` labelloc="t";`) + sb.WriteString(` labelloc="t"; fontsize=22; fontname="Helvetica";`) sb.WriteString("\n") - sb.WriteString(` fontsize=20;`) + sb.WriteString(` node [shape=box, fontname="Helvetica", fillcolor="#B3CDE0", fontcolor="#2C3E50", fontsize=10, margin="0.25,0.15", style="rounded,filled"];`) sb.WriteString("\n") - sb.WriteString(` node [shape=box, style="rounded,filled", fillcolor="lightgray", fontname="Arial", margin="0.2,0.1"];`) + sb.WriteString(` edge [fontname="Helvetica", fontsize=12, arrowsize=0.8];`) sb.WriteString("\n") - sb.WriteString(` edge [fontname="Arial", fontsize=12, arrowsize=0.8];`) - sb.WriteString("\n") - sb.WriteString(` size="10,10";`) - sb.WriteString("\n") - sb.WriteString(` ratio="fill";`) + sb.WriteString(` rankdir=TB;`) sb.WriteString("\n") sortedNodes := tm.TopologicalSort() for _, nodeKey := range sortedNodes { - node := tm.nodes[nodeKey] - nodeColor := "lightblue" - sb.WriteString(fmt.Sprintf(` "%s" [label=" %s", fillcolor="%s", id="node_%s"];`, node.Key, node.Name, nodeColor, node.Key)) + node, _ := tm.nodes.Get(nodeKey) + nodeColor := "lightgray" + nodeShape := "box" + labelSuffix := "" + + // Apply styles based on NodeType + switch node.NodeType { + case Function: + nodeColor = "#D4EDDA" + labelSuffix = " [Function]" + case Page: + nodeColor = "#f0d2d1" + labelSuffix = " [Page]" + } + sb.WriteString(fmt.Sprintf( + ` "%s" [label="%s%s", fontcolor="#2C3E50", fillcolor="%s", shape="%s", style="rounded,filled", id="node_%s"];`, + node.ID, node.Label, labelSuffix, nodeColor, nodeShape, node.ID)) sb.WriteString("\n") } + + // Define edges with unique styling by EdgeType for _, nodeKey := range sortedNodes { - node := tm.nodes[nodeKey] + node, _ := tm.nodes.Get(nodeKey) for _, edge := range node.Edges { - var edgeStyle string + edgeStyle := "solid" + edgeColor := "black" + labelSuffix := "" + + // Apply styles based on EdgeType switch edge.Type { case Iterator: edgeStyle = "dashed" - default: + edgeColor = "blue" + labelSuffix = " [Iter]" + case Simple: edgeStyle = "solid" + edgeColor = "black" + labelSuffix = "" } - edgeColor := "black" - for _, to := range edge.To { - sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="%s", style=%s, fontsize=10, arrowsize=0.6];`, node.Key, to.Key, edge.Label, edgeColor, edgeStyle)) - sb.WriteString("\n") - } + sb.WriteString(fmt.Sprintf( + ` "%s" -> "%s" [label="%s%s", color="%s", style="%s"];`, + node.ID, edge.To.ID, edge.Label, labelSuffix, edgeColor, edgeStyle)) + sb.WriteString("\n") } } for fromNodeKey, conditions := range tm.conditions { for when, then := range conditions { - if toNode, ok := tm.nodes[string(then)]; ok { - sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.Key, when)) + if toNode, ok := tm.nodes.Get(then); ok { + sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.ID, when)) sb.WriteString("\n") } } } + + // Optional: Group related nodes into subgraphs (e.g., loops) for _, nodeKey := range sortedNodes { - node := tm.nodes[nodeKey] + node, _ := tm.nodes.Get(nodeKey) if node.processor != nil { subDAG, _ := isDAGNode(node) if subDAG != nil { sb.WriteString(fmt.Sprintf(` subgraph "cluster_%s" {`, subDAG.name)) sb.WriteString("\n") - sb.WriteString(fmt.Sprintf(` label=" %s";`, subDAG.name)) + sb.WriteString(fmt.Sprintf(` label="Subgraph: %s";`, subDAG.name)) sb.WriteString("\n") - sb.WriteString(` style=dashed;`) + sb.WriteString(` style=filled; color=gray90;`) sb.WriteString("\n") - sb.WriteString(` bgcolor="lightgray";`) - sb.WriteString("\n") - sb.WriteString(` node [shape=rectangle, style="filled", fillcolor="lightblue", fontname="Arial", margin="0.2,0.1"];`) - sb.WriteString("\n") - for subNodeKey, subNode := range subDAG.nodes { - sb.WriteString(fmt.Sprintf(` "%s" [label=" %s"];`, subNodeKey, subNode.Name)) + subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool { + sb.WriteString(fmt.Sprintf(` "%s" [label="%s"];`, subNode.ID, subNode.Label)) sb.WriteString("\n") - } - for subNodeKey, subNode := range subDAG.nodes { + return true + }) + subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool { for _, edge := range subNode.Edges { - for _, to := range edge.To { - sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, subNodeKey, to.Key, edge.Label)) - sb.WriteString("\n") - } + sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label="%s"];`, subNodeKey, edge.To.ID, edge.Label)) + sb.WriteString("\n") } - } - sb.WriteString(` }`) - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, node.Key, subDAG.startNode, subDAG.name)) - sb.WriteString("\n") + return true + }) + sb.WriteString(" }\n") } } } - sb.WriteString(`}`) - sb.WriteString("\n") + + sb.WriteString("}\n") return sb.String() } func (tm *DAG) TopologicalSort() (stack []string) { visited := make(map[string]bool) - for _, node := range tm.nodes { - if !visited[node.Key] { - tm.topologicalSortUtil(node.Key, visited, &stack) + tm.nodes.ForEach(func(_ string, node *Node) bool { + if !visited[node.ID] { + tm.topologicalSortUtil(node.ID, visited, &stack) } - } + return true + }) for i, j := 0, len(stack)-1; i < j; i, j = i+1, j-1 { stack[i], stack[j] = stack[j], stack[i] } @@ -246,13 +258,23 @@ func (tm *DAG) TopologicalSort() (stack []string) { func (tm *DAG) topologicalSortUtil(v string, visited map[string]bool, stack *[]string) { visited[v] = true - node := tm.nodes[v] + node, ok := tm.nodes.Get(v) + if !ok { + fmt.Println("Not found", v) + } for _, edge := range node.Edges { - for _, to := range edge.To { - if !visited[to.Key] { - tm.topologicalSortUtil(to.Key, visited, stack) - } + if !visited[edge.To.ID] { + tm.topologicalSortUtil(edge.To.ID, visited, stack) } } *stack = append(*stack, v) } + +func isDAGNode(node *Node) (*DAG, bool) { + switch node := node.processor.(type) { + case *DAG: + return node, true + default: + return nil, false + } +} diff --git a/dag/v1/api.go b/dag/v1/api.go new file mode 100644 index 0000000..bf6be8f --- /dev/null +++ b/dag/v1/api.go @@ -0,0 +1,209 @@ +package v1 + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + "github.com/oarkflow/mq/jsonparser" + "github.com/oarkflow/mq/sio" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/metrics" +) + +type Request struct { + Payload json.RawMessage `json:"payload"` + Interval time.Duration `json:"interval"` + Schedule bool `json:"schedule"` + Overlap bool `json:"overlap"` + Recurring bool `json:"recurring"` +} + +func (tm *DAG) SetupWS() *sio.Server { + ws := sio.New(sio.Config{ + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: true, + }) + WsEvents(ws) + tm.Notifier = ws + return ws +} + +func (tm *DAG) Handlers() { + metrics.HandleHTTP() + http.Handle("/", http.FileServer(http.Dir("webroot"))) + http.Handle("/notify", tm.SetupWS()) + http.HandleFunc("GET /render", tm.Render) + http.HandleFunc("POST /request", tm.Request) + http.HandleFunc("POST /publish", tm.Publish) + http.HandleFunc("POST /schedule", tm.Schedule) + http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { + id := request.PathValue("id") + if id != "" { + tm.PauseConsumer(request.Context(), id) + } + }) + http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { + id := request.PathValue("id") + if id != "" { + tm.ResumeConsumer(request.Context(), id) + } + }) + http.HandleFunc("/pause", func(w http.ResponseWriter, request *http.Request) { + err := tm.Pause(request.Context()) + if err != nil { + http.Error(w, "Failed to pause", http.StatusBadRequest) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "paused"}) + }) + http.HandleFunc("/resume", func(w http.ResponseWriter, request *http.Request) { + err := tm.Resume(request.Context()) + if err != nil { + http.Error(w, "Failed to resume", http.StatusBadRequest) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "resumed"}) + }) + http.HandleFunc("/stop", func(w http.ResponseWriter, request *http.Request) { + err := tm.Stop(request.Context()) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "stopped"}) + }) + http.HandleFunc("/close", func(w http.ResponseWriter, request *http.Request) { + err := tm.Close() + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "closed"}) + }) + http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprintln(w, tm.ExportDOT()) + }) + http.HandleFunc("/ui", func(w http.ResponseWriter, r *http.Request) { + image := fmt.Sprintf("%s.svg", mq.NewID()) + err := tm.SaveSVG(image) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + defer os.Remove(image) + svgBytes, err := os.ReadFile(image) + if err != nil { + http.Error(w, "Could not read SVG file", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "image/svg+xml") + if _, err := w.Write(svgBytes); err != nil { + http.Error(w, "Could not write SVG response", http.StatusInternalServerError) + return + } + }) +} + +func (tm *DAG) request(w http.ResponseWriter, r *http.Request, async bool) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) + return + } + var request Request + if r.Body != nil { + defer r.Body.Close() + var err error + payload, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + err = json.Unmarshal(payload, &request) + if err != nil { + http.Error(w, "Failed to unmarshal body", http.StatusBadRequest) + return + } + } else { + http.Error(w, "Empty request body", http.StatusBadRequest) + return + } + ctx := r.Context() + if async { + ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) + } + var opts []mq.SchedulerOption + if request.Interval > 0 { + opts = append(opts, mq.WithInterval(request.Interval)) + } + if request.Overlap { + opts = append(opts, mq.WithOverlap()) + } + if request.Recurring { + opts = append(opts, mq.WithRecurring()) + } + ctx = context.WithValue(ctx, "query_params", r.URL.Query()) + var rs mq.Result + if request.Schedule { + rs = tm.ScheduleTask(ctx, request.Payload, opts...) + } else { + rs = tm.Process(ctx, request.Payload) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(rs) +} + +func (tm *DAG) Render(w http.ResponseWriter, r *http.Request) { + ctx := mq.SetHeaders(r.Context(), map[string]string{consts.AwaitResponseKey: "true", "request_type": "render"}) + ctx = context.WithValue(ctx, "query_params", r.URL.Query()) + rs := tm.Process(ctx, nil) + content, err := jsonparser.GetString(rs.Payload, "html_content") + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", consts.TypeHtml) + w.Write([]byte(content)) +} + +func (tm *DAG) Request(w http.ResponseWriter, r *http.Request) { + tm.request(w, r, true) +} + +func (tm *DAG) Publish(w http.ResponseWriter, r *http.Request) { + tm.request(w, r, false) +} + +func (tm *DAG) Schedule(w http.ResponseWriter, r *http.Request) { + tm.request(w, r, false) +} + +func GetTaskID(ctx context.Context) string { + if queryParams := ctx.Value("query_params"); queryParams != nil { + if params, ok := queryParams.(url.Values); ok { + if id := params.Get("taskID"); id != "" { + return id + } + } + } + return "" +} + +func CanNextNode(ctx context.Context) string { + if queryParams := ctx.Value("query_params"); queryParams != nil { + if params, ok := queryParams.(url.Values); ok { + if id := params.Get("next"); id != "" { + return id + } + } + } + return "" +} diff --git a/dag/v1/consts.go b/dag/v1/consts.go new file mode 100644 index 0000000..ef48130 --- /dev/null +++ b/dag/v1/consts.go @@ -0,0 +1,26 @@ +package v1 + +type NodeStatus int + +func (c NodeStatus) IsValid() bool { return c >= Pending && c <= Failed } + +func (c NodeStatus) String() string { + switch c { + case Pending: + return "Pending" + case Processing: + return "Processing" + case Completed: + return "Completed" + case Failed: + return "Failed" + } + return "" +} + +const ( + Pending NodeStatus = iota + Processing + Completed + Failed +) diff --git a/dag/v1/dag.go b/dag/v1/dag.go new file mode 100644 index 0000000..98b38bd --- /dev/null +++ b/dag/v1/dag.go @@ -0,0 +1,612 @@ +package v1 + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + "github.com/oarkflow/mq/storage" + "github.com/oarkflow/mq/storage/memory" + + "github.com/oarkflow/mq/sio" + + "golang.org/x/time/rate" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/metrics" +) + +type EdgeType int + +func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator } + +const ( + Simple EdgeType = iota + Iterator +) + +type NodeType int + +func (c NodeType) IsValid() bool { return c >= Process && c <= Page } + +const ( + Process NodeType = iota + Page +) + +type Node struct { + processor mq.Processor + Name string + Type NodeType + Key string + Edges []Edge + isReady bool +} + +func (n *Node) ProcessTask(ctx context.Context, msg *mq.Task) mq.Result { + return n.processor.ProcessTask(ctx, msg) +} + +func (n *Node) Close() error { + return n.processor.Close() +} + +type Edge struct { + Label string + From *Node + To []*Node + Type EdgeType +} + +type ( + FromNode string + When string + Then string +) + +type DAG struct { + server *mq.Broker + consumer *mq.Consumer + taskContext storage.IMap[string, *TaskManager] + nodes map[string]*Node + iteratorNodes storage.IMap[string, []Edge] + conditions map[FromNode]map[When]Then + pool *mq.Pool + taskCleanupCh chan string + name string + key string + startNode string + consumerTopic string + opts []mq.Option + reportNodeResultCallback func(mq.Result) + Notifier *sio.Server + paused bool + Error error + report string + index string +} + +func (tm *DAG) SetKey(key string) { + tm.key = key +} + +func (tm *DAG) ReportNodeResult(callback func(mq.Result)) { + tm.reportNodeResultCallback = callback +} + +func (tm *DAG) GetType() string { + return tm.key +} + +func (tm *DAG) listenForTaskCleanup() { + for taskID := range tm.taskCleanupCh { + if tm.server.Options().CleanTaskOnComplete() { + tm.taskCleanup(taskID) + } + } +} + +func (tm *DAG) taskCleanup(taskID string) { + tm.taskContext.Del(taskID) + log.Printf("DAG - Task %s cleaned up", taskID) +} + +func (tm *DAG) Consume(ctx context.Context) error { + if tm.consumer != nil { + tm.server.Options().SetSyncMode(true) + return tm.consumer.Consume(ctx) + } + return nil +} + +func (tm *DAG) Stop(ctx context.Context) error { + for _, n := range tm.nodes { + err := n.processor.Stop(ctx) + if err != nil { + return err + } + } + return nil +} + +func (tm *DAG) GetKey() string { + return tm.key +} + +func (tm *DAG) AssignTopic(topic string) { + tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL())) + tm.consumerTopic = topic +} + +func NewDAG(name, key string, opts ...mq.Option) *DAG { + callback := func(ctx context.Context, result mq.Result) error { return nil } + d := &DAG{ + name: name, + key: key, + nodes: make(map[string]*Node), + iteratorNodes: memory.New[string, []Edge](), + taskContext: memory.New[string, *TaskManager](), + conditions: make(map[FromNode]map[When]Then), + taskCleanupCh: make(chan string), + } + opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) + d.server = mq.NewBroker(opts...) + d.opts = opts + options := d.server.Options() + d.pool = mq.NewPool( + options.NumOfWorkers(), + mq.WithTaskQueueSize(options.QueueSize()), + mq.WithMaxMemoryLoad(options.MaxMemoryLoad()), + mq.WithHandler(d.ProcessTask), + mq.WithPoolCallback(callback), + mq.WithTaskStorage(options.Storage()), + ) + d.pool.Start(d.server.Options().NumOfWorkers()) + go d.listenForTaskCleanup() + return d +} + +func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) { + if tm.consumer != nil { + result.Topic = tm.consumerTopic + if tm.consumer.Conn() == nil { + tm.onTaskCallback(ctx, result) + } else { + tm.consumer.OnResponse(ctx, result) + } + } +} + +func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { + if taskContext, ok := tm.taskContext.Get(result.TaskID); ok && result.Topic != "" { + return taskContext.handleNextTask(ctx, result) + } + return mq.Result{} +} + +func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { + if node, ok := tm.nodes[topic]; ok { + log.Printf("DAG - CONSUMER ~> ready on %s", topic) + node.isReady = true + } +} + +func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) { + if node, ok := tm.nodes[topic]; ok { + log.Printf("DAG - CONSUMER ~> down on %s", topic) + node.isReady = false + } +} + +func (tm *DAG) SetStartNode(node string) { + tm.startNode = node +} + +func (tm *DAG) GetStartNode() string { + return tm.startNode +} + +func (tm *DAG) Start(ctx context.Context, addr string) error { + // Start the server in a separate goroutine + go func() { + defer mq.RecoverPanic(mq.RecoverTitle) + if err := tm.server.Start(ctx); err != nil { + panic(err) + } + }() + + // Start the node consumers if not in sync mode + if !tm.server.SyncMode() { + for _, con := range tm.nodes { + go func(con *Node) { + defer mq.RecoverPanic(mq.RecoverTitle) + limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second + for { + err := con.processor.Consume(ctx) + if err != nil { + log.Printf("[ERROR] - Consumer %s failed to start: %v", con.Key, err) + } else { + log.Printf("[INFO] - Consumer %s started successfully", con.Key) + break + } + limiter.Wait(ctx) // Wait with rate limiting before retrying + } + }(con) + } + } + log.Printf("DAG - HTTP_SERVER ~> started on http://localhost%s", addr) + tm.Handlers() + config := tm.server.TLSConfig() + if config.UseTLS { + return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) + } + return http.ListenAndServe(addr, nil) +} + +func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG { + dag.AssignTopic(key) + tm.nodes[key] = &Node{ + Name: name, + Key: key, + processor: dag, + isReady: true, + } + if len(firstNode) > 0 && firstNode[0] { + tm.startNode = key + } + return tm +} + +func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG { + con := mq.NewConsumer(key, key, handler.ProcessTask, tm.opts...) + n := &Node{ + Name: name, + Key: key, + processor: con, + } + if handler.GetType() == "page" { + n.Type = Page + } + if tm.server.SyncMode() { + n.isReady = true + } + tm.nodes[key] = n + if len(firstNode) > 0 && firstNode[0] { + tm.startNode = key + } + return tm +} + +func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error { + if tm.server.SyncMode() { + return fmt.Errorf("DAG cannot have deferred node in Sync Mode") + } + tm.nodes[key] = &Node{ + Name: name, + Key: key, + } + if len(firstNode) > 0 && firstNode[0] { + tm.startNode = key + } + return nil +} + +func (tm *DAG) IsReady() bool { + for _, node := range tm.nodes { + if !node.isReady { + return false + } + } + return true +} + +func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG { + tm.conditions[fromNode] = conditions + return tm +} + +func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG { + tm.Error = tm.addEdge(Iterator, label, from, targets...) + tm.iteratorNodes.Set(from, []Edge{}) + return tm +} + +func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG { + tm.Error = tm.addEdge(Simple, label, from, targets...) + return tm +} + +func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error { + fromNode, ok := tm.nodes[from] + if !ok { + return fmt.Errorf("Error: 'from' node %s does not exist\n", from) + } + var nodes []*Node + for _, target := range targets { + toNode, ok := tm.nodes[target] + if !ok { + return fmt.Errorf("Error: 'from' node %s does not exist\n", target) + } + nodes = append(nodes, toNode) + } + edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label} + fromNode.Edges = append(fromNode.Edges, edge) + if edgeType != Iterator { + if edges, ok := tm.iteratorNodes.Get(fromNode.Key); ok { + edges = append(edges, edge) + tm.iteratorNodes.Set(fromNode.Key, edges) + } + } + return nil +} + +func (tm *DAG) Validate() error { + report, hasCycle, err := tm.ClassifyEdges() + if hasCycle || err != nil { + tm.Error = err + return err + } + tm.report = report + return nil +} + +func (tm *DAG) GetReport() string { + return tm.report +} + +func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + if task.ID == "" { + task.ID = mq.NewID() + } + if index, ok := mq.GetHeader(ctx, "index"); ok { + tm.index = index + } + manager, exists := tm.taskContext.Get(task.ID) + if !exists { + manager = NewTaskManager(tm, task.ID, tm.iteratorNodes) + manager.createdAt = task.CreatedAt + tm.taskContext.Set(task.ID, manager) + } + + if tm.consumer != nil { + initialNode, err := tm.parseInitialNode(ctx) + if err != nil { + metrics.TasksErrors.WithLabelValues("unknown").Inc() // Increase error count + return mq.Result{Error: err} + } + task.Topic = initialNode + } + if manager.topic != "" { + task.Topic = manager.topic + canNext := CanNextNode(ctx) + if canNext != "" { + if n, ok := tm.nodes[task.Topic]; ok { + if len(n.Edges) > 0 { + task.Topic = n.Edges[0].To[0].Key + } + } + } else { + } + } + result := manager.processTask(ctx, task.Topic, task.Payload) + if result.Ctx != nil && tm.index != "" { + result.Ctx = mq.SetHeaders(result.Ctx, map[string]string{"index": tm.index}) + } + if result.Error != nil { + metrics.TasksErrors.WithLabelValues(task.Topic).Inc() // Increase error count + } else { + metrics.TasksProcessed.WithLabelValues("success").Inc() // Increase processed task count + } + return result +} + +func (tm *DAG) check(ctx context.Context, payload []byte) (context.Context, *mq.Task, error) { + if tm.paused { + return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not accepting any task") + } + if !tm.IsReady() { + return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not ready yet") + } + initialNode, err := tm.parseInitialNode(ctx) + if err != nil { + return ctx, nil, err + } + if tm.server.SyncMode() { + ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) + } + taskID := GetTaskID(ctx) + if taskID != "" { + if _, exists := tm.taskContext.Get(taskID); !exists { + return ctx, nil, fmt.Errorf("provided task ID doesn't exist") + } + } + if taskID == "" { + taskID = mq.NewID() + } + return ctx, mq.NewTask(taskID, payload, initialNode), nil +} + +func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { + ctx, task, err := tm.check(ctx, payload) + if err != nil { + return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} + } + awaitResponse, _ := mq.GetAwaitResponse(ctx) + if awaitResponse != "true" { + headers, ok := mq.GetHeaders(ctx) + ctxx := context.Background() + if ok { + ctxx = mq.SetHeaders(ctxx, headers.AsMap()) + } + if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil { + return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "FAILED", Error: err} + } + return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"} + } + return tm.ProcessTask(ctx, task) +} + +func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result { + ctx, task, err := tm.check(ctx, payload) + if err != nil { + return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} + } + headers, ok := mq.GetHeaders(ctx) + ctxx := context.Background() + if ok { + ctxx = mq.SetHeaders(ctxx, headers.AsMap()) + } + tm.pool.Scheduler().AddTask(ctxx, task, opts...) + return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"} +} + +func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { + val := ctx.Value("initial_node") + initialNode, ok := val.(string) + if ok { + return initialNode, nil + } + if tm.startNode == "" { + firstNode := tm.findStartNode() + if firstNode != nil { + tm.startNode = firstNode.Key + } + } + + if tm.startNode == "" { + return "", fmt.Errorf("initial node not found") + } + return tm.startNode, nil +} + +func (tm *DAG) findStartNode() *Node { + incomingEdges := make(map[string]bool) + connectedNodes := make(map[string]bool) + for _, node := range tm.nodes { + for _, edge := range node.Edges { + if edge.Type.IsValid() { + connectedNodes[node.Key] = true + for _, to := range edge.To { + connectedNodes[to.Key] = true + incomingEdges[to.Key] = true + } + } + } + if cond, ok := tm.conditions[FromNode(node.Key)]; ok { + for _, target := range cond { + connectedNodes[string(target)] = true + incomingEdges[string(target)] = true + } + } + } + for nodeID, node := range tm.nodes { + if !incomingEdges[nodeID] && connectedNodes[nodeID] { + return node + } + } + return nil +} + +func (tm *DAG) Pause(_ context.Context) error { + tm.paused = true + return nil +} + +func (tm *DAG) Resume(_ context.Context) error { + tm.paused = false + return nil +} + +func (tm *DAG) Close() error { + for _, n := range tm.nodes { + err := n.Close() + if err != nil { + return err + } + } + return nil +} + +func (tm *DAG) PauseConsumer(ctx context.Context, id string) { + tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE) +} + +func (tm *DAG) ResumeConsumer(ctx context.Context, id string) { + tm.doConsumer(ctx, id, consts.CONSUMER_RESUME) +} + +func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) { + if node, ok := tm.nodes[id]; ok { + switch action { + case consts.CONSUMER_PAUSE: + err := node.processor.Pause(ctx) + if err == nil { + node.isReady = false + log.Printf("[INFO] - Consumer %s paused successfully", node.Key) + } else { + log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.Key, err) + } + case consts.CONSUMER_RESUME: + err := node.processor.Resume(ctx) + if err == nil { + node.isReady = true + log.Printf("[INFO] - Consumer %s resumed successfully", node.Key) + } else { + log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.Key, err) + } + } + } else { + log.Printf("[WARNING] - Consumer %s not found", id) + } +} + +func (tm *DAG) SetNotifyResponse(callback mq.Callback) { + tm.server.SetNotifyHandler(callback) +} + +func (tm *DAG) GetNextNodes(key string) ([]*Node, error) { + node, exists := tm.nodes[key] + if !exists { + return nil, fmt.Errorf("Node with key %s does not exist", key) + } + var successors []*Node + for _, edge := range node.Edges { + successors = append(successors, edge.To...) + } + if conds, exists := tm.conditions[FromNode(key)]; exists { + for _, targetKey := range conds { + if targetNode, exists := tm.nodes[string(targetKey)]; exists { + successors = append(successors, targetNode) + } + } + } + return successors, nil +} + +func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) { + var predecessors []*Node + for _, node := range tm.nodes { + for _, edge := range node.Edges { + for _, target := range edge.To { + if target.Key == key { + predecessors = append(predecessors, node) + } + } + } + } + for fromNode, conds := range tm.conditions { + for _, targetKey := range conds { + if string(targetKey) == key { + node, exists := tm.nodes[string(fromNode)] + if !exists { + return nil, fmt.Errorf("Node with key %s does not exist", fromNode) + } + predecessors = append(predecessors, node) + } + } + } + return predecessors, nil +} diff --git a/dag/v2/operation.go b/dag/v1/operation.go similarity index 99% rename from dag/v2/operation.go rename to dag/v1/operation.go index f926bad..9500bb6 100644 --- a/dag/v2/operation.go +++ b/dag/v1/operation.go @@ -1,4 +1,4 @@ -package v2 +package v1 import ( "context" diff --git a/dag/v2/operations.go b/dag/v1/operations.go similarity index 99% rename from dag/v2/operations.go rename to dag/v1/operations.go index 7678fdf..a7fc07f 100644 --- a/dag/v2/operations.go +++ b/dag/v1/operations.go @@ -1,4 +1,4 @@ -package v2 +package v1 import ( "sync" diff --git a/dag/v1/task_manager.go b/dag/v1/task_manager.go new file mode 100644 index 0000000..27c8848 --- /dev/null +++ b/dag/v1/task_manager.go @@ -0,0 +1,374 @@ +package v1 + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/storage" + "github.com/oarkflow/mq/storage/memory" + + "github.com/oarkflow/mq" +) + +type TaskManager struct { + createdAt time.Time + processedAt time.Time + status string + dag *DAG + taskID string + wg *WaitGroup + topic string + result mq.Result + + iteratorNodes storage.IMap[string, []Edge] + taskNodeStatus storage.IMap[string, *taskNodeStatus] +} + +func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager { + return &TaskManager{ + dag: d, + taskNodeStatus: memory.New[string, *taskNodeStatus](), + taskID: taskID, + iteratorNodes: iteratorNodes, + wg: NewWaitGroup(), + } +} + +func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result { + tm.updateTS(&tm.result) + tm.dag.callbackToConsumer(ctx, tm.result) + if tm.dag.server.NotifyHandler() != nil { + _ = tm.dag.server.NotifyHandler()(ctx, tm.result) + } + tm.dag.taskCleanupCh <- tm.taskID + tm.topic = tm.result.Topic + return tm.result +} + +func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) { + if tm.dag.reportNodeResultCallback != nil { + tm.dag.reportNodeResultCallback(result) + } +} + +func (tm *TaskManager) SetTotalItems(topic string, i int) { + if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok { + nodeStatus.totalItems = i + } +} + +func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { + topic := getTopic(ctx, node.Key) + tm.taskNodeStatus.Set(topic, newNodeStatus(topic)) + defer mq.RecoverPanic(mq.RecoverTitle) + dag, isDAG := isDAGNode(node) + if isDAG { + if tm.dag.server.SyncMode() && !dag.server.SyncMode() { + dag.server.Options().SetSyncMode(true) + } + } + tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key}) + var result mq.Result + if tm.dag.server.SyncMode() { + defer func() { + if isDAG { + result.Topic = dag.consumerTopic + result.TaskID = tm.taskID + tm.reportNodeResult(result, false) + tm.handleNextTask(result.Ctx, result) + } else { + result.Topic = node.Key + tm.reportNodeResult(result, false) + tm.handleNextTask(ctx, result) + } + }() + } + select { + case <-ctx.Done(): + result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx} + tm.reportNodeResult(result, true) + tm.ChangeNodeStatus(ctx, node.Key, Failed, result) + return + default: + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) + if tm.dag.server.SyncMode() { + result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key)) + if result.Error != nil { + tm.reportNodeResult(result, true) + tm.ChangeNodeStatus(ctx, node.Key, Failed, result) + return + } + return + } + err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key) + if err != nil { + tm.reportNodeResult(mq.Result{Error: err}, true) + tm.ChangeNodeStatus(ctx, node.Key, Failed, result) + return + } + } +} + +func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { + defer mq.RecoverPanic(mq.RecoverTitle) + node, ok := tm.dag.nodes[nodeID] + if !ok { + return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} + } + if tm.createdAt.IsZero() { + tm.createdAt = time.Now() + } + tm.wg.Add(1) + go func() { + ctxx := context.Background() + if headers, ok := mq.GetHeaders(ctx); ok { + headers.Set(consts.QueueKey, node.Key) + headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0)) + ctxx = mq.SetHeaders(ctxx, headers.AsMap()) + } + go tm.processNode(ctx, node, payload) + }() + tm.wg.Wait() + requestType, ok := mq.GetHeader(ctx, "request_type") + if ok && requestType == "render" { + return tm.renderResult(ctx) + } + return tm.dispatchFinalResult(ctx) +} + +func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result { + tm.topic = result.Topic + defer func() { + tm.wg.Done() + mq.RecoverPanic(mq.RecoverTitle) + }() + if result.Ctx != nil { + if headers, ok := mq.GetHeaders(ctx); ok { + ctx = mq.SetHeaders(result.Ctx, headers.AsMap()) + } + } + node, ok := tm.dag.nodes[result.Topic] + if !ok { + return result + } + if result.Error != nil { + tm.reportNodeResult(result, true) + tm.ChangeNodeStatus(ctx, node.Key, Failed, result) + return result + } + edges := tm.getConditionalEdges(node, result) + if len(edges) == 0 { + tm.reportNodeResult(result, true) + tm.ChangeNodeStatus(ctx, node.Key, Completed, result) + return result + } else { + tm.reportNodeResult(result, false) + } + if node.Type == Page { + return result + } + for _, edge := range edges { + switch edge.Type { + case Iterator: + var items []json.RawMessage + err := json.Unmarshal(result.Payload, &items) + if err != nil { + tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false) + result.Error = err + tm.ChangeNodeStatus(ctx, node.Key, Failed, result) + return result + } + tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To)) + for _, target := range edge.To { + for i, item := range items { + tm.wg.Add(1) + go func(ctx context.Context, target *Node, item json.RawMessage, i int) { + ctxx := context.Background() + if headers, ok := mq.GetHeaders(ctx); ok { + headers.Set(consts.QueueKey, target.Key) + headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i)) + ctxx = mq.SetHeaders(ctxx, headers.AsMap()) + } + tm.processNode(ctxx, target, item) + }(ctx, target, item, i) + } + } + } + } + for _, edge := range edges { + switch edge.Type { + case Simple: + if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok { + continue + } + tm.processEdge(ctx, edge, result) + } + } + return result +} + +func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) { + tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To)) + index, _ := mq.GetHeader(ctx, "index") + if index != "" && strings.Contains(index, "__") { + index = strings.Split(index, "__")[1] + } else { + index = "0" + } + for _, target := range edge.To { + tm.wg.Add(1) + go func(ctx context.Context, target *Node, result mq.Result) { + ctxx := context.Background() + if headers, ok := mq.GetHeaders(ctx); ok { + headers.Set(consts.QueueKey, target.Key) + headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index)) + ctxx = mq.SetHeaders(ctxx, headers.AsMap()) + } + tm.processNode(ctxx, target, result.Payload) + }(ctx, target, result) + } +} + +func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge { + edges := make([]Edge, len(node.Edges)) + copy(edges, node.Edges) + if result.ConditionStatus != "" { + if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok { + if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok { + if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok { + edges = append(edges, Edge{From: node, To: []*Node{targetNode}}) + } + } else if targetNodeKey, ok = conditions["default"]; ok { + if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok { + edges = append(edges, Edge{From: node, To: []*Node{targetNode}}) + } + } + } + } + return edges +} + +func (tm *TaskManager) renderResult(ctx context.Context) mq.Result { + var rs mq.Result + tm.updateTS(&rs) + tm.dag.callbackToConsumer(ctx, rs) + tm.topic = rs.Topic + return rs +} + +func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) { + topic := nodeID + if !strings.Contains(nodeID, "__") { + nodeID = getTopic(ctx, nodeID) + } else { + topic = strings.Split(nodeID, "__")[0] + } + nodeStatus, ok := tm.taskNodeStatus.Get(nodeID) + if !ok || nodeStatus == nil { + return + } + + nodeStatus.markAs(rs, status) + switch status { + case Completed: + canProceed := false + edges, ok := tm.iteratorNodes.Get(topic) + if ok { + if len(edges) == 0 { + canProceed = true + } else { + nodeStatus.status = Processing + nodeStatus.totalItems = 1 + nodeStatus.itemResults.Clear() + for _, edge := range edges { + tm.processEdge(ctx, edge, rs) + } + tm.iteratorNodes.Del(topic) + } + } + if canProceed || !ok { + if topic == tm.dag.startNode { + tm.result = rs + } else { + tm.markParentTask(ctx, topic, nodeID, status, rs) + } + } + case Failed: + if topic == tm.dag.startNode { + tm.result = rs + } else { + tm.markParentTask(ctx, topic, nodeID, status, rs) + } + } +} + +func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) { + parentNodes, err := tm.dag.GetPreviousNodes(topic) + if err != nil { + return + } + var index string + nodeParts := strings.Split(nodeID, "__") + if len(nodeParts) == 2 { + index = nodeParts[1] + } + for _, parentNode := range parentNodes { + parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index) + parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey) + if !exists { + parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0") + parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey) + } + if exists { + parentNodeStatus.itemResults.Set(nodeID, rs) + if parentNodeStatus.IsDone() { + rt := tm.prepareResult(ctx, parentNodeStatus) + tm.ChangeNodeStatus(ctx, parentKey, status, rt) + } + } + } +} + +func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result { + aggregatedOutput := make([]json.RawMessage, 0) + var status mq.Status + var topic string + var err1 error + if nodeStatus.totalItems == 1 { + rs := nodeStatus.itemResults.Values()[0] + if rs.Ctx == nil { + rs.Ctx = ctx + } + return rs + } + nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool { + if topic == "" { + topic = result.Topic + status = result.Status + } + if result.Error != nil { + err1 = result.Error + return false + } + var item json.RawMessage + err := json.Unmarshal(result.Payload, &item) + if err != nil { + err1 = err + return false + } + aggregatedOutput = append(aggregatedOutput, item) + return true + }) + if err1 != nil { + return mq.HandleError(ctx, err1) + } + finalOutput, err := json.Marshal(aggregatedOutput) + if err != nil { + return mq.HandleError(ctx, err) + } + return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx} +} diff --git a/dag/v2/ui.go b/dag/v1/ui.go similarity index 54% rename from dag/v2/ui.go rename to dag/v1/ui.go index bd32cc5..78fb02b 100644 --- a/dag/v2/ui.go +++ b/dag/v1/ui.go @@ -1,4 +1,4 @@ -package v2 +package v1 import ( "fmt" @@ -9,24 +9,25 @@ import ( func (tm *DAG) PrintGraph() { fmt.Println("DAG Graph structure:") - tm.nodes.ForEach(func(_ string, node *Node) bool { - fmt.Printf("Node: %s (%s) -> ", node.Label, node.ID) - if conditions, ok := tm.conditions[node.ID]; ok { + for _, node := range tm.nodes { + fmt.Printf("Node: %s (%s) -> ", node.Name, node.Key) + if conditions, ok := tm.conditions[FromNode(node.Key)]; ok { var c []string for when, then := range conditions { - if target, ok := tm.nodes.Get(then); ok { - c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Label, target.ID)) + if target, ok := tm.nodes[string(then)]; ok { + c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Name, target.Key)) } } fmt.Println(strings.Join(c, ", ")) } var edges []string - for _, target := range node.Edges { - edges = append(edges, fmt.Sprintf("%s (%s)", target.To.Label, target.To.ID)) + for _, edge := range node.Edges { + for _, target := range edge.To { + edges = append(edges, fmt.Sprintf("%s (%s)", target.Name, target.Key)) + } } fmt.Println(strings.Join(edges, ", ")) - return true - }) + } } func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) { @@ -43,7 +44,7 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) { if startNode == "" { firstNode := tm.findStartNode() if firstNode != nil { - startNode = firstNode.ID + startNode = firstNode.Key } } if startNode == "" { @@ -61,22 +62,24 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim inRecursionStack[v] = true // mark node as part of recursion stack *timeVal++ discoveryTime[v] = *timeVal - node, _ := tm.nodes.Get(v) + node := tm.nodes[v] hasCycle := false var err error for _, edge := range node.Edges { - if !visited[edge.To.ID] { - builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, edge.To.ID)) - hasCycle, err := tm.dfs(edge.To.ID, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) - if err != nil { - return true, err + for _, adj := range edge.To { + if !visited[adj.Key] { + builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key)) + hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) + if err != nil { + return true, err + } + if hasCycle { + return true, nil + } + } else if inRecursionStack[adj.Key] { + cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key) + return true, fmt.Errorf(cycleMsg) } - if hasCycle { - return true, nil - } - } else if inRecursionStack[edge.To.ID] { - cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, edge.To.ID) - return true, fmt.Errorf(cycleMsg) } } hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) @@ -90,20 +93,20 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim } func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) { - node, _ := tm.nodes.Get(v) - for when, then := range tm.conditions[node.ID] { - if targetNode, ok := tm.nodes.Get(then); ok { - if !visited[targetNode.ID] { - builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID)) - hasCycle, err := tm.dfs(targetNode.ID, visited, discoveryTime, finishedTime, time, inRecursionStack, builder) + node := tm.nodes[v] + for when, then := range tm.conditions[FromNode(node.Key)] { + if targetNode, ok := tm.nodes[string(then)]; ok { + if !visited[targetNode.Key] { + builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)) + hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder) if err != nil { return true, err } if hasCycle { return true, nil } - } else if inRecursionStack[targetNode.ID] { - cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID) + } else if inRecursionStack[targetNode.Key] { + cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key) return true, fmt.Errorf(cycleMsg) } } @@ -143,113 +146,98 @@ func (tm *DAG) ExportDOT() string { var sb strings.Builder sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name)) sb.WriteString("\n") - sb.WriteString(` label="Enhanced DAG Representation";`) + sb.WriteString(fmt.Sprintf(` label="%s";`, tm.name)) sb.WriteString("\n") - sb.WriteString(` labelloc="t"; fontsize=22; fontname="Helvetica";`) + sb.WriteString(` labelloc="t";`) sb.WriteString("\n") - sb.WriteString(` node [shape=box, fontname="Helvetica", fillcolor="#B3CDE0", fontcolor="#2C3E50", fontsize=10, margin="0.25,0.15", style="rounded,filled"];`) + sb.WriteString(` fontsize=20;`) sb.WriteString("\n") - sb.WriteString(` edge [fontname="Helvetica", fontsize=12, arrowsize=0.8];`) + sb.WriteString(` node [shape=box, style="rounded,filled", fillcolor="lightgray", fontname="Arial", margin="0.2,0.1"];`) sb.WriteString("\n") - sb.WriteString(` rankdir=TB;`) + sb.WriteString(` edge [fontname="Arial", fontsize=12, arrowsize=0.8];`) + sb.WriteString("\n") + sb.WriteString(` size="10,10";`) + sb.WriteString("\n") + sb.WriteString(` ratio="fill";`) sb.WriteString("\n") sortedNodes := tm.TopologicalSort() for _, nodeKey := range sortedNodes { - node, _ := tm.nodes.Get(nodeKey) - nodeColor := "lightgray" - nodeShape := "box" - labelSuffix := "" - - // Apply styles based on NodeType - switch node.NodeType { - case Function: - nodeColor = "#D4EDDA" - labelSuffix = " [Function]" - case Page: - nodeColor = "#F08080" - labelSuffix = " [Page]" - } - sb.WriteString(fmt.Sprintf( - ` "%s" [label="%s%s", fontcolor="#2C3E50", fillcolor="%s", shape="%s", style="rounded,filled", id="node_%s"];`, - node.ID, node.Label, labelSuffix, nodeColor, nodeShape, node.ID)) + node := tm.nodes[nodeKey] + nodeColor := "lightblue" + sb.WriteString(fmt.Sprintf(` "%s" [label=" %s", fillcolor="%s", id="node_%s"];`, node.Key, node.Name, nodeColor, node.Key)) sb.WriteString("\n") } - - // Define edges with unique styling by EdgeType for _, nodeKey := range sortedNodes { - node, _ := tm.nodes.Get(nodeKey) + node := tm.nodes[nodeKey] for _, edge := range node.Edges { - edgeStyle := "solid" - edgeColor := "black" - labelSuffix := "" - - // Apply styles based on EdgeType + var edgeStyle string switch edge.Type { case Iterator: edgeStyle = "dashed" - edgeColor = "blue" - labelSuffix = " [Iter]" - case Simple: + default: edgeStyle = "solid" - edgeColor = "black" - labelSuffix = "" } - sb.WriteString(fmt.Sprintf( - ` "%s" -> "%s" [label="%s%s", color="%s", style="%s"];`, - node.ID, edge.To.ID, edge.Label, labelSuffix, edgeColor, edgeStyle)) - sb.WriteString("\n") - } - } - for fromNodeKey, conditions := range tm.conditions { - for when, then := range conditions { - if toNode, ok := tm.nodes.Get(then); ok { - sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.ID, when)) + edgeColor := "black" + for _, to := range edge.To { + sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="%s", style=%s, fontsize=10, arrowsize=0.6];`, node.Key, to.Key, edge.Label, edgeColor, edgeStyle)) + sb.WriteString("\n") + } + } + } + for fromNodeKey, conditions := range tm.conditions { + for when, then := range conditions { + if toNode, ok := tm.nodes[string(then)]; ok { + sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.Key, when)) sb.WriteString("\n") } } } - - // Optional: Group related nodes into subgraphs (e.g., loops) for _, nodeKey := range sortedNodes { - node, _ := tm.nodes.Get(nodeKey) + node := tm.nodes[nodeKey] if node.processor != nil { subDAG, _ := isDAGNode(node) if subDAG != nil { sb.WriteString(fmt.Sprintf(` subgraph "cluster_%s" {`, subDAG.name)) sb.WriteString("\n") - sb.WriteString(fmt.Sprintf(` label="Subgraph: %s";`, subDAG.name)) + sb.WriteString(fmt.Sprintf(` label=" %s";`, subDAG.name)) sb.WriteString("\n") - sb.WriteString(` style=filled; color=gray90;`) + sb.WriteString(` style=dashed;`) sb.WriteString("\n") - subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool { - sb.WriteString(fmt.Sprintf(` "%s" [label="%s"];`, subNode.ID, subNode.Label)) + sb.WriteString(` bgcolor="lightgray";`) + sb.WriteString("\n") + sb.WriteString(` node [shape=rectangle, style="filled", fillcolor="lightblue", fontname="Arial", margin="0.2,0.1"];`) + sb.WriteString("\n") + for subNodeKey, subNode := range subDAG.nodes { + sb.WriteString(fmt.Sprintf(` "%s" [label=" %s"];`, subNodeKey, subNode.Name)) sb.WriteString("\n") - return true - }) - subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool { + } + for subNodeKey, subNode := range subDAG.nodes { for _, edge := range subNode.Edges { - sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label="%s"];`, subNodeKey, edge.To.ID, edge.Label)) - sb.WriteString("\n") + for _, to := range edge.To { + sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, subNodeKey, to.Key, edge.Label)) + sb.WriteString("\n") + } } - return true - }) - sb.WriteString(" }\n") + } + sb.WriteString(` }`) + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, node.Key, subDAG.startNode, subDAG.name)) + sb.WriteString("\n") } } } - - sb.WriteString("}\n") + sb.WriteString(`}`) + sb.WriteString("\n") return sb.String() } func (tm *DAG) TopologicalSort() (stack []string) { visited := make(map[string]bool) - tm.nodes.ForEach(func(_ string, node *Node) bool { - if !visited[node.ID] { - tm.topologicalSortUtil(node.ID, visited, &stack) + for _, node := range tm.nodes { + if !visited[node.Key] { + tm.topologicalSortUtil(node.Key, visited, &stack) } - return true - }) + } for i, j := 0, len(stack)-1; i < j; i, j = i+1, j-1 { stack[i], stack[j] = stack[j], stack[i] } @@ -258,23 +246,13 @@ func (tm *DAG) TopologicalSort() (stack []string) { func (tm *DAG) topologicalSortUtil(v string, visited map[string]bool, stack *[]string) { visited[v] = true - node, ok := tm.nodes.Get(v) - if !ok { - fmt.Println("Not found", v) - } + node := tm.nodes[v] for _, edge := range node.Edges { - if !visited[edge.To.ID] { - tm.topologicalSortUtil(edge.To.ID, visited, stack) + for _, to := range edge.To { + if !visited[to.Key] { + tm.topologicalSortUtil(to.Key, visited, stack) + } } } *stack = append(*stack, v) } - -func isDAGNode(node *Node) (*DAG, bool) { - switch node := node.processor.(type) { - case *DAG: - return node, true - default: - return nil, false - } -} diff --git a/dag/utils.go b/dag/v1/utils.go similarity index 98% rename from dag/utils.go rename to dag/v1/utils.go index efd8642..52c085f 100644 --- a/dag/utils.go +++ b/dag/v1/utils.go @@ -1,4 +1,4 @@ -package dag +package v1 import ( "context" diff --git a/dag/waitgroup.go b/dag/v1/waitgroup.go similarity index 98% rename from dag/waitgroup.go rename to dag/v1/waitgroup.go index e63f9bf..dfb84cc 100644 --- a/dag/waitgroup.go +++ b/dag/v1/waitgroup.go @@ -1,4 +1,4 @@ -package dag +package v1 import ( "sync" diff --git a/dag/v2/websocket.go b/dag/v1/websocket.go similarity index 98% rename from dag/v2/websocket.go rename to dag/v1/websocket.go index 0f0534a..4d4b544 100644 --- a/dag/v2/websocket.go +++ b/dag/v1/websocket.go @@ -1,4 +1,4 @@ -package v2 +package v1 import ( "encoding/json" diff --git a/dag/v2/api.go b/dag/v2/api.go deleted file mode 100644 index f2b6f4a..0000000 --- a/dag/v2/api.go +++ /dev/null @@ -1,152 +0,0 @@ -package v2 - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "os" - "strings" - - "github.com/oarkflow/mq" - "github.com/oarkflow/mq/sio" - - "github.com/oarkflow/mq/consts" - "github.com/oarkflow/mq/jsonparser" -) - -func renderNotFound(w http.ResponseWriter) { - html := []byte(` -
-

task not found

-

Back to home

-
-`) - w.Header().Set(consts.ContentType, consts.TypeHtml) - w.Write(html) -} - -func (tm *DAG) render(w http.ResponseWriter, r *http.Request) { - ctx, data, err := parse(r) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - accept := r.Header.Get("Accept") - userCtx := UserContext(ctx) - ctx = context.WithValue(ctx, "method", r.Method) - if r.Method == "GET" && userCtx.Get("task_id") != "" { - manager, ok := tm.taskManager.Get(userCtx.Get("task_id")) - if !ok || manager == nil { - if strings.Contains(accept, "text/html") || accept == "" { - renderNotFound(w) - return - } - http.Error(w, fmt.Sprintf(`{"message": "%s"}`, "task not found"), http.StatusInternalServerError) - return - } - } - result := tm.Process(ctx, data) - if result.Error != nil { - http.Error(w, fmt.Sprintf(`{"message": "%s"}`, result.Error.Error()), http.StatusInternalServerError) - return - } - contentType, ok := result.Ctx.Value(consts.ContentType).(string) - if !ok { - contentType = consts.TypeJson - } - switch contentType { - case consts.TypeHtml: - w.Header().Set(consts.ContentType, consts.TypeHtml) - data, err := jsonparser.GetString(result.Payload, "html_content") - if err != nil { - return - } - w.Write([]byte(data)) - default: - if r.Method != "POST" { - http.Error(w, `{"message": "not allowed"}`, http.StatusMethodNotAllowed) - return - } - w.Header().Set(consts.ContentType, consts.TypeJson) - json.NewEncoder(w).Encode(result.Payload) - } -} - -func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) { - taskID := r.URL.Query().Get("taskID") - if taskID == "" { - http.Error(w, `{"message": "taskID is missing"}`, http.StatusBadRequest) - return - } - manager, ok := tm.taskManager.Get(taskID) - if !ok { - http.Error(w, `{"message": "Invalid TaskID"}`, http.StatusNotFound) - return - } - result := make(map[string]TaskState) - manager.taskStates.ForEach(func(key string, value *TaskState) bool { - key = strings.Split(key, Delimiter)[0] - nodeID := strings.Split(value.NodeID, Delimiter)[0] - rs := jsonparser.Delete(value.Result.Payload, "html_content") - status := value.Status - if status == mq.Processing { - status = mq.Completed - } - state := TaskState{ - NodeID: nodeID, - Status: status, - UpdatedAt: value.UpdatedAt, - Result: mq.Result{ - Payload: rs, - Error: value.Result.Error, - Status: status, - }, - } - result[key] = state - return true - }) - w.Header().Set(consts.ContentType, consts.TypeJson) - json.NewEncoder(w).Encode(result) -} - -func (tm *DAG) SetupWS() *sio.Server { - ws := sio.New(sio.Config{ - CheckOrigin: func(r *http.Request) bool { return true }, - EnableCompression: true, - }) - WsEvents(ws) - tm.Notifier = ws - return ws -} - -func (tm *DAG) Handlers() { - http.Handle("/", http.FileServer(http.Dir("webroot"))) - http.Handle("/notify", tm.SetupWS()) - http.HandleFunc("/process", tm.render) - http.HandleFunc("/request", tm.render) - http.HandleFunc("/task/status", tm.taskStatusHandler) - http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - fmt.Fprintln(w, tm.ExportDOT()) - }) - http.HandleFunc("/ui", func(w http.ResponseWriter, r *http.Request) { - image := fmt.Sprintf("%s.svg", mq.NewID()) - err := tm.SaveSVG(image) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - defer os.Remove(image) - svgBytes, err := os.ReadFile(image) - if err != nil { - http.Error(w, "Could not read SVG file", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "image/svg+xml") - if _, err := w.Write(svgBytes); err != nil { - http.Error(w, "Could not write SVG response", http.StatusInternalServerError) - return - } - }) -} diff --git a/dag/v2/consts.go b/dag/v2/consts.go deleted file mode 100644 index 3010831..0000000 --- a/dag/v2/consts.go +++ /dev/null @@ -1,28 +0,0 @@ -package v2 - -import "time" - -const ( - Delimiter = "___" - ContextIndex = "index" - DefaultChannelSize = 1000 - RetryInterval = 5 * time.Second -) - -type NodeType int - -func (c NodeType) IsValid() bool { return c >= Function && c <= Page } - -const ( - Function NodeType = iota - Page -) - -type EdgeType int - -func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator } - -const ( - Simple EdgeType = iota - Iterator -) diff --git a/dag/v2/dag.go b/dag/v2/dag.go deleted file mode 100644 index 0e0cf28..0000000 --- a/dag/v2/dag.go +++ /dev/null @@ -1,524 +0,0 @@ -package v2 - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "strings" - "time" - - "golang.org/x/time/rate" - - "github.com/oarkflow/mq/consts" - "github.com/oarkflow/mq/sio" - - "github.com/oarkflow/mq" - "github.com/oarkflow/mq/storage" - "github.com/oarkflow/mq/storage/memory" -) - -type Node struct { - NodeType NodeType - Label string - ID string - Edges []Edge - processor mq.Processor - isReady bool -} - -type Edge struct { - From *Node - To *Node - Type EdgeType - Label string -} - -type DAG struct { - server *mq.Broker - consumer *mq.Consumer - nodes storage.IMap[string, *Node] - taskManager storage.IMap[string, *TaskManager] - iteratorNodes storage.IMap[string, []Edge] - finalResult func(taskID string, result mq.Result) - pool *mq.Pool - name string - key string - startNode string - opts []mq.Option - conditions map[string]map[string]string - consumerTopic string - reportNodeResultCallback func(mq.Result) - Error error - Notifier *sio.Server - paused bool - report string -} - -func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.Result), opts ...mq.Option) *DAG { - callback := func(ctx context.Context, result mq.Result) error { return nil } - d := &DAG{ - name: name, - key: key, - nodes: memory.New[string, *Node](), - taskManager: memory.New[string, *TaskManager](), - iteratorNodes: memory.New[string, []Edge](), - conditions: make(map[string]map[string]string), - finalResult: finalResultCallback, - } - opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) - d.server = mq.NewBroker(opts...) - d.opts = opts - options := d.server.Options() - d.pool = mq.NewPool( - options.NumOfWorkers(), - mq.WithTaskQueueSize(options.QueueSize()), - mq.WithMaxMemoryLoad(options.MaxMemoryLoad()), - mq.WithHandler(d.ProcessTask), - mq.WithPoolCallback(callback), - mq.WithTaskStorage(options.Storage()), - ) - d.pool.Start(d.server.Options().NumOfWorkers()) - return d -} - -func (tm *DAG) SetKey(key string) { - tm.key = key -} - -func (tm *DAG) ReportNodeResult(callback func(mq.Result)) { - tm.reportNodeResultCallback = callback -} - -func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { - if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" { - manager.onNodeCompleted(nodeResult{ - ctx: ctx, - nodeID: result.Topic, - status: result.Status, - result: result, - }) - } - return mq.Result{} -} - -func (tm *DAG) GetType() string { - return tm.key -} - -func (tm *DAG) Consume(ctx context.Context) error { - if tm.consumer != nil { - tm.server.Options().SetSyncMode(true) - return tm.consumer.Consume(ctx) - } - return nil -} - -func (tm *DAG) Stop(ctx context.Context) error { - tm.nodes.ForEach(func(_ string, n *Node) bool { - err := n.processor.Stop(ctx) - if err != nil { - return false - } - return true - }) - return nil -} - -func (tm *DAG) GetKey() string { - return tm.key -} - -func (tm *DAG) AssignTopic(topic string) { - tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL())) - tm.consumerTopic = topic -} - -func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) { - if tm.consumer != nil { - result.Topic = tm.consumerTopic - if tm.consumer.Conn() == nil { - tm.onTaskCallback(ctx, result) - } else { - tm.consumer.OnResponse(ctx, result) - } - } -} - -func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { - if node, ok := tm.nodes.Get(topic); ok { - log.Printf("DAG - CONSUMER ~> ready on %s", topic) - node.isReady = true - } -} - -func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) { - if node, ok := tm.nodes.Get(topic); ok { - log.Printf("DAG - CONSUMER ~> down on %s", topic) - node.isReady = false - } -} - -func (tm *DAG) Pause(_ context.Context) error { - tm.paused = true - return nil -} - -func (tm *DAG) Resume(_ context.Context) error { - tm.paused = false - return nil -} - -func (tm *DAG) Close() error { - var err error - tm.nodes.ForEach(func(_ string, n *Node) bool { - err = n.processor.Close() - if err != nil { - return false - } - return true - }) - return nil -} - -func (tm *DAG) SetStartNode(node string) { - tm.startNode = node -} - -func (tm *DAG) SetNotifyResponse(callback mq.Callback) { - tm.server.SetNotifyHandler(callback) -} - -func (tm *DAG) GetStartNode() string { - return tm.startNode -} - -func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG { - tm.conditions[fromNode] = conditions - return tm -} - -func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG { - if tm.Error != nil { - return tm - } - con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask) - n := &Node{ - Label: name, - ID: nodeID, - NodeType: nodeType, - processor: con, - } - if tm.server != nil && tm.server.SyncMode() { - n.isReady = true - } - tm.nodes.Set(nodeID, n) - if len(startNode) > 0 && startNode[0] { - tm.startNode = nodeID - } - return tm -} - -func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error { - if tm.server.SyncMode() { - return fmt.Errorf("DAG cannot have deferred node in Sync Mode") - } - tm.nodes.Set(key, &Node{ - Label: name, - ID: key, - NodeType: nodeType, - }) - if len(firstNode) > 0 && firstNode[0] { - tm.startNode = key - } - return nil -} - -func (tm *DAG) IsReady() bool { - var isReady bool - tm.nodes.ForEach(func(_ string, n *Node) bool { - if !n.isReady { - return false - } - isReady = true - return true - }) - return isReady -} - -func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG { - if tm.Error != nil { - return tm - } - if edgeType == Iterator { - tm.iteratorNodes.Set(from, []Edge{}) - } - node, ok := tm.nodes.Get(from) - if !ok { - tm.Error = fmt.Errorf("node not found %s", from) - return tm - } - for _, target := range targets { - if targetNode, ok := tm.nodes.Get(target); ok { - edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label} - node.Edges = append(node.Edges, edge) - if edgeType != Iterator { - if edges, ok := tm.iteratorNodes.Get(node.ID); ok { - edges = append(edges, edge) - tm.iteratorNodes.Set(node.ID, edges) - } - } - } - } - return tm -} - -func (tm *DAG) getCurrentNode(manager *TaskManager) string { - if manager.currentNodePayload.Size() == 0 { - return "" - } - return manager.currentNodePayload.Keys()[0] -} - -func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { - ctx = context.WithValue(ctx, "task_id", task.ID) - userContext := UserContext(ctx) - next := userContext.Get("next") - manager, ok := tm.taskManager.Get(task.ID) - resultCh := make(chan mq.Result, 1) - if !ok { - manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone()) - tm.taskManager.Set(task.ID, manager) - } else { - manager.resultCh = resultCh - } - currentKey := tm.getCurrentNode(manager) - currentNode := strings.Split(currentKey, Delimiter)[0] - node, exists := tm.nodes.Get(currentNode) - method, ok := ctx.Value("method").(string) - if method == "GET" && exists && node.NodeType == Page { - ctx = context.WithValue(ctx, "initial_node", currentNode) - /* - if isLastNode, err := tm.IsLastNode(currentNode); err != nil && isLastNode { - if manager.result != nil { - fmt.Println(string(manager.result.Payload)) - resultCh <- *manager.result - return <-resultCh - } - } - */ - if manager.result != nil { - task.Payload = manager.result.Payload - } - } else if next == "true" { - nodes, err := tm.GetNextNodes(currentNode) - if err != nil { - return mq.Result{Error: err, Ctx: ctx} - } - if len(nodes) > 0 { - ctx = context.WithValue(ctx, "initial_node", nodes[0].ID) - } - } - if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult { - var taskPayload, resultPayload map[string]any - if err := json.Unmarshal(task.Payload, &taskPayload); err == nil { - if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil { - for key, val := range resultPayload { - taskPayload[key] = val - } - task.Payload, _ = json.Marshal(taskPayload) - } - } - } - firstNode, err := tm.parseInitialNode(ctx) - if err != nil { - return mq.Result{Error: err, Ctx: ctx} - } - node, ok = tm.nodes.Get(firstNode) - if ok && node.NodeType != Page && task.Payload == nil { - return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx} - } - task.Topic = firstNode - ctx = context.WithValue(ctx, ContextIndex, "0") - manager.ProcessTask(ctx, firstNode, task.Payload) - return <-resultCh -} - -func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { - var taskID string - userCtx := UserContext(ctx) - if val := userCtx.Get("task_id"); val != "" { - taskID = val - } else { - taskID = mq.NewID() - } - return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, "")) -} - -func (tm *DAG) Validate() error { - report, hasCycle, err := tm.ClassifyEdges() - if hasCycle || err != nil { - tm.Error = err - return err - } - tm.report = report - return nil -} - -func (tm *DAG) GetReport() string { - return tm.report -} - -func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG { - dag.AssignTopic(key) - tm.nodes.Set(key, &Node{ - Label: name, - ID: key, - processor: dag, - isReady: true, - }) - if len(firstNode) > 0 && firstNode[0] { - tm.startNode = key - } - return tm -} - -func (tm *DAG) Start(ctx context.Context, addr string) error { - // Start the server in a separate goroutine - go func() { - defer mq.RecoverPanic(mq.RecoverTitle) - if err := tm.server.Start(ctx); err != nil { - panic(err) - } - }() - - // Start the node consumers if not in sync mode - if !tm.server.SyncMode() { - tm.nodes.ForEach(func(_ string, con *Node) bool { - go func(con *Node) { - defer mq.RecoverPanic(mq.RecoverTitle) - limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second - for { - err := con.processor.Consume(ctx) - if err != nil { - log.Printf("[ERROR] - Consumer %s failed to start: %v", con.ID, err) - } else { - log.Printf("[INFO] - Consumer %s started successfully", con.ID) - break - } - limiter.Wait(ctx) // Wait with rate limiting before retrying - } - }(con) - return true - }) - } - log.Printf("DAG - HTTP_SERVER ~> started on http://%s", addr) - tm.Handlers() - config := tm.server.TLSConfig() - log.Printf("Server listening on http://%s", addr) - if config.UseTLS { - return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) - } - return http.ListenAndServe(addr, nil) -} - -func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result { - var taskID string - userCtx := UserContext(ctx) - if val := userCtx.Get("task_id"); val != "" { - taskID = val - } else { - taskID = mq.NewID() - } - t := mq.NewTask(taskID, payload, "") - - ctx = context.WithValue(ctx, "task_id", taskID) - userContext := UserContext(ctx) - next := userContext.Get("next") - manager, ok := tm.taskManager.Get(taskID) - resultCh := make(chan mq.Result, 1) - if !ok { - manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone()) - tm.taskManager.Set(taskID, manager) - } else { - manager.resultCh = resultCh - } - currentKey := tm.getCurrentNode(manager) - currentNode := strings.Split(currentKey, Delimiter)[0] - node, exists := tm.nodes.Get(currentNode) - method, ok := ctx.Value("method").(string) - if method == "GET" && exists && node.NodeType == Page { - ctx = context.WithValue(ctx, "initial_node", currentNode) - } else if next == "true" { - nodes, err := tm.GetNextNodes(currentNode) - if err != nil { - return mq.Result{Error: err, Ctx: ctx} - } - if len(nodes) > 0 { - ctx = context.WithValue(ctx, "initial_node", nodes[0].ID) - } - } - if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult { - var taskPayload, resultPayload map[string]any - if err := json.Unmarshal(payload, &taskPayload); err == nil { - if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil { - for key, val := range resultPayload { - taskPayload[key] = val - } - payload, _ = json.Marshal(taskPayload) - } - } - } - firstNode, err := tm.parseInitialNode(ctx) - if err != nil { - return mq.Result{Error: err, Ctx: ctx} - } - node, ok = tm.nodes.Get(firstNode) - if ok && node.NodeType != Page && t.Payload == nil { - return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx} - } - t.Topic = firstNode - ctx = context.WithValue(ctx, ContextIndex, "0") - - headers, ok := mq.GetHeaders(ctx) - ctxx := context.Background() - if ok { - ctxx = mq.SetHeaders(ctxx, headers.AsMap()) - } - tm.pool.Scheduler().AddTask(ctxx, t, opts...) - return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: "PENDING"} -} - -func (tm *DAG) PauseConsumer(ctx context.Context, id string) { - tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE) -} - -func (tm *DAG) ResumeConsumer(ctx context.Context, id string) { - tm.doConsumer(ctx, id, consts.CONSUMER_RESUME) -} - -func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) { - if node, ok := tm.nodes.Get(id); ok { - switch action { - case consts.CONSUMER_PAUSE: - err := node.processor.Pause(ctx) - if err == nil { - node.isReady = false - log.Printf("[INFO] - Consumer %s paused successfully", node.ID) - } else { - log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.ID, err) - } - case consts.CONSUMER_RESUME: - err := node.processor.Resume(ctx) - if err == nil { - node.isReady = true - log.Printf("[INFO] - Consumer %s resumed successfully", node.ID) - } else { - log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.ID, err) - } - } - } else { - log.Printf("[WARNING] - Consumer %s not found", id) - } -} diff --git a/dag/v2/task_manager.go b/dag/v2/task_manager.go deleted file mode 100644 index caf8dbf..0000000 --- a/dag/v2/task_manager.go +++ /dev/null @@ -1,388 +0,0 @@ -package v2 - -import ( - "context" - "encoding/json" - "fmt" - "log" - "strings" - "time" - - "github.com/oarkflow/mq" - - "github.com/oarkflow/mq/storage" - "github.com/oarkflow/mq/storage/memory" -) - -type TaskState struct { - NodeID string - Status mq.Status - UpdatedAt time.Time - Result mq.Result - targetResults storage.IMap[string, mq.Result] -} - -func newTaskState(nodeID string) *TaskState { - return &TaskState{ - NodeID: nodeID, - Status: mq.Pending, - UpdatedAt: time.Now(), - targetResults: memory.New[string, mq.Result](), - } -} - -type nodeResult struct { - ctx context.Context - nodeID string - status mq.Status - result mq.Result -} - -type TaskManager struct { - taskStates storage.IMap[string, *TaskState] - parentNodes storage.IMap[string, string] - childNodes storage.IMap[string, int] - deferredTasks storage.IMap[string, *task] - iteratorNodes storage.IMap[string, []Edge] - currentNodePayload storage.IMap[string, json.RawMessage] - currentNodeResult storage.IMap[string, mq.Result] - result *mq.Result - dag *DAG - taskID string - taskQueue chan *task - resultQueue chan nodeResult - resultCh chan mq.Result - stopCh chan struct{} -} - -type task struct { - ctx context.Context - taskID string - nodeID string - payload json.RawMessage -} - -func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task { - return &task{ - ctx: ctx, - taskID: taskID, - nodeID: nodeID, - payload: payload, - } -} - -func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager { - tm := &TaskManager{ - taskStates: memory.New[string, *TaskState](), - parentNodes: memory.New[string, string](), - childNodes: memory.New[string, int](), - deferredTasks: memory.New[string, *task](), - currentNodePayload: memory.New[string, json.RawMessage](), - currentNodeResult: memory.New[string, mq.Result](), - taskQueue: make(chan *task, DefaultChannelSize), - resultQueue: make(chan nodeResult, DefaultChannelSize), - iteratorNodes: iteratorNodes, - stopCh: make(chan struct{}), - resultCh: resultCh, - taskID: taskID, - dag: dag, - } - go tm.run() - go tm.waitForResult() - return tm -} - -func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) { - tm.send(ctx, startNode, tm.taskID, payload) -} - -func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) { - if index, ok := ctx.Value(ContextIndex).(string); ok { - startNode = strings.Split(startNode, Delimiter)[0] - startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index) - } - if _, exists := tm.taskStates.Get(startNode); !exists { - tm.taskStates.Set(startNode, newTaskState(startNode)) - } - t := newTask(ctx, taskID, startNode, payload) - select { - case tm.taskQueue <- t: - default: - log.Println("task queue is full, dropping task.") - tm.deferredTasks.Set(taskID, t) - } -} - -func (tm *TaskManager) run() { - for { - select { - case <-tm.stopCh: - log.Println("Stopping TaskManager") - return - case task := <-tm.taskQueue: - tm.processNode(task) - } - } -} - -func (tm *TaskManager) waitForResult() { - for { - select { - case <-tm.stopCh: - log.Println("Stopping Result Listener") - return - case nr := <-tm.resultQueue: - tm.onNodeCompleted(nr) - } - } -} - -func (tm *TaskManager) processNode(exec *task) { - pureNodeID := strings.Split(exec.nodeID, Delimiter)[0] - node, exists := tm.dag.nodes.Get(pureNodeID) - if !exists { - log.Printf("Node %s does not exist while processing node\n", pureNodeID) - return - } - state, _ := tm.taskStates.Get(exec.nodeID) - if state == nil { - log.Printf("State for node %s not found; creating new state.\n", exec.nodeID) - state = newTaskState(exec.nodeID) - tm.taskStates.Set(exec.nodeID, state) - } - state.Status = mq.Processing - state.UpdatedAt = time.Now() - tm.currentNodePayload.Clear() - tm.currentNodeResult.Clear() - tm.currentNodePayload.Set(exec.nodeID, exec.payload) - result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID)) - tm.currentNodeResult.Set(exec.nodeID, result) - state.Result = result - result.Topic = node.ID - if result.Error != nil { - tm.result = &result - tm.resultCh <- result - tm.processFinalResult(state) - return - } - if node.NodeType == Page { - tm.result = &result - tm.resultCh <- result - return - } - tm.handleNext(exec.ctx, node, state, result) -} - -func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) { - state.targetResults.Set(childNode, result) - state.targetResults.Del(state.NodeID) - targetsCount, _ := tm.childNodes.Get(state.NodeID) - size := state.targetResults.Size() - nodeID := strings.Split(state.NodeID, Delimiter) - if size == targetsCount { - if size > 1 { - aggregatedData := make([]json.RawMessage, size) - i := 0 - state.targetResults.ForEach(func(_ string, rs mq.Result) bool { - aggregatedData[i] = rs.Payload - i++ - return true - }) - aggregatedPayload, err := json.Marshal(aggregatedData) - if err != nil { - panic(err) - } - state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID} - } else if size == 1 { - state.Result = state.targetResults.Values()[0] - } - state.Status = result.Status - state.Result.Status = result.Status - } - if state.Result.Payload == nil { - state.Result.Payload = result.Payload - } - state.UpdatedAt = time.Now() - if result.Ctx == nil { - result.Ctx = ctx - } - if result.Error != nil { - state.Status = mq.Failed - } - pn, ok := tm.parentNodes.Get(state.NodeID) - if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed { - state.Status = mq.Processing - tm.iteratorNodes.Del(nodeID[0]) - state.targetResults.Clear() - if len(nodeID) == 2 { - ctx = context.WithValue(ctx, ContextIndex, nodeID[1]) - } - toProcess := nodeResult{ - ctx: ctx, - nodeID: state.NodeID, - status: state.Status, - result: state.Result, - } - tm.handleEdges(toProcess, edges) - } else if ok { - if targetsCount == size { - parentState, _ := tm.taskStates.Get(pn) - if parentState != nil { - state.Result.Topic = state.NodeID - tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal) - } - } - } else { - tm.result = &state.Result - state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0] - tm.resultCh <- state.Result - tm.processFinalResult(state) - } -} - -func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) { - state.UpdatedAt = time.Now() - if result.Ctx == nil { - result.Ctx = ctx - } - if result.Error != nil { - state.Status = mq.Failed - } else { - edges := tm.getConditionalEdges(node, result) - if len(edges) == 0 { - state.Status = mq.Completed - } - } - if result.Status == "" { - result.Status = state.Status - } - select { - case tm.resultQueue <- nodeResult{ - ctx: ctx, - nodeID: state.NodeID, - result: result, - status: state.Status, - }: - default: - log.Println("Result queue is full, dropping result.") - } -} - -func (tm *TaskManager) onNodeCompleted(rs nodeResult) { - nodeID := strings.Split(rs.nodeID, Delimiter)[0] - node, ok := tm.dag.nodes.Get(nodeID) - if !ok { - return - } - edges := tm.getConditionalEdges(node, rs.result) - hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0 - if hasErrorOrCompleted { - if index, ok := rs.ctx.Value(ContextIndex).(string); ok { - childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index) - pn, ok := tm.parentNodes.Get(childNode) - if ok { - parentState, _ := tm.taskStates.Get(pn) - if parentState != nil { - pn = strings.Split(pn, Delimiter)[0] - tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true) - } - } - } - return - } - tm.handleEdges(rs, edges) -} - -func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge { - edges := make([]Edge, len(node.Edges)) - copy(edges, node.Edges) - if result.ConditionStatus != "" { - if conditions, ok := tm.dag.conditions[result.Topic]; ok { - if targetNodeKey, ok := conditions[result.ConditionStatus]; ok { - if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok { - edges = append(edges, Edge{From: node, To: targetNode}) - } - } else if targetNodeKey, ok = conditions["default"]; ok { - if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok { - edges = append(edges, Edge{From: node, To: targetNode}) - } - } - } - } - return edges -} - -func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) { - for _, edge := range edges { - index, ok := currentResult.ctx.Value(ContextIndex).(string) - if !ok { - index = "0" - } - parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index) - if edge.Type == Simple { - if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok { - continue - } - } - if edge.Type == Iterator { - var items []json.RawMessage - err := json.Unmarshal(currentResult.result.Payload, &items) - if err != nil { - log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err) - tm.resultQueue <- nodeResult{ - ctx: currentResult.ctx, - nodeID: edge.To.ID, - status: mq.Failed, - result: mq.Result{Error: err}, - } - return - } - tm.childNodes.Set(parentNode, len(items)) - for i, item := range items { - childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i) - ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i)) - tm.parentNodes.Set(childNode, parentNode) - tm.send(ctx, edge.To.ID, tm.taskID, item) - } - } else { - tm.childNodes.Set(parentNode, 1) - idx, ok := currentResult.ctx.Value(ContextIndex).(string) - if !ok { - idx = "0" - } - childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx) - ctx := context.WithValue(currentResult.ctx, ContextIndex, idx) - tm.parentNodes.Set(childNode, parentNode) - tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload) - } - } -} - -func (tm *TaskManager) retryDeferredTasks() { - const maxRetries = 5 - retries := 0 - for retries < maxRetries { - select { - case <-tm.stopCh: - log.Println("Stopping Deferred task Retrier") - return - case <-time.After(RetryInterval): - tm.deferredTasks.ForEach(func(taskID string, task *task) bool { - tm.send(task.ctx, task.nodeID, taskID, task.payload) - retries++ - return true - }) - } - } -} - -func (tm *TaskManager) processFinalResult(state *TaskState) { - state.targetResults.Clear() - if tm.dag.finalResult != nil { - tm.dag.finalResult(tm.taskID, state.Result) - } -} - -func (tm *TaskManager) Stop() { - close(tm.stopCh) -} diff --git a/dag/websocket.go b/dag/websocket.go index 70f2218..4f2c8cf 100644 --- a/dag/websocket.go +++ b/dag/websocket.go @@ -11,7 +11,6 @@ func WsEvents(s *sio.Server) { } func join(s *sio.Socket, data []byte) { - //just one room at a time for the simple example currentRooms := s.GetRooms() for _, room := range currentRooms { s.Leave(room) diff --git a/examples/dag.go b/examples/dag.go index 9a30d8e..6edc258 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -4,14 +4,13 @@ import ( "context" "encoding/json" "fmt" - v2 "github.com/oarkflow/mq/dag/v2" - "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" "github.com/oarkflow/mq/examples/tasks" ) func main() { - f := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { + f := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) }, mq.WithSyncMode(true)) f.SetNotifyResponse(func(ctx context.Context, result mq.Result) error { @@ -25,39 +24,38 @@ func main() { if err != nil { panic(err) } - f.Start(context.Background(), ":8083") + f.Start(context.Background(), ":8082") sendData(f) } -func subDAG() *v2.DAG { - f := v2.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) { +func subDAG() *dag.DAG { + f := dag.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) { fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) }, mq.WithSyncMode(true)) f. - AddNode(v2.Function, "Store data", "store:data", &tasks.StoreData{Operation: v2.Operation{Type: "process"}}, true). - AddNode(v2.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: v2.Operation{Type: "process"}}). - AddNode(v2.Function, "Notification", "notification", &tasks.InAppNotification{Operation: v2.Operation{Type: "process"}}). - AddEdge(v2.Simple, "Store Payload to send sms", "store:data", "send:sms"). - AddEdge(v2.Simple, "Store Payload to notification", "send:sms", "notification") + AddNode(dag.Function, "Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: "process"}}, true). + AddNode(dag.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: "process"}}). + AddNode(dag.Function, "Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: "process"}}). + AddEdge(dag.Simple, "Store Payload to send sms", "store:data", "send:sms"). + AddEdge(dag.Simple, "Store Payload to notification", "send:sms", "notification") return f } -func setup(f *v2.DAG) { +func setup(f *dag.DAG) { f. - AddNode(v2.Function, "Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: v2.Operation{Type: "process"}}). - AddNode(v2.Function, "Prepare Email", "prepare:email", &tasks.PrepareEmail{Operation: v2.Operation{Type: "process"}}). - AddNode(v2.Function, "Get Input", "get:input", &tasks.GetData{Operation: v2.Operation{Type: "input"}}, true). - AddNode(v2.Function, "Final Payload", "final", &tasks.Final{Operation: v2.Operation{Type: "page"}}). - AddNode(v2.Function, "Iterator Processor", "loop", &tasks.Loop{Operation: v2.Operation{Type: "loop"}}). - AddNode(v2.Function, "Condition", "condition", &tasks.Condition{Operation: v2.Operation{Type: "condition"}}). + AddNode(dag.Function, "Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: dag.Operation{Type: "process"}}). + AddNode(dag.Function, "Prepare Email", "prepare:email", &tasks.PrepareEmail{Operation: dag.Operation{Type: "process"}}). + AddNode(dag.Function, "Get Input", "get:input", &tasks.GetData{Operation: dag.Operation{Type: "input"}}, true). + AddNode(dag.Function, "Iterator Processor", "loop", &tasks.Loop{Operation: dag.Operation{Type: "loop"}}). + AddNode(dag.Function, "Condition", "condition", &tasks.Condition{Operation: dag.Operation{Type: "condition"}}). AddDAGNode("Persistent", "persistent", subDAG()). - AddEdge(v2.Simple, "Get input to loop", "get:input", "loop"). - AddEdge(v2.Iterator, "Loop to prepare email", "loop", "prepare:email"). - AddEdge(v2.Simple, "Prepare Email to condition", "prepare:email", "condition"). + AddEdge(dag.Simple, "Get input to loop", "get:input", "loop"). + AddEdge(dag.Iterator, "Loop to prepare email", "loop", "prepare:email"). + AddEdge(dag.Simple, "Prepare Email to condition", "prepare:email", "condition"). AddCondition("condition", map[string]string{"pass": "email:deliver", "fail": "persistent"}) } -func sendData(f *v2.DAG) { +func sendData(f *dag.DAG) { data := []map[string]any{ {"phone": "+123456789", "email": "abc.xyz@gmail.com"}, {"phone": "+98765412", "email": "xyz.abc@gmail.com"}, } diff --git a/examples/dag_consumer.go b/examples/dag_consumer.go index b4b9ad9..921e5e2 100644 --- a/examples/dag_consumer.go +++ b/examples/dag_consumer.go @@ -3,31 +3,30 @@ package main import ( "context" "fmt" - v2 "github.com/oarkflow/mq/dag/v2" - "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" "github.com/oarkflow/mq/examples/tasks" ) func main() { - d := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { + d := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { fmt.Println("Final", string(result.Payload)) }, mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse), ) - d.AddNode(v2.Function, "C", "C", &tasks.Node3{}, true) - d.AddNode(v2.Function, "D", "D", &tasks.Node4{}) - d.AddNode(v2.Function, "E", "E", &tasks.Node5{}) - d.AddNode(v2.Function, "F", "F", &tasks.Node6{}) - d.AddNode(v2.Function, "G", "G", &tasks.Node7{}) - d.AddNode(v2.Function, "H", "H", &tasks.Node8{}) + d.AddNode(dag.Function, "C", "C", &tasks.Node3{}, true) + d.AddNode(dag.Function, "D", "D", &tasks.Node4{}) + d.AddNode(dag.Function, "E", "E", &tasks.Node5{}) + d.AddNode(dag.Function, "F", "F", &tasks.Node6{}) + d.AddNode(dag.Function, "G", "G", &tasks.Node7{}) + d.AddNode(dag.Function, "H", "H", &tasks.Node8{}) d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) - d.AddEdge(v2.Simple, "Label 1", "B", "C") - d.AddEdge(v2.Simple, "Label 2", "D", "F") - d.AddEdge(v2.Simple, "Label 3", "E", "F") - d.AddEdge(v2.Simple, "Label 4", "F", "G", "H") + d.AddEdge(dag.Simple, "Label 1", "B", "C") + d.AddEdge(dag.Simple, "Label 2", "D", "F") + d.AddEdge(dag.Simple, "Label 3", "E", "F") + d.AddEdge(dag.Simple, "Label 4", "F", "G", "H") d.AssignTopic("queue") err := d.Consume(context.Background()) if err != nil { diff --git a/examples/form.go b/examples/form.go index d20ea26..9c37aaa 100644 --- a/examples/form.go +++ b/examples/form.go @@ -4,25 +4,25 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq/dag" "github.com/oarkflow/jet" "github.com/oarkflow/mq" "github.com/oarkflow/mq/consts" - v2 "github.com/oarkflow/mq/dag/v2" ) func main() { - flow := v2.NewDAG("Multi-Step Form", "multi-step-form", func(taskID string, result mq.Result) { + flow := dag.NewDAG("Multi-Step Form", "multi-step-form", func(taskID string, result mq.Result) { fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) }) - flow.AddNode(v2.Page, "FormStep1", "FormStep1", &FormStep1{}) - flow.AddNode(v2.Page, "FormStep2", "FormStep2", &FormStep2{}) - flow.AddNode(v2.Page, "FormResult", "FormResult", &FormResult{}) + flow.AddNode(dag.Page, "FormStep1", "FormStep1", &FormStep1{}) + flow.AddNode(dag.Page, "FormStep2", "FormStep2", &FormStep2{}) + flow.AddNode(dag.Page, "FormResult", "FormResult", &FormResult{}) // Define edges - flow.AddEdge(v2.Simple, "FormStep1", "FormStep1", "FormStep2") - flow.AddEdge(v2.Simple, "FormStep2", "FormStep2", "FormResult") + flow.AddEdge(dag.Simple, "FormStep1", "FormStep1", "FormStep2") + flow.AddEdge(dag.Simple, "FormStep2", "FormStep2", "FormResult") // Start the flow if flow.Error != nil { @@ -32,7 +32,7 @@ func main() { } type FormStep1 struct { - v2.Operation + dag.Operation } func (p *FormStep1) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -68,7 +68,7 @@ func (p *FormStep1) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } type FormStep2 struct { - v2.Operation + dag.Operation } func (p *FormStep2) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -111,7 +111,7 @@ func (p *FormStep2) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } type FormResult struct { - v2.Operation + dag.Operation } func (p *FormResult) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { diff --git a/examples/subdag.go b/examples/subdag.go index 55d9d4d..4c197c9 100644 --- a/examples/subdag.go +++ b/examples/subdag.go @@ -3,21 +3,21 @@ package main import ( "context" "fmt" + "github.com/oarkflow/mq/dag/v1" "github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq" - "github.com/oarkflow/mq/dag" ) func main() { - d := dag.NewDAG( + d := v1.NewDAG( "Sample DAG", "sample-dag", mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse), ) - subDag := dag.NewDAG( + subDag := v1.NewDAG( "Sub DAG", "D", mq.WithNotifyResponse(tasks.NotifySubDAGResponse), @@ -35,7 +35,7 @@ func main() { d.AddDAGNode("D", "D", subDag) d.AddNode("E", "E", &tasks.Node5{}) d.AddIterator("Send each item", "A", "B") - d.AddCondition("C", map[dag.When]dag.Then{"PASS": "D", "FAIL": "E"}) + d.AddCondition("C", map[v1.When]v1.Then{"PASS": "D", "FAIL": "E"}) d.AddEdge("Label 1", "B", "C") fmt.Println(d.ExportDOT()) diff --git a/examples/tasks/operations.go b/examples/tasks/operations.go index 62e0d6b..dea31d0 100644 --- a/examples/tasks/operations.go +++ b/examples/tasks/operations.go @@ -2,9 +2,8 @@ package tasks import ( "context" - v2 "github.com/oarkflow/mq/dag/v2" - "github.com/oarkflow/json" + v2 "github.com/oarkflow/mq/dag" "github.com/oarkflow/mq" ) diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index a72d7db..dcc2e47 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "fmt" - v2 "github.com/oarkflow/mq/dag/v2" + v2 "github.com/oarkflow/mq/dag" "log" "github.com/oarkflow/mq" diff --git a/examples/v2.go b/examples/v2.go index 1d86632..a259cca 100644 --- a/examples/v2.go +++ b/examples/v2.go @@ -5,16 +5,16 @@ import ( "encoding/json" "fmt" "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" "os" "github.com/oarkflow/jet" "github.com/oarkflow/mq/consts" - v2 "github.com/oarkflow/mq/dag/v2" ) type Form struct { - v2.Operation + dag.Operation } func (p *Form) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -38,7 +38,7 @@ func (p *Form) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } type NodeA struct { - v2.Operation + dag.Operation } func (p *NodeA) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -52,7 +52,7 @@ func (p *NodeA) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } type NodeB struct { - v2.Operation + dag.Operation } func (p *NodeB) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -66,7 +66,7 @@ func (p *NodeB) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } type NodeC struct { - v2.Operation + dag.Operation } func (p *NodeC) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -80,7 +80,7 @@ func (p *NodeC) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } type Result struct { - v2.Operation + dag.Operation } func (p *Result) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -115,16 +115,16 @@ func notify(taskID string, result mq.Result) { } func main() { - flow := v2.NewDAG("Sample DAG", "sample-dag", notify) - flow.AddNode(v2.Page, "Form", "Form", &Form{}) - flow.AddNode(v2.Function, "NodeA", "NodeA", &NodeA{}) - flow.AddNode(v2.Function, "NodeB", "NodeB", &NodeB{}) - flow.AddNode(v2.Function, "NodeC", "NodeC", &NodeC{}) - flow.AddNode(v2.Page, "Result", "Result", &Result{}) - flow.AddEdge(v2.Simple, "Form", "Form", "NodeA") - flow.AddEdge(v2.Simple, "NodeA", "NodeA", "NodeB") - flow.AddEdge(v2.Simple, "NodeB", "NodeB", "NodeC") - flow.AddEdge(v2.Simple, "NodeC", "NodeC", "Result") + flow := dag.NewDAG("Sample DAG", "sample-dag", notify) + flow.AddNode(dag.Page, "Form", "Form", &Form{}) + flow.AddNode(dag.Function, "NodeA", "NodeA", &NodeA{}) + flow.AddNode(dag.Function, "NodeB", "NodeB", &NodeB{}) + flow.AddNode(dag.Function, "NodeC", "NodeC", &NodeC{}) + flow.AddNode(dag.Page, "Result", "Result", &Result{}) + flow.AddEdge(dag.Simple, "Form", "Form", "NodeA") + flow.AddEdge(dag.Simple, "NodeA", "NodeA", "NodeB") + flow.AddEdge(dag.Simple, "NodeB", "NodeB", "NodeC") + flow.AddEdge(dag.Simple, "NodeC", "NodeC", "Result") if flow.Error != nil { panic(flow.Error) } diff --git a/examples/v3.go b/examples/v3.go index 9197d33..bb589fb 100644 --- a/examples/v3.go +++ b/examples/v3.go @@ -5,23 +5,23 @@ import ( "encoding/json" "fmt" "github.com/oarkflow/mq" - v2 "github.com/oarkflow/mq/dag/v2" + "github.com/oarkflow/mq/dag" ) func main() { - flow := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { + flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { // fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) }) - flow.AddNode(v2.Function, "GetData", "GetData", &GetData{}, true) - flow.AddNode(v2.Function, "Loop", "Loop", &Loop{}) - flow.AddNode(v2.Function, "ValidateAge", "ValidateAge", &ValidateAge{}) - flow.AddNode(v2.Function, "ValidateGender", "ValidateGender", &ValidateGender{}) - flow.AddNode(v2.Function, "Final", "Final", &Final{}) + flow.AddNode(dag.Function, "GetData", "GetData", &GetData{}, true) + flow.AddNode(dag.Function, "Loop", "Loop", &Loop{}) + flow.AddNode(dag.Function, "ValidateAge", "ValidateAge", &ValidateAge{}) + flow.AddNode(dag.Function, "ValidateGender", "ValidateGender", &ValidateGender{}) + flow.AddNode(dag.Function, "Final", "Final", &Final{}) - flow.AddEdge(v2.Simple, "GetData", "GetData", "Loop") - flow.AddEdge(v2.Iterator, "Validate age for each item", "Loop", "ValidateAge") + flow.AddEdge(dag.Simple, "GetData", "GetData", "Loop") + flow.AddEdge(dag.Iterator, "Validate age for each item", "Loop", "ValidateAge") flow.AddCondition("ValidateAge", map[string]string{"pass": "ValidateGender"}) - flow.AddEdge(v2.Simple, "Mark as Done", "Loop", "Final") + flow.AddEdge(dag.Simple, "Mark as Done", "Loop", "Final") // flow.Start(":8080") data := []byte(`[{"age": "15", "gender": "female"}, {"age": "18", "gender": "male"}]`) @@ -38,7 +38,7 @@ func main() { } type GetData struct { - v2.Operation + dag.Operation } func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -46,7 +46,7 @@ func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } type Loop struct { - v2.Operation + dag.Operation } func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -54,7 +54,7 @@ func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } type ValidateAge struct { - v2.Operation + dag.Operation } func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -73,7 +73,7 @@ func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result } type ValidateGender struct { - v2.Operation + dag.Operation } func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { @@ -87,7 +87,7 @@ func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Resu } type Final struct { - v2.Operation + dag.Operation } func (p *Final) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {