diff --git a/dag/api.go b/dag/api.go
index f37fe4d..d7395f0 100644
--- a/dag/api.go
+++ b/dag/api.go
@@ -4,26 +4,110 @@ import (
"context"
"encoding/json"
"fmt"
- "io"
"net/http"
- "net/url"
"os"
- "time"
-
- "github.com/oarkflow/mq/jsonparser"
- "github.com/oarkflow/mq/sio"
+ "strings"
"github.com/oarkflow/mq"
+ "github.com/oarkflow/mq/sio"
+
"github.com/oarkflow/mq/consts"
- "github.com/oarkflow/mq/metrics"
+ "github.com/oarkflow/mq/jsonparser"
)
-type Request struct {
- Payload json.RawMessage `json:"payload"`
- Interval time.Duration `json:"interval"`
- Schedule bool `json:"schedule"`
- Overlap bool `json:"overlap"`
- Recurring bool `json:"recurring"`
+func renderNotFound(w http.ResponseWriter) {
+ html := []byte(`
+
+`)
+ w.Header().Set(consts.ContentType, consts.TypeHtml)
+ w.Write(html)
+}
+
+func (tm *DAG) render(w http.ResponseWriter, r *http.Request) {
+ ctx, data, err := parse(r)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusNotFound)
+ return
+ }
+ accept := r.Header.Get("Accept")
+ userCtx := UserContext(ctx)
+ ctx = context.WithValue(ctx, "method", r.Method)
+ if r.Method == "GET" && userCtx.Get("task_id") != "" {
+ manager, ok := tm.taskManager.Get(userCtx.Get("task_id"))
+ if !ok || manager == nil {
+ if strings.Contains(accept, "text/html") || accept == "" {
+ renderNotFound(w)
+ return
+ }
+ http.Error(w, fmt.Sprintf(`{"message": "%s"}`, "task not found"), http.StatusInternalServerError)
+ return
+ }
+ }
+ result := tm.Process(ctx, data)
+ if result.Error != nil {
+ http.Error(w, fmt.Sprintf(`{"message": "%s"}`, result.Error.Error()), http.StatusInternalServerError)
+ return
+ }
+ contentType, ok := result.Ctx.Value(consts.ContentType).(string)
+ if !ok {
+ contentType = consts.TypeJson
+ }
+ switch contentType {
+ case consts.TypeHtml:
+ w.Header().Set(consts.ContentType, consts.TypeHtml)
+ data, err := jsonparser.GetString(result.Payload, "html_content")
+ if err != nil {
+ return
+ }
+ w.Write([]byte(data))
+ default:
+ if r.Method != "POST" {
+ http.Error(w, `{"message": "not allowed"}`, http.StatusMethodNotAllowed)
+ return
+ }
+ w.Header().Set(consts.ContentType, consts.TypeJson)
+ json.NewEncoder(w).Encode(result.Payload)
+ }
+}
+
+func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) {
+ taskID := r.URL.Query().Get("taskID")
+ if taskID == "" {
+ http.Error(w, `{"message": "taskID is missing"}`, http.StatusBadRequest)
+ return
+ }
+ manager, ok := tm.taskManager.Get(taskID)
+ if !ok {
+ http.Error(w, `{"message": "Invalid TaskID"}`, http.StatusNotFound)
+ return
+ }
+ result := make(map[string]TaskState)
+ manager.taskStates.ForEach(func(key string, value *TaskState) bool {
+ key = strings.Split(key, Delimiter)[0]
+ nodeID := strings.Split(value.NodeID, Delimiter)[0]
+ rs := jsonparser.Delete(value.Result.Payload, "html_content")
+ status := value.Status
+ if status == mq.Processing {
+ status = mq.Completed
+ }
+ state := TaskState{
+ NodeID: nodeID,
+ Status: status,
+ UpdatedAt: value.UpdatedAt,
+ Result: mq.Result{
+ Payload: rs,
+ Error: value.Result.Error,
+ Status: status,
+ },
+ }
+ result[key] = state
+ return true
+ })
+ w.Header().Set(consts.ContentType, consts.TypeJson)
+ json.NewEncoder(w).Encode(result)
}
func (tm *DAG) SetupWS() *sio.Server {
@@ -37,57 +121,11 @@ func (tm *DAG) SetupWS() *sio.Server {
}
func (tm *DAG) Handlers() {
- metrics.HandleHTTP()
http.Handle("/", http.FileServer(http.Dir("webroot")))
http.Handle("/notify", tm.SetupWS())
- http.HandleFunc("GET /render", tm.Render)
- http.HandleFunc("POST /request", tm.Request)
- http.HandleFunc("POST /publish", tm.Publish)
- http.HandleFunc("POST /schedule", tm.Schedule)
- http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
- id := request.PathValue("id")
- if id != "" {
- tm.PauseConsumer(request.Context(), id)
- }
- })
- http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
- id := request.PathValue("id")
- if id != "" {
- tm.ResumeConsumer(request.Context(), id)
- }
- })
- http.HandleFunc("/pause", func(w http.ResponseWriter, request *http.Request) {
- err := tm.Pause(request.Context())
- if err != nil {
- http.Error(w, "Failed to pause", http.StatusBadRequest)
- return
- }
- json.NewEncoder(w).Encode(map[string]string{"status": "paused"})
- })
- http.HandleFunc("/resume", func(w http.ResponseWriter, request *http.Request) {
- err := tm.Resume(request.Context())
- if err != nil {
- http.Error(w, "Failed to resume", http.StatusBadRequest)
- return
- }
- json.NewEncoder(w).Encode(map[string]string{"status": "resumed"})
- })
- http.HandleFunc("/stop", func(w http.ResponseWriter, request *http.Request) {
- err := tm.Stop(request.Context())
- if err != nil {
- http.Error(w, "Failed to read request body", http.StatusBadRequest)
- return
- }
- json.NewEncoder(w).Encode(map[string]string{"status": "stopped"})
- })
- http.HandleFunc("/close", func(w http.ResponseWriter, request *http.Request) {
- err := tm.Close()
- if err != nil {
- http.Error(w, "Failed to read request body", http.StatusBadRequest)
- return
- }
- json.NewEncoder(w).Encode(map[string]string{"status": "closed"})
- })
+ http.HandleFunc("/process", tm.render)
+ http.HandleFunc("/request", tm.render)
+ http.HandleFunc("/task/status", tm.taskStatusHandler)
http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, tm.ExportDOT())
@@ -112,98 +150,3 @@ func (tm *DAG) Handlers() {
}
})
}
-
-func (tm *DAG) request(w http.ResponseWriter, r *http.Request, async bool) {
- if r.Method != http.MethodPost {
- http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
- return
- }
- var request Request
- if r.Body != nil {
- defer r.Body.Close()
- var err error
- payload, err := io.ReadAll(r.Body)
- if err != nil {
- http.Error(w, "Failed to read request body", http.StatusBadRequest)
- return
- }
- err = json.Unmarshal(payload, &request)
- if err != nil {
- http.Error(w, "Failed to unmarshal body", http.StatusBadRequest)
- return
- }
- } else {
- http.Error(w, "Empty request body", http.StatusBadRequest)
- return
- }
- ctx := r.Context()
- if async {
- ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
- }
- var opts []mq.SchedulerOption
- if request.Interval > 0 {
- opts = append(opts, mq.WithInterval(request.Interval))
- }
- if request.Overlap {
- opts = append(opts, mq.WithOverlap())
- }
- if request.Recurring {
- opts = append(opts, mq.WithRecurring())
- }
- ctx = context.WithValue(ctx, "query_params", r.URL.Query())
- var rs mq.Result
- if request.Schedule {
- rs = tm.ScheduleTask(ctx, request.Payload, opts...)
- } else {
- rs = tm.Process(ctx, request.Payload)
- }
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(rs)
-}
-
-func (tm *DAG) Render(w http.ResponseWriter, r *http.Request) {
- ctx := mq.SetHeaders(r.Context(), map[string]string{consts.AwaitResponseKey: "true", "request_type": "render"})
- ctx = context.WithValue(ctx, "query_params", r.URL.Query())
- rs := tm.Process(ctx, nil)
- content, err := jsonparser.GetString(rs.Payload, "html_content")
- if err != nil {
- http.Error(w, "Failed to read request body", http.StatusBadRequest)
- return
- }
- w.Header().Set("Content-Type", consts.TypeHtml)
- w.Write([]byte(content))
-}
-
-func (tm *DAG) Request(w http.ResponseWriter, r *http.Request) {
- tm.request(w, r, true)
-}
-
-func (tm *DAG) Publish(w http.ResponseWriter, r *http.Request) {
- tm.request(w, r, false)
-}
-
-func (tm *DAG) Schedule(w http.ResponseWriter, r *http.Request) {
- tm.request(w, r, false)
-}
-
-func GetTaskID(ctx context.Context) string {
- if queryParams := ctx.Value("query_params"); queryParams != nil {
- if params, ok := queryParams.(url.Values); ok {
- if id := params.Get("taskID"); id != "" {
- return id
- }
- }
- }
- return ""
-}
-
-func CanNextNode(ctx context.Context) string {
- if queryParams := ctx.Value("query_params"); queryParams != nil {
- if params, ok := queryParams.(url.Values); ok {
- if id := params.Get("next"); id != "" {
- return id
- }
- }
- }
- return ""
-}
diff --git a/dag/consts.go b/dag/consts.go
index 372a7ee..7e6d712 100644
--- a/dag/consts.go
+++ b/dag/consts.go
@@ -1,26 +1,28 @@
package dag
-type NodeStatus int
-
-func (c NodeStatus) IsValid() bool { return c >= Pending && c <= Failed }
-
-func (c NodeStatus) String() string {
- switch c {
- case Pending:
- return "Pending"
- case Processing:
- return "Processing"
- case Completed:
- return "Completed"
- case Failed:
- return "Failed"
- }
- return ""
-}
+import "time"
const (
- Pending NodeStatus = iota
- Processing
- Completed
- Failed
+ Delimiter = "___"
+ ContextIndex = "index"
+ DefaultChannelSize = 1000
+ RetryInterval = 5 * time.Second
+)
+
+type NodeType int
+
+func (c NodeType) IsValid() bool { return c >= Function && c <= Page }
+
+const (
+ Function NodeType = iota
+ Page
+)
+
+type EdgeType int
+
+func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
+
+const (
+ Simple EdgeType = iota
+ Iterator
)
diff --git a/dag/v2/context.go b/dag/context.go
similarity index 99%
rename from dag/v2/context.go
rename to dag/context.go
index d6fa0d2..c5fc06a 100644
--- a/dag/v2/context.go
+++ b/dag/context.go
@@ -1,4 +1,4 @@
-package v2
+package dag
import (
"context"
diff --git a/dag/dag.go b/dag/dag.go
index fb1bc64..a010e91 100644
--- a/dag/dag.go
+++ b/dag/dag.go
@@ -2,155 +2,70 @@ package dag
import (
"context"
+ "encoding/json"
"fmt"
"log"
"net/http"
+ "strings"
"time"
- "github.com/oarkflow/mq/storage"
- "github.com/oarkflow/mq/storage/memory"
-
- "github.com/oarkflow/mq/sio"
-
"golang.org/x/time/rate"
- "github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
- "github.com/oarkflow/mq/metrics"
-)
+ "github.com/oarkflow/mq/sio"
-type EdgeType int
-
-func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
-
-const (
- Simple EdgeType = iota
- Iterator
-)
-
-type NodeType int
-
-func (c NodeType) IsValid() bool { return c >= Process && c <= Page }
-
-const (
- Process NodeType = iota
- Page
+ "github.com/oarkflow/mq"
+ "github.com/oarkflow/mq/storage"
+ "github.com/oarkflow/mq/storage/memory"
)
type Node struct {
- processor mq.Processor
- Name string
- Type NodeType
- Key string
+ NodeType NodeType
+ Label string
+ ID string
Edges []Edge
+ processor mq.Processor
isReady bool
}
-func (n *Node) ProcessTask(ctx context.Context, msg *mq.Task) mq.Result {
- return n.processor.ProcessTask(ctx, msg)
-}
-
-func (n *Node) Close() error {
- return n.processor.Close()
-}
-
type Edge struct {
- Label string
From *Node
- To []*Node
+ To *Node
Type EdgeType
+ Label string
}
-type (
- FromNode string
- When string
- Then string
-)
-
type DAG struct {
server *mq.Broker
consumer *mq.Consumer
- taskContext storage.IMap[string, *TaskManager]
- nodes map[string]*Node
+ nodes storage.IMap[string, *Node]
+ taskManager storage.IMap[string, *TaskManager]
iteratorNodes storage.IMap[string, []Edge]
- conditions map[FromNode]map[When]Then
+ finalResult func(taskID string, result mq.Result)
pool *mq.Pool
- taskCleanupCh chan string
name string
key string
startNode string
- consumerTopic string
opts []mq.Option
+ conditions map[string]map[string]string
+ consumerTopic string
reportNodeResultCallback func(mq.Result)
+ Error error
Notifier *sio.Server
paused bool
- Error error
report string
- index string
}
-func (tm *DAG) SetKey(key string) {
- tm.key = key
-}
-
-func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
- tm.reportNodeResultCallback = callback
-}
-
-func (tm *DAG) GetType() string {
- return tm.key
-}
-
-func (tm *DAG) listenForTaskCleanup() {
- for taskID := range tm.taskCleanupCh {
- if tm.server.Options().CleanTaskOnComplete() {
- tm.taskCleanup(taskID)
- }
- }
-}
-
-func (tm *DAG) taskCleanup(taskID string) {
- tm.taskContext.Del(taskID)
- log.Printf("DAG - Task %s cleaned up", taskID)
-}
-
-func (tm *DAG) Consume(ctx context.Context) error {
- if tm.consumer != nil {
- tm.server.Options().SetSyncMode(true)
- return tm.consumer.Consume(ctx)
- }
- return nil
-}
-
-func (tm *DAG) Stop(ctx context.Context) error {
- for _, n := range tm.nodes {
- err := n.processor.Stop(ctx)
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func (tm *DAG) GetKey() string {
- return tm.key
-}
-
-func (tm *DAG) AssignTopic(topic string) {
- tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
- tm.consumerTopic = topic
-}
-
-func NewDAG(name, key string, opts ...mq.Option) *DAG {
+func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.Result), opts ...mq.Option) *DAG {
callback := func(ctx context.Context, result mq.Result) error { return nil }
d := &DAG{
name: name,
key: key,
- nodes: make(map[string]*Node),
+ nodes: memory.New[string, *Node](),
+ taskManager: memory.New[string, *TaskManager](),
iteratorNodes: memory.New[string, []Edge](),
- taskContext: memory.New[string, *TaskManager](),
- conditions: make(map[FromNode]map[When]Then),
- taskCleanupCh: make(chan string),
+ conditions: make(map[string]map[string]string),
+ finalResult: finalResultCallback,
}
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
d.server = mq.NewBroker(opts...)
@@ -165,10 +80,61 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG {
mq.WithTaskStorage(options.Storage()),
)
d.pool.Start(d.server.Options().NumOfWorkers())
- go d.listenForTaskCleanup()
return d
}
+func (tm *DAG) SetKey(key string) {
+ tm.key = key
+}
+
+func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
+ tm.reportNodeResultCallback = callback
+}
+
+func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
+ if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" {
+ manager.onNodeCompleted(nodeResult{
+ ctx: ctx,
+ nodeID: result.Topic,
+ status: result.Status,
+ result: result,
+ })
+ }
+ return mq.Result{}
+}
+
+func (tm *DAG) GetType() string {
+ return tm.key
+}
+
+func (tm *DAG) Consume(ctx context.Context) error {
+ if tm.consumer != nil {
+ tm.server.Options().SetSyncMode(true)
+ return tm.consumer.Consume(ctx)
+ }
+ return nil
+}
+
+func (tm *DAG) Stop(ctx context.Context) error {
+ tm.nodes.ForEach(func(_ string, n *Node) bool {
+ err := n.processor.Stop(ctx)
+ if err != nil {
+ return false
+ }
+ return true
+ })
+ return nil
+}
+
+func (tm *DAG) GetKey() string {
+ return tm.key
+}
+
+func (tm *DAG) AssignTopic(topic string) {
+ tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
+ tm.consumerTopic = topic
+}
+
func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
if tm.consumer != nil {
result.Topic = tm.consumerTopic
@@ -180,114 +146,89 @@ func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
}
}
-func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
- if taskContext, ok := tm.taskContext.Get(result.TaskID); ok && result.Topic != "" {
- return taskContext.handleNextTask(ctx, result)
- }
- return mq.Result{}
-}
-
func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
- if node, ok := tm.nodes[topic]; ok {
+ if node, ok := tm.nodes.Get(topic); ok {
log.Printf("DAG - CONSUMER ~> ready on %s", topic)
node.isReady = true
}
}
func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) {
- if node, ok := tm.nodes[topic]; ok {
+ if node, ok := tm.nodes.Get(topic); ok {
log.Printf("DAG - CONSUMER ~> down on %s", topic)
node.isReady = false
}
}
+func (tm *DAG) Pause(_ context.Context) error {
+ tm.paused = true
+ return nil
+}
+
+func (tm *DAG) Resume(_ context.Context) error {
+ tm.paused = false
+ return nil
+}
+
+func (tm *DAG) Close() error {
+ var err error
+ tm.nodes.ForEach(func(_ string, n *Node) bool {
+ err = n.processor.Close()
+ if err != nil {
+ return false
+ }
+ return true
+ })
+ return nil
+}
+
func (tm *DAG) SetStartNode(node string) {
tm.startNode = node
}
+func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
+ tm.server.SetNotifyHandler(callback)
+}
+
func (tm *DAG) GetStartNode() string {
return tm.startNode
}
-func (tm *DAG) Start(ctx context.Context, addr string) error {
- // Start the server in a separate goroutine
- go func() {
- defer mq.RecoverPanic(mq.RecoverTitle)
- if err := tm.server.Start(ctx); err != nil {
- panic(err)
- }
- }()
-
- // Start the node consumers if not in sync mode
- if !tm.server.SyncMode() {
- for _, con := range tm.nodes {
- go func(con *Node) {
- defer mq.RecoverPanic(mq.RecoverTitle)
- limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
- for {
- err := con.processor.Consume(ctx)
- if err != nil {
- log.Printf("[ERROR] - Consumer %s failed to start: %v", con.Key, err)
- } else {
- log.Printf("[INFO] - Consumer %s started successfully", con.Key)
- break
- }
- limiter.Wait(ctx) // Wait with rate limiting before retrying
- }
- }(con)
- }
- }
- log.Printf("DAG - HTTP_SERVER ~> started on http://localhost%s", addr)
- tm.Handlers()
- config := tm.server.TLSConfig()
- if config.UseTLS {
- return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
- }
- return http.ListenAndServe(addr, nil)
-}
-
-func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
- dag.AssignTopic(key)
- tm.nodes[key] = &Node{
- Name: name,
- Key: key,
- processor: dag,
- isReady: true,
- }
- if len(firstNode) > 0 && firstNode[0] {
- tm.startNode = key
- }
+func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG {
+ tm.conditions[fromNode] = conditions
return tm
}
-func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG {
- con := mq.NewConsumer(key, key, handler.ProcessTask, tm.opts...)
+func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG {
+ if tm.Error != nil {
+ return tm
+ }
+ con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask)
n := &Node{
- Name: name,
- Key: key,
+ Label: name,
+ ID: nodeID,
+ NodeType: nodeType,
processor: con,
}
- if handler.GetType() == "page" {
- n.Type = Page
- }
- if tm.server.SyncMode() {
+ if tm.server != nil && tm.server.SyncMode() {
n.isReady = true
}
- tm.nodes[key] = n
- if len(firstNode) > 0 && firstNode[0] {
- tm.startNode = key
+ tm.nodes.Set(nodeID, n)
+ if len(startNode) > 0 && startNode[0] {
+ tm.startNode = nodeID
}
return tm
}
-func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error {
+func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error {
if tm.server.SyncMode() {
return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
}
- tm.nodes[key] = &Node{
- Name: name,
- Key: key,
- }
+ tm.nodes.Set(key, &Node{
+ Label: name,
+ ID: key,
+ NodeType: nodeType,
+ })
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
@@ -295,52 +236,124 @@ func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error {
}
func (tm *DAG) IsReady() bool {
- for _, node := range tm.nodes {
- if !node.isReady {
+ var isReady bool
+ tm.nodes.ForEach(func(_ string, n *Node) bool {
+ if !n.isReady {
return false
}
+ isReady = true
+ return true
+ })
+ return isReady
+}
+
+func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG {
+ if tm.Error != nil {
+ return tm
}
- return true
-}
-
-func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG {
- tm.conditions[fromNode] = conditions
- return tm
-}
-
-func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG {
- tm.Error = tm.addEdge(Iterator, label, from, targets...)
- tm.iteratorNodes.Set(from, []Edge{})
- return tm
-}
-
-func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG {
- tm.Error = tm.addEdge(Simple, label, from, targets...)
- return tm
-}
-
-func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error {
- fromNode, ok := tm.nodes[from]
+ if edgeType == Iterator {
+ tm.iteratorNodes.Set(from, []Edge{})
+ }
+ node, ok := tm.nodes.Get(from)
if !ok {
- return fmt.Errorf("Error: 'from' node %s does not exist\n", from)
+ tm.Error = fmt.Errorf("node not found %s", from)
+ return tm
}
- var nodes []*Node
for _, target := range targets {
- toNode, ok := tm.nodes[target]
- if !ok {
- return fmt.Errorf("Error: 'from' node %s does not exist\n", target)
- }
- nodes = append(nodes, toNode)
- }
- edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label}
- fromNode.Edges = append(fromNode.Edges, edge)
- if edgeType != Iterator {
- if edges, ok := tm.iteratorNodes.Get(fromNode.Key); ok {
- edges = append(edges, edge)
- tm.iteratorNodes.Set(fromNode.Key, edges)
+ if targetNode, ok := tm.nodes.Get(target); ok {
+ edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label}
+ node.Edges = append(node.Edges, edge)
+ if edgeType != Iterator {
+ if edges, ok := tm.iteratorNodes.Get(node.ID); ok {
+ edges = append(edges, edge)
+ tm.iteratorNodes.Set(node.ID, edges)
+ }
+ }
}
}
- return nil
+ return tm
+}
+
+func (tm *DAG) getCurrentNode(manager *TaskManager) string {
+ if manager.currentNodePayload.Size() == 0 {
+ return ""
+ }
+ return manager.currentNodePayload.Keys()[0]
+}
+
+func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
+ ctx = context.WithValue(ctx, "task_id", task.ID)
+ userContext := UserContext(ctx)
+ next := userContext.Get("next")
+ manager, ok := tm.taskManager.Get(task.ID)
+ resultCh := make(chan mq.Result, 1)
+ if !ok {
+ manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone())
+ tm.taskManager.Set(task.ID, manager)
+ } else {
+ manager.resultCh = resultCh
+ }
+ currentKey := tm.getCurrentNode(manager)
+ currentNode := strings.Split(currentKey, Delimiter)[0]
+ node, exists := tm.nodes.Get(currentNode)
+ method, ok := ctx.Value("method").(string)
+ if method == "GET" && exists && node.NodeType == Page {
+ ctx = context.WithValue(ctx, "initial_node", currentNode)
+ /*
+ if isLastNode, err := tm.IsLastNode(currentNode); err != nil && isLastNode {
+ if manager.result != nil {
+ fmt.Println(string(manager.result.Payload))
+ resultCh <- *manager.result
+ return <-resultCh
+ }
+ }
+ */
+ if manager.result != nil {
+ task.Payload = manager.result.Payload
+ }
+ } else if next == "true" {
+ nodes, err := tm.GetNextNodes(currentNode)
+ if err != nil {
+ return mq.Result{Error: err, Ctx: ctx}
+ }
+ if len(nodes) > 0 {
+ ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
+ }
+ }
+ if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
+ var taskPayload, resultPayload map[string]any
+ if err := json.Unmarshal(task.Payload, &taskPayload); err == nil {
+ if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
+ for key, val := range resultPayload {
+ taskPayload[key] = val
+ }
+ task.Payload, _ = json.Marshal(taskPayload)
+ }
+ }
+ }
+ firstNode, err := tm.parseInitialNode(ctx)
+ if err != nil {
+ return mq.Result{Error: err, Ctx: ctx}
+ }
+ node, ok = tm.nodes.Get(firstNode)
+ if ok && node.NodeType != Page && task.Payload == nil {
+ return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx}
+ }
+ task.Topic = firstNode
+ ctx = context.WithValue(ctx, ContextIndex, "0")
+ manager.ProcessTask(ctx, firstNode, task.Payload)
+ return <-resultCh
+}
+
+func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
+ var taskID string
+ userCtx := UserContext(ctx)
+ if val := userCtx.Get("task_id"); val != "" {
+ taskID = val
+ } else {
+ taskID = mq.NewID()
+ }
+ return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, ""))
}
func (tm *DAG) Validate() error {
@@ -357,177 +370,124 @@ func (tm *DAG) GetReport() string {
return tm.report
}
-func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
- if task.ID == "" {
- task.ID = mq.NewID()
- }
- if index, ok := mq.GetHeader(ctx, "index"); ok {
- tm.index = index
- }
- manager, exists := tm.taskContext.Get(task.ID)
- if !exists {
- manager = NewTaskManager(tm, task.ID, tm.iteratorNodes)
- manager.createdAt = task.CreatedAt
- tm.taskContext.Set(task.ID, manager)
+func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
+ dag.AssignTopic(key)
+ tm.nodes.Set(key, &Node{
+ Label: name,
+ ID: key,
+ processor: dag,
+ isReady: true,
+ })
+ if len(firstNode) > 0 && firstNode[0] {
+ tm.startNode = key
}
+ return tm
+}
- if tm.consumer != nil {
- initialNode, err := tm.parseInitialNode(ctx)
- if err != nil {
- metrics.TasksErrors.WithLabelValues("unknown").Inc() // Increase error count
- return mq.Result{Error: err}
+func (tm *DAG) Start(ctx context.Context, addr string) error {
+ // Start the server in a separate goroutine
+ go func() {
+ defer mq.RecoverPanic(mq.RecoverTitle)
+ if err := tm.server.Start(ctx); err != nil {
+ panic(err)
}
- task.Topic = initialNode
- }
- if manager.topic != "" {
- task.Topic = manager.topic
- canNext := CanNextNode(ctx)
- if canNext != "" {
- if n, ok := tm.nodes[task.Topic]; ok {
- if len(n.Edges) > 0 {
- task.Topic = n.Edges[0].To[0].Key
+ }()
+
+ // Start the node consumers if not in sync mode
+ if !tm.server.SyncMode() {
+ tm.nodes.ForEach(func(_ string, con *Node) bool {
+ go func(con *Node) {
+ defer mq.RecoverPanic(mq.RecoverTitle)
+ limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
+ for {
+ err := con.processor.Consume(ctx)
+ if err != nil {
+ log.Printf("[ERROR] - Consumer %s failed to start: %v", con.ID, err)
+ } else {
+ log.Printf("[INFO] - Consumer %s started successfully", con.ID)
+ break
+ }
+ limiter.Wait(ctx) // Wait with rate limiting before retrying
}
- }
- } else {
- }
+ }(con)
+ return true
+ })
}
- result := manager.processTask(ctx, task.Topic, task.Payload)
- if result.Ctx != nil && tm.index != "" {
- result.Ctx = mq.SetHeaders(result.Ctx, map[string]string{"index": tm.index})
+ log.Printf("DAG - HTTP_SERVER ~> started on http://%s", addr)
+ tm.Handlers()
+ config := tm.server.TLSConfig()
+ log.Printf("Server listening on http://%s", addr)
+ if config.UseTLS {
+ return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
}
- if result.Error != nil {
- metrics.TasksErrors.WithLabelValues(task.Topic).Inc() // Increase error count
- } else {
- metrics.TasksProcessed.WithLabelValues("success").Inc() // Increase processed task count
- }
- return result
-}
-
-func (tm *DAG) check(ctx context.Context, payload []byte) (context.Context, *mq.Task, error) {
- if tm.paused {
- return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not accepting any task")
- }
- if !tm.IsReady() {
- return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not ready yet")
- }
- initialNode, err := tm.parseInitialNode(ctx)
- if err != nil {
- return ctx, nil, err
- }
- if tm.server.SyncMode() {
- ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
- }
- taskID := GetTaskID(ctx)
- if taskID != "" {
- if _, exists := tm.taskContext.Get(taskID); !exists {
- return ctx, nil, fmt.Errorf("provided task ID doesn't exist")
- }
- }
- if taskID == "" {
- taskID = mq.NewID()
- }
- return ctx, mq.NewTask(taskID, payload, initialNode), nil
-}
-
-func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
- ctx, task, err := tm.check(ctx, payload)
- if err != nil {
- return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
- }
- awaitResponse, _ := mq.GetAwaitResponse(ctx)
- if awaitResponse != "true" {
- headers, ok := mq.GetHeaders(ctx)
- ctxx := context.Background()
- if ok {
- ctxx = mq.SetHeaders(ctxx, headers.AsMap())
- }
- if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil {
- return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "FAILED", Error: err}
- }
- return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"}
- }
- return tm.ProcessTask(ctx, task)
+ return http.ListenAndServe(addr, nil)
}
func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
- ctx, task, err := tm.check(ctx, payload)
- if err != nil {
- return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
+ var taskID string
+ userCtx := UserContext(ctx)
+ if val := userCtx.Get("task_id"); val != "" {
+ taskID = val
+ } else {
+ taskID = mq.NewID()
}
+ t := mq.NewTask(taskID, payload, "")
+
+ ctx = context.WithValue(ctx, "task_id", taskID)
+ userContext := UserContext(ctx)
+ next := userContext.Get("next")
+ manager, ok := tm.taskManager.Get(taskID)
+ resultCh := make(chan mq.Result, 1)
+ if !ok {
+ manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone())
+ tm.taskManager.Set(taskID, manager)
+ } else {
+ manager.resultCh = resultCh
+ }
+ currentKey := tm.getCurrentNode(manager)
+ currentNode := strings.Split(currentKey, Delimiter)[0]
+ node, exists := tm.nodes.Get(currentNode)
+ method, ok := ctx.Value("method").(string)
+ if method == "GET" && exists && node.NodeType == Page {
+ ctx = context.WithValue(ctx, "initial_node", currentNode)
+ } else if next == "true" {
+ nodes, err := tm.GetNextNodes(currentNode)
+ if err != nil {
+ return mq.Result{Error: err, Ctx: ctx}
+ }
+ if len(nodes) > 0 {
+ ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
+ }
+ }
+ if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
+ var taskPayload, resultPayload map[string]any
+ if err := json.Unmarshal(payload, &taskPayload); err == nil {
+ if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
+ for key, val := range resultPayload {
+ taskPayload[key] = val
+ }
+ payload, _ = json.Marshal(taskPayload)
+ }
+ }
+ }
+ firstNode, err := tm.parseInitialNode(ctx)
+ if err != nil {
+ return mq.Result{Error: err, Ctx: ctx}
+ }
+ node, ok = tm.nodes.Get(firstNode)
+ if ok && node.NodeType != Page && t.Payload == nil {
+ return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx}
+ }
+ t.Topic = firstNode
+ ctx = context.WithValue(ctx, ContextIndex, "0")
+
headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background()
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
- tm.pool.Scheduler().AddTask(ctxx, task, opts...)
- return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"}
-}
-
-func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
- val := ctx.Value("initial_node")
- initialNode, ok := val.(string)
- if ok {
- return initialNode, nil
- }
- if tm.startNode == "" {
- firstNode := tm.findStartNode()
- if firstNode != nil {
- tm.startNode = firstNode.Key
- }
- }
-
- if tm.startNode == "" {
- return "", fmt.Errorf("initial node not found")
- }
- return tm.startNode, nil
-}
-
-func (tm *DAG) findStartNode() *Node {
- incomingEdges := make(map[string]bool)
- connectedNodes := make(map[string]bool)
- for _, node := range tm.nodes {
- for _, edge := range node.Edges {
- if edge.Type.IsValid() {
- connectedNodes[node.Key] = true
- for _, to := range edge.To {
- connectedNodes[to.Key] = true
- incomingEdges[to.Key] = true
- }
- }
- }
- if cond, ok := tm.conditions[FromNode(node.Key)]; ok {
- for _, target := range cond {
- connectedNodes[string(target)] = true
- incomingEdges[string(target)] = true
- }
- }
- }
- for nodeID, node := range tm.nodes {
- if !incomingEdges[nodeID] && connectedNodes[nodeID] {
- return node
- }
- }
- return nil
-}
-
-func (tm *DAG) Pause(_ context.Context) error {
- tm.paused = true
- return nil
-}
-
-func (tm *DAG) Resume(_ context.Context) error {
- tm.paused = false
- return nil
-}
-
-func (tm *DAG) Close() error {
- for _, n := range tm.nodes {
- err := n.Close()
- if err != nil {
- return err
- }
- }
- return nil
+ tm.pool.Scheduler().AddTask(ctxx, t, opts...)
+ return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: "PENDING"}
}
func (tm *DAG) PauseConsumer(ctx context.Context, id string) {
@@ -539,74 +499,26 @@ func (tm *DAG) ResumeConsumer(ctx context.Context, id string) {
}
func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) {
- if node, ok := tm.nodes[id]; ok {
+ if node, ok := tm.nodes.Get(id); ok {
switch action {
case consts.CONSUMER_PAUSE:
err := node.processor.Pause(ctx)
if err == nil {
node.isReady = false
- log.Printf("[INFO] - Consumer %s paused successfully", node.Key)
+ log.Printf("[INFO] - Consumer %s paused successfully", node.ID)
} else {
- log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.Key, err)
+ log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.ID, err)
}
case consts.CONSUMER_RESUME:
err := node.processor.Resume(ctx)
if err == nil {
node.isReady = true
- log.Printf("[INFO] - Consumer %s resumed successfully", node.Key)
+ log.Printf("[INFO] - Consumer %s resumed successfully", node.ID)
} else {
- log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.Key, err)
+ log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.ID, err)
}
}
} else {
log.Printf("[WARNING] - Consumer %s not found", id)
}
}
-
-func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
- tm.server.SetNotifyHandler(callback)
-}
-
-func (tm *DAG) GetNextNodes(key string) ([]*Node, error) {
- node, exists := tm.nodes[key]
- if !exists {
- return nil, fmt.Errorf("Node with key %s does not exist", key)
- }
- var successors []*Node
- for _, edge := range node.Edges {
- successors = append(successors, edge.To...)
- }
- if conds, exists := tm.conditions[FromNode(key)]; exists {
- for _, targetKey := range conds {
- if targetNode, exists := tm.nodes[string(targetKey)]; exists {
- successors = append(successors, targetNode)
- }
- }
- }
- return successors, nil
-}
-
-func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) {
- var predecessors []*Node
- for _, node := range tm.nodes {
- for _, edge := range node.Edges {
- for _, target := range edge.To {
- if target.Key == key {
- predecessors = append(predecessors, node)
- }
- }
- }
- }
- for fromNode, conds := range tm.conditions {
- for _, targetKey := range conds {
- if string(targetKey) == key {
- node, exists := tm.nodes[string(fromNode)]
- if !exists {
- return nil, fmt.Errorf("Node with key %s does not exist", fromNode)
- }
- predecessors = append(predecessors, node)
- }
- }
- }
- return predecessors, nil
-}
diff --git a/dag/v2/node.go b/dag/node.go
similarity index 99%
rename from dag/v2/node.go
rename to dag/node.go
index bba93ac..66161e8 100644
--- a/dag/v2/node.go
+++ b/dag/node.go
@@ -1,4 +1,4 @@
-package v2
+package dag
import (
"context"
diff --git a/dag/task_manager.go b/dag/task_manager.go
index 0ea8067..3b4e68a 100644
--- a/dag/task_manager.go
+++ b/dag/task_manager.go
@@ -4,247 +4,307 @@ import (
"context"
"encoding/json"
"fmt"
+ "log"
"strings"
"time"
- "github.com/oarkflow/mq/consts"
+ "github.com/oarkflow/mq"
+
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
-
- "github.com/oarkflow/mq"
)
+type TaskState struct {
+ NodeID string
+ Status mq.Status
+ UpdatedAt time.Time
+ Result mq.Result
+ targetResults storage.IMap[string, mq.Result]
+}
+
+func newTaskState(nodeID string) *TaskState {
+ return &TaskState{
+ NodeID: nodeID,
+ Status: mq.Pending,
+ UpdatedAt: time.Now(),
+ targetResults: memory.New[string, mq.Result](),
+ }
+}
+
+type nodeResult struct {
+ ctx context.Context
+ nodeID string
+ status mq.Status
+ result mq.Result
+}
+
type TaskManager struct {
- createdAt time.Time
- processedAt time.Time
- status string
- dag *DAG
- taskID string
- wg *WaitGroup
- topic string
- result mq.Result
-
- iteratorNodes storage.IMap[string, []Edge]
- taskNodeStatus storage.IMap[string, *taskNodeStatus]
+ taskStates storage.IMap[string, *TaskState]
+ parentNodes storage.IMap[string, string]
+ childNodes storage.IMap[string, int]
+ deferredTasks storage.IMap[string, *task]
+ iteratorNodes storage.IMap[string, []Edge]
+ currentNodePayload storage.IMap[string, json.RawMessage]
+ currentNodeResult storage.IMap[string, mq.Result]
+ result *mq.Result
+ dag *DAG
+ taskID string
+ taskQueue chan *task
+ resultQueue chan nodeResult
+ resultCh chan mq.Result
+ stopCh chan struct{}
}
-func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
- return &TaskManager{
- dag: d,
- taskNodeStatus: memory.New[string, *taskNodeStatus](),
- taskID: taskID,
- iteratorNodes: iteratorNodes,
- wg: NewWaitGroup(),
+type task struct {
+ ctx context.Context
+ taskID string
+ nodeID string
+ payload json.RawMessage
+}
+
+func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task {
+ return &task{
+ ctx: ctx,
+ taskID: taskID,
+ nodeID: nodeID,
+ payload: payload,
}
}
-func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
- tm.updateTS(&tm.result)
- tm.dag.callbackToConsumer(ctx, tm.result)
- if tm.dag.server.NotifyHandler() != nil {
- _ = tm.dag.server.NotifyHandler()(ctx, tm.result)
+func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
+ tm := &TaskManager{
+ taskStates: memory.New[string, *TaskState](),
+ parentNodes: memory.New[string, string](),
+ childNodes: memory.New[string, int](),
+ deferredTasks: memory.New[string, *task](),
+ currentNodePayload: memory.New[string, json.RawMessage](),
+ currentNodeResult: memory.New[string, mq.Result](),
+ taskQueue: make(chan *task, DefaultChannelSize),
+ resultQueue: make(chan nodeResult, DefaultChannelSize),
+ iteratorNodes: iteratorNodes,
+ stopCh: make(chan struct{}),
+ resultCh: resultCh,
+ taskID: taskID,
+ dag: dag,
}
- tm.dag.taskCleanupCh <- tm.taskID
- tm.topic = tm.result.Topic
- return tm.result
+ go tm.run()
+ go tm.waitForResult()
+ return tm
}
-func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) {
- if tm.dag.reportNodeResultCallback != nil {
- tm.dag.reportNodeResultCallback(result)
- }
+func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) {
+ tm.send(ctx, startNode, tm.taskID, payload)
}
-func (tm *TaskManager) SetTotalItems(topic string, i int) {
- if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok {
- nodeStatus.totalItems = i
+func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
+ if index, ok := ctx.Value(ContextIndex).(string); ok {
+ startNode = strings.Split(startNode, Delimiter)[0]
+ startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index)
}
-}
-
-func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
- topic := getTopic(ctx, node.Key)
- tm.taskNodeStatus.Set(topic, newNodeStatus(topic))
- defer mq.RecoverPanic(mq.RecoverTitle)
- dag, isDAG := isDAGNode(node)
- if isDAG {
- if tm.dag.server.SyncMode() && !dag.server.SyncMode() {
- dag.server.Options().SetSyncMode(true)
- }
- }
- tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key})
- var result mq.Result
- if tm.dag.server.SyncMode() {
- defer func() {
- if isDAG {
- result.Topic = dag.consumerTopic
- result.TaskID = tm.taskID
- tm.reportNodeResult(result, false)
- tm.handleNextTask(result.Ctx, result)
- } else {
- result.Topic = node.Key
- tm.reportNodeResult(result, false)
- tm.handleNextTask(ctx, result)
- }
- }()
+ if _, exists := tm.taskStates.Get(startNode); !exists {
+ tm.taskStates.Set(startNode, newTaskState(startNode))
}
+ t := newTask(ctx, taskID, startNode, payload)
select {
- case <-ctx.Done():
- result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
- tm.reportNodeResult(result, true)
- tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
- return
+ case tm.taskQueue <- t:
default:
- ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
- if tm.dag.server.SyncMode() {
- result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key))
- if result.Error != nil {
- tm.reportNodeResult(result, true)
- tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
- return
+ log.Println("task queue is full, dropping task.")
+ tm.deferredTasks.Set(taskID, t)
+ }
+}
+
+func (tm *TaskManager) run() {
+ for {
+ select {
+ case <-tm.stopCh:
+ log.Println("Stopping TaskManager")
+ return
+ case task := <-tm.taskQueue:
+ tm.processNode(task)
+ }
+ }
+}
+
+func (tm *TaskManager) waitForResult() {
+ for {
+ select {
+ case <-tm.stopCh:
+ log.Println("Stopping Result Listener")
+ return
+ case nr := <-tm.resultQueue:
+ tm.onNodeCompleted(nr)
+ }
+ }
+}
+
+func (tm *TaskManager) processNode(exec *task) {
+ pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
+ node, exists := tm.dag.nodes.Get(pureNodeID)
+ if !exists {
+ log.Printf("Node %s does not exist while processing node\n", pureNodeID)
+ return
+ }
+ state, _ := tm.taskStates.Get(exec.nodeID)
+ if state == nil {
+ log.Printf("State for node %s not found; creating new state.\n", exec.nodeID)
+ state = newTaskState(exec.nodeID)
+ tm.taskStates.Set(exec.nodeID, state)
+ }
+ state.Status = mq.Processing
+ state.UpdatedAt = time.Now()
+ tm.currentNodePayload.Clear()
+ tm.currentNodeResult.Clear()
+ tm.currentNodePayload.Set(exec.nodeID, exec.payload)
+ result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
+ tm.currentNodeResult.Set(exec.nodeID, result)
+ state.Result = result
+ result.Topic = node.ID
+ if result.Error != nil {
+ tm.result = &result
+ tm.resultCh <- result
+ tm.processFinalResult(state)
+ return
+ }
+ if node.NodeType == Page {
+ tm.result = &result
+ tm.resultCh <- result
+ return
+ }
+ tm.handleNext(exec.ctx, node, state, result)
+}
+
+func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
+ state.targetResults.Set(childNode, result)
+ state.targetResults.Del(state.NodeID)
+ targetsCount, _ := tm.childNodes.Get(state.NodeID)
+ size := state.targetResults.Size()
+ nodeID := strings.Split(state.NodeID, Delimiter)
+ if size == targetsCount {
+ if size > 1 {
+ aggregatedData := make([]json.RawMessage, size)
+ i := 0
+ state.targetResults.ForEach(func(_ string, rs mq.Result) bool {
+ aggregatedData[i] = rs.Payload
+ i++
+ return true
+ })
+ aggregatedPayload, err := json.Marshal(aggregatedData)
+ if err != nil {
+ panic(err)
}
- return
- }
- err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
- if err != nil {
- tm.reportNodeResult(mq.Result{Error: err}, true)
- tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
- return
+ state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID}
+ } else if size == 1 {
+ state.Result = state.targetResults.Values()[0]
}
+ state.Status = result.Status
+ state.Result.Status = result.Status
}
-}
-
-func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
- defer mq.RecoverPanic(mq.RecoverTitle)
- node, ok := tm.dag.nodes[nodeID]
- if !ok {
- return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
+ if state.Result.Payload == nil {
+ state.Result.Payload = result.Payload
}
- if tm.createdAt.IsZero() {
- tm.createdAt = time.Now()
- }
- tm.wg.Add(1)
- go func() {
- ctxx := context.Background()
- if headers, ok := mq.GetHeaders(ctx); ok {
- headers.Set(consts.QueueKey, node.Key)
- headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0))
- ctxx = mq.SetHeaders(ctxx, headers.AsMap())
- }
- go tm.processNode(ctx, node, payload)
- }()
- tm.wg.Wait()
- requestType, ok := mq.GetHeader(ctx, "request_type")
- if ok && requestType == "render" {
- return tm.renderResult(ctx)
- }
- return tm.dispatchFinalResult(ctx)
-}
-
-func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result {
- tm.topic = result.Topic
- defer func() {
- tm.wg.Done()
- mq.RecoverPanic(mq.RecoverTitle)
- }()
- if result.Ctx != nil {
- if headers, ok := mq.GetHeaders(ctx); ok {
- ctx = mq.SetHeaders(result.Ctx, headers.AsMap())
- }
- }
- node, ok := tm.dag.nodes[result.Topic]
- if !ok {
- return result
+ state.UpdatedAt = time.Now()
+ if result.Ctx == nil {
+ result.Ctx = ctx
}
if result.Error != nil {
- tm.reportNodeResult(result, true)
- tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
- return result
+ state.Status = mq.Failed
}
- edges := tm.getConditionalEdges(node, result)
- if len(edges) == 0 {
- tm.reportNodeResult(result, true)
- tm.ChangeNodeStatus(ctx, node.Key, Completed, result)
- return result
- } else {
- tm.reportNodeResult(result, false)
- }
- if node.Type == Page {
- return result
- }
- for _, edge := range edges {
- switch edge.Type {
- case Iterator:
- var items []json.RawMessage
- err := json.Unmarshal(result.Payload, &items)
- if err != nil {
- tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
- result.Error = err
- tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
- return result
+ pn, ok := tm.parentNodes.Get(state.NodeID)
+ if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed {
+ state.Status = mq.Processing
+ tm.iteratorNodes.Del(nodeID[0])
+ state.targetResults.Clear()
+ if len(nodeID) == 2 {
+ ctx = context.WithValue(ctx, ContextIndex, nodeID[1])
+ }
+ toProcess := nodeResult{
+ ctx: ctx,
+ nodeID: state.NodeID,
+ status: state.Status,
+ result: state.Result,
+ }
+ tm.handleEdges(toProcess, edges)
+ } else if ok {
+ if targetsCount == size {
+ parentState, _ := tm.taskStates.Get(pn)
+ if parentState != nil {
+ state.Result.Topic = state.NodeID
+ tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
}
- tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To))
- for _, target := range edge.To {
- for i, item := range items {
- tm.wg.Add(1)
- go func(ctx context.Context, target *Node, item json.RawMessage, i int) {
- ctxx := context.Background()
- if headers, ok := mq.GetHeaders(ctx); ok {
- headers.Set(consts.QueueKey, target.Key)
- headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i))
- ctxx = mq.SetHeaders(ctxx, headers.AsMap())
- }
- tm.processNode(ctxx, target, item)
- }(ctx, target, item, i)
+ }
+ } else {
+ tm.result = &state.Result
+ state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0]
+ tm.resultCh <- state.Result
+ tm.processFinalResult(state)
+ }
+}
+
+func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) {
+ state.UpdatedAt = time.Now()
+ if result.Ctx == nil {
+ result.Ctx = ctx
+ }
+ if result.Error != nil {
+ state.Status = mq.Failed
+ } else {
+ edges := tm.getConditionalEdges(node, result)
+ if len(edges) == 0 {
+ state.Status = mq.Completed
+ }
+ }
+ if result.Status == "" {
+ result.Status = state.Status
+ }
+ select {
+ case tm.resultQueue <- nodeResult{
+ ctx: ctx,
+ nodeID: state.NodeID,
+ result: result,
+ status: state.Status,
+ }:
+ default:
+ log.Println("Result queue is full, dropping result.")
+ }
+}
+
+func (tm *TaskManager) onNodeCompleted(rs nodeResult) {
+ nodeID := strings.Split(rs.nodeID, Delimiter)[0]
+ node, ok := tm.dag.nodes.Get(nodeID)
+ if !ok {
+ return
+ }
+ edges := tm.getConditionalEdges(node, rs.result)
+ hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0
+ if hasErrorOrCompleted {
+ if index, ok := rs.ctx.Value(ContextIndex).(string); ok {
+ childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index)
+ pn, ok := tm.parentNodes.Get(childNode)
+ if ok {
+ parentState, _ := tm.taskStates.Get(pn)
+ if parentState != nil {
+ pn = strings.Split(pn, Delimiter)[0]
+ tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true)
}
}
}
+ return
}
- for _, edge := range edges {
- switch edge.Type {
- case Simple:
- if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok {
- continue
- }
- tm.processEdge(ctx, edge, result)
- }
- }
- return result
-}
-
-func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) {
- tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To))
- index, _ := mq.GetHeader(ctx, "index")
- if index != "" && strings.Contains(index, "__") {
- index = strings.Split(index, "__")[1]
- } else {
- index = "0"
- }
- for _, target := range edge.To {
- tm.wg.Add(1)
- go func(ctx context.Context, target *Node, result mq.Result) {
- ctxx := context.Background()
- if headers, ok := mq.GetHeaders(ctx); ok {
- headers.Set(consts.QueueKey, target.Key)
- headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index))
- ctxx = mq.SetHeaders(ctxx, headers.AsMap())
- }
- tm.processNode(ctxx, target, result.Payload)
- }(ctx, target, result)
- }
+ tm.handleEdges(rs, edges)
}
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
edges := make([]Edge, len(node.Edges))
copy(edges, node.Edges)
if result.ConditionStatus != "" {
- if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok {
- if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok {
- if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
- edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
+ if conditions, ok := tm.dag.conditions[result.Topic]; ok {
+ if targetNodeKey, ok := conditions[result.ConditionStatus]; ok {
+ if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
+ edges = append(edges, Edge{From: node, To: targetNode})
}
} else if targetNodeKey, ok = conditions["default"]; ok {
- if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
- edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
+ if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
+ edges = append(edges, Edge{From: node, To: targetNode})
}
}
}
@@ -252,123 +312,77 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge
return edges
}
-func (tm *TaskManager) renderResult(ctx context.Context) mq.Result {
- var rs mq.Result
- tm.updateTS(&rs)
- tm.dag.callbackToConsumer(ctx, rs)
- tm.topic = rs.Topic
- return rs
-}
-
-func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) {
- topic := nodeID
- if !strings.Contains(nodeID, "__") {
- nodeID = getTopic(ctx, nodeID)
- } else {
- topic = strings.Split(nodeID, "__")[0]
- }
- nodeStatus, ok := tm.taskNodeStatus.Get(nodeID)
- if !ok || nodeStatus == nil {
- return
- }
-
- nodeStatus.markAs(rs, status)
- switch status {
- case Completed:
- canProceed := false
- edges, ok := tm.iteratorNodes.Get(topic)
- if ok {
- if len(edges) == 0 {
- canProceed = true
- } else {
- nodeStatus.status = Processing
- nodeStatus.totalItems = 1
- nodeStatus.itemResults.Clear()
- for _, edge := range edges {
- tm.processEdge(ctx, edge, rs)
+func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
+ for _, edge := range edges {
+ index, ok := currentResult.ctx.Value(ContextIndex).(string)
+ if !ok {
+ index = "0"
+ }
+ parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index)
+ if edge.Type == Simple {
+ if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok {
+ continue
+ }
+ }
+ if edge.Type == Iterator {
+ var items []json.RawMessage
+ err := json.Unmarshal(currentResult.result.Payload, &items)
+ if err != nil {
+ log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err)
+ tm.resultQueue <- nodeResult{
+ ctx: currentResult.ctx,
+ nodeID: edge.To.ID,
+ status: mq.Failed,
+ result: mq.Result{Error: err},
}
- tm.iteratorNodes.Del(topic)
+ return
}
- }
- if canProceed || !ok {
- if topic == tm.dag.startNode {
- tm.result = rs
- } else {
- tm.markParentTask(ctx, topic, nodeID, status, rs)
+ tm.childNodes.Set(parentNode, len(items))
+ for i, item := range items {
+ childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i)
+ ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i))
+ tm.parentNodes.Set(childNode, parentNode)
+ tm.send(ctx, edge.To.ID, tm.taskID, item)
}
- }
- case Failed:
- if topic == tm.dag.startNode {
- tm.result = rs
} else {
- tm.markParentTask(ctx, topic, nodeID, status, rs)
- }
- }
-}
-
-func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) {
- parentNodes, err := tm.dag.GetPreviousNodes(topic)
- if err != nil {
- return
- }
- var index string
- nodeParts := strings.Split(nodeID, "__")
- if len(nodeParts) == 2 {
- index = nodeParts[1]
- }
- for _, parentNode := range parentNodes {
- parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index)
- parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey)
- if !exists {
- parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0")
- parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey)
- }
- if exists {
- parentNodeStatus.itemResults.Set(nodeID, rs)
- if parentNodeStatus.IsDone() {
- rt := tm.prepareResult(ctx, parentNodeStatus)
- tm.ChangeNodeStatus(ctx, parentKey, status, rt)
+ tm.childNodes.Set(parentNode, 1)
+ idx, ok := currentResult.ctx.Value(ContextIndex).(string)
+ if !ok {
+ idx = "0"
}
+ childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
+ ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
+ tm.parentNodes.Set(childNode, parentNode)
+ tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
}
}
}
-func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result {
- aggregatedOutput := make([]json.RawMessage, 0)
- var status mq.Status
- var topic string
- var err1 error
- if nodeStatus.totalItems == 1 {
- rs := nodeStatus.itemResults.Values()[0]
- if rs.Ctx == nil {
- rs.Ctx = ctx
+func (tm *TaskManager) retryDeferredTasks() {
+ const maxRetries = 5
+ retries := 0
+ for retries < maxRetries {
+ select {
+ case <-tm.stopCh:
+ log.Println("Stopping Deferred task Retrier")
+ return
+ case <-time.After(RetryInterval):
+ tm.deferredTasks.ForEach(func(taskID string, task *task) bool {
+ tm.send(task.ctx, task.nodeID, taskID, task.payload)
+ retries++
+ return true
+ })
}
- return rs
}
- nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool {
- if topic == "" {
- topic = result.Topic
- status = result.Status
- }
- if result.Error != nil {
- err1 = result.Error
- return false
- }
- var item json.RawMessage
- err := json.Unmarshal(result.Payload, &item)
- if err != nil {
- err1 = err
- return false
- }
- aggregatedOutput = append(aggregatedOutput, item)
- return true
- })
- if err1 != nil {
- return mq.HandleError(ctx, err1)
- }
- finalOutput, err := json.Marshal(aggregatedOutput)
- if err != nil {
- return mq.HandleError(ctx, err)
- }
- return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx}
+}
+
+func (tm *TaskManager) processFinalResult(state *TaskState) {
+ state.targetResults.Clear()
+ if tm.dag.finalResult != nil {
+ tm.dag.finalResult(tm.taskID, state.Result)
+ }
+}
+
+func (tm *TaskManager) Stop() {
+ close(tm.stopCh)
}
diff --git a/dag/ui.go b/dag/ui.go
index d2074af..388bbe2 100644
--- a/dag/ui.go
+++ b/dag/ui.go
@@ -9,25 +9,24 @@ import (
func (tm *DAG) PrintGraph() {
fmt.Println("DAG Graph structure:")
- for _, node := range tm.nodes {
- fmt.Printf("Node: %s (%s) -> ", node.Name, node.Key)
- if conditions, ok := tm.conditions[FromNode(node.Key)]; ok {
+ tm.nodes.ForEach(func(_ string, node *Node) bool {
+ fmt.Printf("Node: %s (%s) -> ", node.Label, node.ID)
+ if conditions, ok := tm.conditions[node.ID]; ok {
var c []string
for when, then := range conditions {
- if target, ok := tm.nodes[string(then)]; ok {
- c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Name, target.Key))
+ if target, ok := tm.nodes.Get(then); ok {
+ c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Label, target.ID))
}
}
fmt.Println(strings.Join(c, ", "))
}
var edges []string
- for _, edge := range node.Edges {
- for _, target := range edge.To {
- edges = append(edges, fmt.Sprintf("%s (%s)", target.Name, target.Key))
- }
+ for _, target := range node.Edges {
+ edges = append(edges, fmt.Sprintf("%s (%s)", target.To.Label, target.To.ID))
}
fmt.Println(strings.Join(edges, ", "))
- }
+ return true
+ })
}
func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
@@ -44,7 +43,7 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
if startNode == "" {
firstNode := tm.findStartNode()
if firstNode != nil {
- startNode = firstNode.Key
+ startNode = firstNode.ID
}
}
if startNode == "" {
@@ -62,24 +61,22 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim
inRecursionStack[v] = true // mark node as part of recursion stack
*timeVal++
discoveryTime[v] = *timeVal
- node := tm.nodes[v]
+ node, _ := tm.nodes.Get(v)
hasCycle := false
var err error
for _, edge := range node.Edges {
- for _, adj := range edge.To {
- if !visited[adj.Key] {
- builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key))
- hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
- if err != nil {
- return true, err
- }
- if hasCycle {
- return true, nil
- }
- } else if inRecursionStack[adj.Key] {
- cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key)
- return true, fmt.Errorf(cycleMsg)
+ if !visited[edge.To.ID] {
+ builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, edge.To.ID))
+ hasCycle, err := tm.dfs(edge.To.ID, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
+ if err != nil {
+ return true, err
}
+ if hasCycle {
+ return true, nil
+ }
+ } else if inRecursionStack[edge.To.ID] {
+ cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, edge.To.ID)
+ return true, fmt.Errorf(cycleMsg)
}
}
hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
@@ -93,20 +90,20 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim
}
func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) {
- node := tm.nodes[v]
- for when, then := range tm.conditions[FromNode(node.Key)] {
- if targetNode, ok := tm.nodes[string(then)]; ok {
- if !visited[targetNode.Key] {
- builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key))
- hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder)
+ node, _ := tm.nodes.Get(v)
+ for when, then := range tm.conditions[node.ID] {
+ if targetNode, ok := tm.nodes.Get(then); ok {
+ if !visited[targetNode.ID] {
+ builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID))
+ hasCycle, err := tm.dfs(targetNode.ID, visited, discoveryTime, finishedTime, time, inRecursionStack, builder)
if err != nil {
return true, err
}
if hasCycle {
return true, nil
}
- } else if inRecursionStack[targetNode.Key] {
- cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)
+ } else if inRecursionStack[targetNode.ID] {
+ cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID)
return true, fmt.Errorf(cycleMsg)
}
}
@@ -146,98 +143,113 @@ func (tm *DAG) ExportDOT() string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name))
sb.WriteString("\n")
- sb.WriteString(fmt.Sprintf(` label="%s";`, tm.name))
+ sb.WriteString(` label="Enhanced DAG Representation";`)
sb.WriteString("\n")
- sb.WriteString(` labelloc="t";`)
+ sb.WriteString(` labelloc="t"; fontsize=22; fontname="Helvetica";`)
sb.WriteString("\n")
- sb.WriteString(` fontsize=20;`)
+ sb.WriteString(` node [shape=box, fontname="Helvetica", fillcolor="#B3CDE0", fontcolor="#2C3E50", fontsize=10, margin="0.25,0.15", style="rounded,filled"];`)
sb.WriteString("\n")
- sb.WriteString(` node [shape=box, style="rounded,filled", fillcolor="lightgray", fontname="Arial", margin="0.2,0.1"];`)
+ sb.WriteString(` edge [fontname="Helvetica", fontsize=12, arrowsize=0.8];`)
sb.WriteString("\n")
- sb.WriteString(` edge [fontname="Arial", fontsize=12, arrowsize=0.8];`)
- sb.WriteString("\n")
- sb.WriteString(` size="10,10";`)
- sb.WriteString("\n")
- sb.WriteString(` ratio="fill";`)
+ sb.WriteString(` rankdir=TB;`)
sb.WriteString("\n")
sortedNodes := tm.TopologicalSort()
for _, nodeKey := range sortedNodes {
- node := tm.nodes[nodeKey]
- nodeColor := "lightblue"
- sb.WriteString(fmt.Sprintf(` "%s" [label=" %s", fillcolor="%s", id="node_%s"];`, node.Key, node.Name, nodeColor, node.Key))
+ node, _ := tm.nodes.Get(nodeKey)
+ nodeColor := "lightgray"
+ nodeShape := "box"
+ labelSuffix := ""
+
+ // Apply styles based on NodeType
+ switch node.NodeType {
+ case Function:
+ nodeColor = "#D4EDDA"
+ labelSuffix = " [Function]"
+ case Page:
+ nodeColor = "#f0d2d1"
+ labelSuffix = " [Page]"
+ }
+ sb.WriteString(fmt.Sprintf(
+ ` "%s" [label="%s%s", fontcolor="#2C3E50", fillcolor="%s", shape="%s", style="rounded,filled", id="node_%s"];`,
+ node.ID, node.Label, labelSuffix, nodeColor, nodeShape, node.ID))
sb.WriteString("\n")
}
+
+ // Define edges with unique styling by EdgeType
for _, nodeKey := range sortedNodes {
- node := tm.nodes[nodeKey]
+ node, _ := tm.nodes.Get(nodeKey)
for _, edge := range node.Edges {
- var edgeStyle string
+ edgeStyle := "solid"
+ edgeColor := "black"
+ labelSuffix := ""
+
+ // Apply styles based on EdgeType
switch edge.Type {
case Iterator:
edgeStyle = "dashed"
- default:
+ edgeColor = "blue"
+ labelSuffix = " [Iter]"
+ case Simple:
edgeStyle = "solid"
+ edgeColor = "black"
+ labelSuffix = ""
}
- edgeColor := "black"
- for _, to := range edge.To {
- sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="%s", style=%s, fontsize=10, arrowsize=0.6];`, node.Key, to.Key, edge.Label, edgeColor, edgeStyle))
- sb.WriteString("\n")
- }
+ sb.WriteString(fmt.Sprintf(
+ ` "%s" -> "%s" [label="%s%s", color="%s", style="%s"];`,
+ node.ID, edge.To.ID, edge.Label, labelSuffix, edgeColor, edgeStyle))
+ sb.WriteString("\n")
}
}
for fromNodeKey, conditions := range tm.conditions {
for when, then := range conditions {
- if toNode, ok := tm.nodes[string(then)]; ok {
- sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.Key, when))
+ if toNode, ok := tm.nodes.Get(then); ok {
+ sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.ID, when))
sb.WriteString("\n")
}
}
}
+
+ // Optional: Group related nodes into subgraphs (e.g., loops)
for _, nodeKey := range sortedNodes {
- node := tm.nodes[nodeKey]
+ node, _ := tm.nodes.Get(nodeKey)
if node.processor != nil {
subDAG, _ := isDAGNode(node)
if subDAG != nil {
sb.WriteString(fmt.Sprintf(` subgraph "cluster_%s" {`, subDAG.name))
sb.WriteString("\n")
- sb.WriteString(fmt.Sprintf(` label=" %s";`, subDAG.name))
+ sb.WriteString(fmt.Sprintf(` label="Subgraph: %s";`, subDAG.name))
sb.WriteString("\n")
- sb.WriteString(` style=dashed;`)
+ sb.WriteString(` style=filled; color=gray90;`)
sb.WriteString("\n")
- sb.WriteString(` bgcolor="lightgray";`)
- sb.WriteString("\n")
- sb.WriteString(` node [shape=rectangle, style="filled", fillcolor="lightblue", fontname="Arial", margin="0.2,0.1"];`)
- sb.WriteString("\n")
- for subNodeKey, subNode := range subDAG.nodes {
- sb.WriteString(fmt.Sprintf(` "%s" [label=" %s"];`, subNodeKey, subNode.Name))
+ subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool {
+ sb.WriteString(fmt.Sprintf(` "%s" [label="%s"];`, subNode.ID, subNode.Label))
sb.WriteString("\n")
- }
- for subNodeKey, subNode := range subDAG.nodes {
+ return true
+ })
+ subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool {
for _, edge := range subNode.Edges {
- for _, to := range edge.To {
- sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, subNodeKey, to.Key, edge.Label))
- sb.WriteString("\n")
- }
+ sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label="%s"];`, subNodeKey, edge.To.ID, edge.Label))
+ sb.WriteString("\n")
}
- }
- sb.WriteString(` }`)
- sb.WriteString("\n")
- sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, node.Key, subDAG.startNode, subDAG.name))
- sb.WriteString("\n")
+ return true
+ })
+ sb.WriteString(" }\n")
}
}
}
- sb.WriteString(`}`)
- sb.WriteString("\n")
+
+ sb.WriteString("}\n")
return sb.String()
}
func (tm *DAG) TopologicalSort() (stack []string) {
visited := make(map[string]bool)
- for _, node := range tm.nodes {
- if !visited[node.Key] {
- tm.topologicalSortUtil(node.Key, visited, &stack)
+ tm.nodes.ForEach(func(_ string, node *Node) bool {
+ if !visited[node.ID] {
+ tm.topologicalSortUtil(node.ID, visited, &stack)
}
- }
+ return true
+ })
for i, j := 0, len(stack)-1; i < j; i, j = i+1, j-1 {
stack[i], stack[j] = stack[j], stack[i]
}
@@ -246,13 +258,23 @@ func (tm *DAG) TopologicalSort() (stack []string) {
func (tm *DAG) topologicalSortUtil(v string, visited map[string]bool, stack *[]string) {
visited[v] = true
- node := tm.nodes[v]
+ node, ok := tm.nodes.Get(v)
+ if !ok {
+ fmt.Println("Not found", v)
+ }
for _, edge := range node.Edges {
- for _, to := range edge.To {
- if !visited[to.Key] {
- tm.topologicalSortUtil(to.Key, visited, stack)
- }
+ if !visited[edge.To.ID] {
+ tm.topologicalSortUtil(edge.To.ID, visited, stack)
}
}
*stack = append(*stack, v)
}
+
+func isDAGNode(node *Node) (*DAG, bool) {
+ switch node := node.processor.(type) {
+ case *DAG:
+ return node, true
+ default:
+ return nil, false
+ }
+}
diff --git a/dag/v1/api.go b/dag/v1/api.go
new file mode 100644
index 0000000..bf6be8f
--- /dev/null
+++ b/dag/v1/api.go
@@ -0,0 +1,209 @@
+package v1
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "time"
+
+ "github.com/oarkflow/mq/jsonparser"
+ "github.com/oarkflow/mq/sio"
+
+ "github.com/oarkflow/mq"
+ "github.com/oarkflow/mq/consts"
+ "github.com/oarkflow/mq/metrics"
+)
+
+type Request struct {
+ Payload json.RawMessage `json:"payload"`
+ Interval time.Duration `json:"interval"`
+ Schedule bool `json:"schedule"`
+ Overlap bool `json:"overlap"`
+ Recurring bool `json:"recurring"`
+}
+
+func (tm *DAG) SetupWS() *sio.Server {
+ ws := sio.New(sio.Config{
+ CheckOrigin: func(r *http.Request) bool { return true },
+ EnableCompression: true,
+ })
+ WsEvents(ws)
+ tm.Notifier = ws
+ return ws
+}
+
+func (tm *DAG) Handlers() {
+ metrics.HandleHTTP()
+ http.Handle("/", http.FileServer(http.Dir("webroot")))
+ http.Handle("/notify", tm.SetupWS())
+ http.HandleFunc("GET /render", tm.Render)
+ http.HandleFunc("POST /request", tm.Request)
+ http.HandleFunc("POST /publish", tm.Publish)
+ http.HandleFunc("POST /schedule", tm.Schedule)
+ http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
+ id := request.PathValue("id")
+ if id != "" {
+ tm.PauseConsumer(request.Context(), id)
+ }
+ })
+ http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
+ id := request.PathValue("id")
+ if id != "" {
+ tm.ResumeConsumer(request.Context(), id)
+ }
+ })
+ http.HandleFunc("/pause", func(w http.ResponseWriter, request *http.Request) {
+ err := tm.Pause(request.Context())
+ if err != nil {
+ http.Error(w, "Failed to pause", http.StatusBadRequest)
+ return
+ }
+ json.NewEncoder(w).Encode(map[string]string{"status": "paused"})
+ })
+ http.HandleFunc("/resume", func(w http.ResponseWriter, request *http.Request) {
+ err := tm.Resume(request.Context())
+ if err != nil {
+ http.Error(w, "Failed to resume", http.StatusBadRequest)
+ return
+ }
+ json.NewEncoder(w).Encode(map[string]string{"status": "resumed"})
+ })
+ http.HandleFunc("/stop", func(w http.ResponseWriter, request *http.Request) {
+ err := tm.Stop(request.Context())
+ if err != nil {
+ http.Error(w, "Failed to read request body", http.StatusBadRequest)
+ return
+ }
+ json.NewEncoder(w).Encode(map[string]string{"status": "stopped"})
+ })
+ http.HandleFunc("/close", func(w http.ResponseWriter, request *http.Request) {
+ err := tm.Close()
+ if err != nil {
+ http.Error(w, "Failed to read request body", http.StatusBadRequest)
+ return
+ }
+ json.NewEncoder(w).Encode(map[string]string{"status": "closed"})
+ })
+ http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ fmt.Fprintln(w, tm.ExportDOT())
+ })
+ http.HandleFunc("/ui", func(w http.ResponseWriter, r *http.Request) {
+ image := fmt.Sprintf("%s.svg", mq.NewID())
+ err := tm.SaveSVG(image)
+ if err != nil {
+ http.Error(w, "Failed to read request body", http.StatusBadRequest)
+ return
+ }
+ defer os.Remove(image)
+ svgBytes, err := os.ReadFile(image)
+ if err != nil {
+ http.Error(w, "Could not read SVG file", http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "image/svg+xml")
+ if _, err := w.Write(svgBytes); err != nil {
+ http.Error(w, "Could not write SVG response", http.StatusInternalServerError)
+ return
+ }
+ })
+}
+
+func (tm *DAG) request(w http.ResponseWriter, r *http.Request, async bool) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
+ return
+ }
+ var request Request
+ if r.Body != nil {
+ defer r.Body.Close()
+ var err error
+ payload, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, "Failed to read request body", http.StatusBadRequest)
+ return
+ }
+ err = json.Unmarshal(payload, &request)
+ if err != nil {
+ http.Error(w, "Failed to unmarshal body", http.StatusBadRequest)
+ return
+ }
+ } else {
+ http.Error(w, "Empty request body", http.StatusBadRequest)
+ return
+ }
+ ctx := r.Context()
+ if async {
+ ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
+ }
+ var opts []mq.SchedulerOption
+ if request.Interval > 0 {
+ opts = append(opts, mq.WithInterval(request.Interval))
+ }
+ if request.Overlap {
+ opts = append(opts, mq.WithOverlap())
+ }
+ if request.Recurring {
+ opts = append(opts, mq.WithRecurring())
+ }
+ ctx = context.WithValue(ctx, "query_params", r.URL.Query())
+ var rs mq.Result
+ if request.Schedule {
+ rs = tm.ScheduleTask(ctx, request.Payload, opts...)
+ } else {
+ rs = tm.Process(ctx, request.Payload)
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(rs)
+}
+
+func (tm *DAG) Render(w http.ResponseWriter, r *http.Request) {
+ ctx := mq.SetHeaders(r.Context(), map[string]string{consts.AwaitResponseKey: "true", "request_type": "render"})
+ ctx = context.WithValue(ctx, "query_params", r.URL.Query())
+ rs := tm.Process(ctx, nil)
+ content, err := jsonparser.GetString(rs.Payload, "html_content")
+ if err != nil {
+ http.Error(w, "Failed to read request body", http.StatusBadRequest)
+ return
+ }
+ w.Header().Set("Content-Type", consts.TypeHtml)
+ w.Write([]byte(content))
+}
+
+func (tm *DAG) Request(w http.ResponseWriter, r *http.Request) {
+ tm.request(w, r, true)
+}
+
+func (tm *DAG) Publish(w http.ResponseWriter, r *http.Request) {
+ tm.request(w, r, false)
+}
+
+func (tm *DAG) Schedule(w http.ResponseWriter, r *http.Request) {
+ tm.request(w, r, false)
+}
+
+func GetTaskID(ctx context.Context) string {
+ if queryParams := ctx.Value("query_params"); queryParams != nil {
+ if params, ok := queryParams.(url.Values); ok {
+ if id := params.Get("taskID"); id != "" {
+ return id
+ }
+ }
+ }
+ return ""
+}
+
+func CanNextNode(ctx context.Context) string {
+ if queryParams := ctx.Value("query_params"); queryParams != nil {
+ if params, ok := queryParams.(url.Values); ok {
+ if id := params.Get("next"); id != "" {
+ return id
+ }
+ }
+ }
+ return ""
+}
diff --git a/dag/v1/consts.go b/dag/v1/consts.go
new file mode 100644
index 0000000..ef48130
--- /dev/null
+++ b/dag/v1/consts.go
@@ -0,0 +1,26 @@
+package v1
+
+type NodeStatus int
+
+func (c NodeStatus) IsValid() bool { return c >= Pending && c <= Failed }
+
+func (c NodeStatus) String() string {
+ switch c {
+ case Pending:
+ return "Pending"
+ case Processing:
+ return "Processing"
+ case Completed:
+ return "Completed"
+ case Failed:
+ return "Failed"
+ }
+ return ""
+}
+
+const (
+ Pending NodeStatus = iota
+ Processing
+ Completed
+ Failed
+)
diff --git a/dag/v1/dag.go b/dag/v1/dag.go
new file mode 100644
index 0000000..98b38bd
--- /dev/null
+++ b/dag/v1/dag.go
@@ -0,0 +1,612 @@
+package v1
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "net/http"
+ "time"
+
+ "github.com/oarkflow/mq/storage"
+ "github.com/oarkflow/mq/storage/memory"
+
+ "github.com/oarkflow/mq/sio"
+
+ "golang.org/x/time/rate"
+
+ "github.com/oarkflow/mq"
+ "github.com/oarkflow/mq/consts"
+ "github.com/oarkflow/mq/metrics"
+)
+
+type EdgeType int
+
+func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
+
+const (
+ Simple EdgeType = iota
+ Iterator
+)
+
+type NodeType int
+
+func (c NodeType) IsValid() bool { return c >= Process && c <= Page }
+
+const (
+ Process NodeType = iota
+ Page
+)
+
+type Node struct {
+ processor mq.Processor
+ Name string
+ Type NodeType
+ Key string
+ Edges []Edge
+ isReady bool
+}
+
+func (n *Node) ProcessTask(ctx context.Context, msg *mq.Task) mq.Result {
+ return n.processor.ProcessTask(ctx, msg)
+}
+
+func (n *Node) Close() error {
+ return n.processor.Close()
+}
+
+type Edge struct {
+ Label string
+ From *Node
+ To []*Node
+ Type EdgeType
+}
+
+type (
+ FromNode string
+ When string
+ Then string
+)
+
+type DAG struct {
+ server *mq.Broker
+ consumer *mq.Consumer
+ taskContext storage.IMap[string, *TaskManager]
+ nodes map[string]*Node
+ iteratorNodes storage.IMap[string, []Edge]
+ conditions map[FromNode]map[When]Then
+ pool *mq.Pool
+ taskCleanupCh chan string
+ name string
+ key string
+ startNode string
+ consumerTopic string
+ opts []mq.Option
+ reportNodeResultCallback func(mq.Result)
+ Notifier *sio.Server
+ paused bool
+ Error error
+ report string
+ index string
+}
+
+func (tm *DAG) SetKey(key string) {
+ tm.key = key
+}
+
+func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
+ tm.reportNodeResultCallback = callback
+}
+
+func (tm *DAG) GetType() string {
+ return tm.key
+}
+
+func (tm *DAG) listenForTaskCleanup() {
+ for taskID := range tm.taskCleanupCh {
+ if tm.server.Options().CleanTaskOnComplete() {
+ tm.taskCleanup(taskID)
+ }
+ }
+}
+
+func (tm *DAG) taskCleanup(taskID string) {
+ tm.taskContext.Del(taskID)
+ log.Printf("DAG - Task %s cleaned up", taskID)
+}
+
+func (tm *DAG) Consume(ctx context.Context) error {
+ if tm.consumer != nil {
+ tm.server.Options().SetSyncMode(true)
+ return tm.consumer.Consume(ctx)
+ }
+ return nil
+}
+
+func (tm *DAG) Stop(ctx context.Context) error {
+ for _, n := range tm.nodes {
+ err := n.processor.Stop(ctx)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (tm *DAG) GetKey() string {
+ return tm.key
+}
+
+func (tm *DAG) AssignTopic(topic string) {
+ tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
+ tm.consumerTopic = topic
+}
+
+func NewDAG(name, key string, opts ...mq.Option) *DAG {
+ callback := func(ctx context.Context, result mq.Result) error { return nil }
+ d := &DAG{
+ name: name,
+ key: key,
+ nodes: make(map[string]*Node),
+ iteratorNodes: memory.New[string, []Edge](),
+ taskContext: memory.New[string, *TaskManager](),
+ conditions: make(map[FromNode]map[When]Then),
+ taskCleanupCh: make(chan string),
+ }
+ opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
+ d.server = mq.NewBroker(opts...)
+ d.opts = opts
+ options := d.server.Options()
+ d.pool = mq.NewPool(
+ options.NumOfWorkers(),
+ mq.WithTaskQueueSize(options.QueueSize()),
+ mq.WithMaxMemoryLoad(options.MaxMemoryLoad()),
+ mq.WithHandler(d.ProcessTask),
+ mq.WithPoolCallback(callback),
+ mq.WithTaskStorage(options.Storage()),
+ )
+ d.pool.Start(d.server.Options().NumOfWorkers())
+ go d.listenForTaskCleanup()
+ return d
+}
+
+func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
+ if tm.consumer != nil {
+ result.Topic = tm.consumerTopic
+ if tm.consumer.Conn() == nil {
+ tm.onTaskCallback(ctx, result)
+ } else {
+ tm.consumer.OnResponse(ctx, result)
+ }
+ }
+}
+
+func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
+ if taskContext, ok := tm.taskContext.Get(result.TaskID); ok && result.Topic != "" {
+ return taskContext.handleNextTask(ctx, result)
+ }
+ return mq.Result{}
+}
+
+func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
+ if node, ok := tm.nodes[topic]; ok {
+ log.Printf("DAG - CONSUMER ~> ready on %s", topic)
+ node.isReady = true
+ }
+}
+
+func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) {
+ if node, ok := tm.nodes[topic]; ok {
+ log.Printf("DAG - CONSUMER ~> down on %s", topic)
+ node.isReady = false
+ }
+}
+
+func (tm *DAG) SetStartNode(node string) {
+ tm.startNode = node
+}
+
+func (tm *DAG) GetStartNode() string {
+ return tm.startNode
+}
+
+func (tm *DAG) Start(ctx context.Context, addr string) error {
+ // Start the server in a separate goroutine
+ go func() {
+ defer mq.RecoverPanic(mq.RecoverTitle)
+ if err := tm.server.Start(ctx); err != nil {
+ panic(err)
+ }
+ }()
+
+ // Start the node consumers if not in sync mode
+ if !tm.server.SyncMode() {
+ for _, con := range tm.nodes {
+ go func(con *Node) {
+ defer mq.RecoverPanic(mq.RecoverTitle)
+ limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
+ for {
+ err := con.processor.Consume(ctx)
+ if err != nil {
+ log.Printf("[ERROR] - Consumer %s failed to start: %v", con.Key, err)
+ } else {
+ log.Printf("[INFO] - Consumer %s started successfully", con.Key)
+ break
+ }
+ limiter.Wait(ctx) // Wait with rate limiting before retrying
+ }
+ }(con)
+ }
+ }
+ log.Printf("DAG - HTTP_SERVER ~> started on http://localhost%s", addr)
+ tm.Handlers()
+ config := tm.server.TLSConfig()
+ if config.UseTLS {
+ return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
+ }
+ return http.ListenAndServe(addr, nil)
+}
+
+func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
+ dag.AssignTopic(key)
+ tm.nodes[key] = &Node{
+ Name: name,
+ Key: key,
+ processor: dag,
+ isReady: true,
+ }
+ if len(firstNode) > 0 && firstNode[0] {
+ tm.startNode = key
+ }
+ return tm
+}
+
+func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG {
+ con := mq.NewConsumer(key, key, handler.ProcessTask, tm.opts...)
+ n := &Node{
+ Name: name,
+ Key: key,
+ processor: con,
+ }
+ if handler.GetType() == "page" {
+ n.Type = Page
+ }
+ if tm.server.SyncMode() {
+ n.isReady = true
+ }
+ tm.nodes[key] = n
+ if len(firstNode) > 0 && firstNode[0] {
+ tm.startNode = key
+ }
+ return tm
+}
+
+func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error {
+ if tm.server.SyncMode() {
+ return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
+ }
+ tm.nodes[key] = &Node{
+ Name: name,
+ Key: key,
+ }
+ if len(firstNode) > 0 && firstNode[0] {
+ tm.startNode = key
+ }
+ return nil
+}
+
+func (tm *DAG) IsReady() bool {
+ for _, node := range tm.nodes {
+ if !node.isReady {
+ return false
+ }
+ }
+ return true
+}
+
+func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG {
+ tm.conditions[fromNode] = conditions
+ return tm
+}
+
+func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG {
+ tm.Error = tm.addEdge(Iterator, label, from, targets...)
+ tm.iteratorNodes.Set(from, []Edge{})
+ return tm
+}
+
+func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG {
+ tm.Error = tm.addEdge(Simple, label, from, targets...)
+ return tm
+}
+
+func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error {
+ fromNode, ok := tm.nodes[from]
+ if !ok {
+ return fmt.Errorf("Error: 'from' node %s does not exist\n", from)
+ }
+ var nodes []*Node
+ for _, target := range targets {
+ toNode, ok := tm.nodes[target]
+ if !ok {
+ return fmt.Errorf("Error: 'from' node %s does not exist\n", target)
+ }
+ nodes = append(nodes, toNode)
+ }
+ edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label}
+ fromNode.Edges = append(fromNode.Edges, edge)
+ if edgeType != Iterator {
+ if edges, ok := tm.iteratorNodes.Get(fromNode.Key); ok {
+ edges = append(edges, edge)
+ tm.iteratorNodes.Set(fromNode.Key, edges)
+ }
+ }
+ return nil
+}
+
+func (tm *DAG) Validate() error {
+ report, hasCycle, err := tm.ClassifyEdges()
+ if hasCycle || err != nil {
+ tm.Error = err
+ return err
+ }
+ tm.report = report
+ return nil
+}
+
+func (tm *DAG) GetReport() string {
+ return tm.report
+}
+
+func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
+ if task.ID == "" {
+ task.ID = mq.NewID()
+ }
+ if index, ok := mq.GetHeader(ctx, "index"); ok {
+ tm.index = index
+ }
+ manager, exists := tm.taskContext.Get(task.ID)
+ if !exists {
+ manager = NewTaskManager(tm, task.ID, tm.iteratorNodes)
+ manager.createdAt = task.CreatedAt
+ tm.taskContext.Set(task.ID, manager)
+ }
+
+ if tm.consumer != nil {
+ initialNode, err := tm.parseInitialNode(ctx)
+ if err != nil {
+ metrics.TasksErrors.WithLabelValues("unknown").Inc() // Increase error count
+ return mq.Result{Error: err}
+ }
+ task.Topic = initialNode
+ }
+ if manager.topic != "" {
+ task.Topic = manager.topic
+ canNext := CanNextNode(ctx)
+ if canNext != "" {
+ if n, ok := tm.nodes[task.Topic]; ok {
+ if len(n.Edges) > 0 {
+ task.Topic = n.Edges[0].To[0].Key
+ }
+ }
+ } else {
+ }
+ }
+ result := manager.processTask(ctx, task.Topic, task.Payload)
+ if result.Ctx != nil && tm.index != "" {
+ result.Ctx = mq.SetHeaders(result.Ctx, map[string]string{"index": tm.index})
+ }
+ if result.Error != nil {
+ metrics.TasksErrors.WithLabelValues(task.Topic).Inc() // Increase error count
+ } else {
+ metrics.TasksProcessed.WithLabelValues("success").Inc() // Increase processed task count
+ }
+ return result
+}
+
+func (tm *DAG) check(ctx context.Context, payload []byte) (context.Context, *mq.Task, error) {
+ if tm.paused {
+ return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not accepting any task")
+ }
+ if !tm.IsReady() {
+ return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not ready yet")
+ }
+ initialNode, err := tm.parseInitialNode(ctx)
+ if err != nil {
+ return ctx, nil, err
+ }
+ if tm.server.SyncMode() {
+ ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
+ }
+ taskID := GetTaskID(ctx)
+ if taskID != "" {
+ if _, exists := tm.taskContext.Get(taskID); !exists {
+ return ctx, nil, fmt.Errorf("provided task ID doesn't exist")
+ }
+ }
+ if taskID == "" {
+ taskID = mq.NewID()
+ }
+ return ctx, mq.NewTask(taskID, payload, initialNode), nil
+}
+
+func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
+ ctx, task, err := tm.check(ctx, payload)
+ if err != nil {
+ return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
+ }
+ awaitResponse, _ := mq.GetAwaitResponse(ctx)
+ if awaitResponse != "true" {
+ headers, ok := mq.GetHeaders(ctx)
+ ctxx := context.Background()
+ if ok {
+ ctxx = mq.SetHeaders(ctxx, headers.AsMap())
+ }
+ if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil {
+ return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "FAILED", Error: err}
+ }
+ return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"}
+ }
+ return tm.ProcessTask(ctx, task)
+}
+
+func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
+ ctx, task, err := tm.check(ctx, payload)
+ if err != nil {
+ return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
+ }
+ headers, ok := mq.GetHeaders(ctx)
+ ctxx := context.Background()
+ if ok {
+ ctxx = mq.SetHeaders(ctxx, headers.AsMap())
+ }
+ tm.pool.Scheduler().AddTask(ctxx, task, opts...)
+ return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"}
+}
+
+func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
+ val := ctx.Value("initial_node")
+ initialNode, ok := val.(string)
+ if ok {
+ return initialNode, nil
+ }
+ if tm.startNode == "" {
+ firstNode := tm.findStartNode()
+ if firstNode != nil {
+ tm.startNode = firstNode.Key
+ }
+ }
+
+ if tm.startNode == "" {
+ return "", fmt.Errorf("initial node not found")
+ }
+ return tm.startNode, nil
+}
+
+func (tm *DAG) findStartNode() *Node {
+ incomingEdges := make(map[string]bool)
+ connectedNodes := make(map[string]bool)
+ for _, node := range tm.nodes {
+ for _, edge := range node.Edges {
+ if edge.Type.IsValid() {
+ connectedNodes[node.Key] = true
+ for _, to := range edge.To {
+ connectedNodes[to.Key] = true
+ incomingEdges[to.Key] = true
+ }
+ }
+ }
+ if cond, ok := tm.conditions[FromNode(node.Key)]; ok {
+ for _, target := range cond {
+ connectedNodes[string(target)] = true
+ incomingEdges[string(target)] = true
+ }
+ }
+ }
+ for nodeID, node := range tm.nodes {
+ if !incomingEdges[nodeID] && connectedNodes[nodeID] {
+ return node
+ }
+ }
+ return nil
+}
+
+func (tm *DAG) Pause(_ context.Context) error {
+ tm.paused = true
+ return nil
+}
+
+func (tm *DAG) Resume(_ context.Context) error {
+ tm.paused = false
+ return nil
+}
+
+func (tm *DAG) Close() error {
+ for _, n := range tm.nodes {
+ err := n.Close()
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (tm *DAG) PauseConsumer(ctx context.Context, id string) {
+ tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE)
+}
+
+func (tm *DAG) ResumeConsumer(ctx context.Context, id string) {
+ tm.doConsumer(ctx, id, consts.CONSUMER_RESUME)
+}
+
+func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) {
+ if node, ok := tm.nodes[id]; ok {
+ switch action {
+ case consts.CONSUMER_PAUSE:
+ err := node.processor.Pause(ctx)
+ if err == nil {
+ node.isReady = false
+ log.Printf("[INFO] - Consumer %s paused successfully", node.Key)
+ } else {
+ log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.Key, err)
+ }
+ case consts.CONSUMER_RESUME:
+ err := node.processor.Resume(ctx)
+ if err == nil {
+ node.isReady = true
+ log.Printf("[INFO] - Consumer %s resumed successfully", node.Key)
+ } else {
+ log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.Key, err)
+ }
+ }
+ } else {
+ log.Printf("[WARNING] - Consumer %s not found", id)
+ }
+}
+
+func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
+ tm.server.SetNotifyHandler(callback)
+}
+
+func (tm *DAG) GetNextNodes(key string) ([]*Node, error) {
+ node, exists := tm.nodes[key]
+ if !exists {
+ return nil, fmt.Errorf("Node with key %s does not exist", key)
+ }
+ var successors []*Node
+ for _, edge := range node.Edges {
+ successors = append(successors, edge.To...)
+ }
+ if conds, exists := tm.conditions[FromNode(key)]; exists {
+ for _, targetKey := range conds {
+ if targetNode, exists := tm.nodes[string(targetKey)]; exists {
+ successors = append(successors, targetNode)
+ }
+ }
+ }
+ return successors, nil
+}
+
+func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) {
+ var predecessors []*Node
+ for _, node := range tm.nodes {
+ for _, edge := range node.Edges {
+ for _, target := range edge.To {
+ if target.Key == key {
+ predecessors = append(predecessors, node)
+ }
+ }
+ }
+ }
+ for fromNode, conds := range tm.conditions {
+ for _, targetKey := range conds {
+ if string(targetKey) == key {
+ node, exists := tm.nodes[string(fromNode)]
+ if !exists {
+ return nil, fmt.Errorf("Node with key %s does not exist", fromNode)
+ }
+ predecessors = append(predecessors, node)
+ }
+ }
+ }
+ return predecessors, nil
+}
diff --git a/dag/v2/operation.go b/dag/v1/operation.go
similarity index 99%
rename from dag/v2/operation.go
rename to dag/v1/operation.go
index f926bad..9500bb6 100644
--- a/dag/v2/operation.go
+++ b/dag/v1/operation.go
@@ -1,4 +1,4 @@
-package v2
+package v1
import (
"context"
diff --git a/dag/v2/operations.go b/dag/v1/operations.go
similarity index 99%
rename from dag/v2/operations.go
rename to dag/v1/operations.go
index 7678fdf..a7fc07f 100644
--- a/dag/v2/operations.go
+++ b/dag/v1/operations.go
@@ -1,4 +1,4 @@
-package v2
+package v1
import (
"sync"
diff --git a/dag/v1/task_manager.go b/dag/v1/task_manager.go
new file mode 100644
index 0000000..27c8848
--- /dev/null
+++ b/dag/v1/task_manager.go
@@ -0,0 +1,374 @@
+package v1
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/oarkflow/mq/consts"
+ "github.com/oarkflow/mq/storage"
+ "github.com/oarkflow/mq/storage/memory"
+
+ "github.com/oarkflow/mq"
+)
+
+type TaskManager struct {
+ createdAt time.Time
+ processedAt time.Time
+ status string
+ dag *DAG
+ taskID string
+ wg *WaitGroup
+ topic string
+ result mq.Result
+
+ iteratorNodes storage.IMap[string, []Edge]
+ taskNodeStatus storage.IMap[string, *taskNodeStatus]
+}
+
+func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
+ return &TaskManager{
+ dag: d,
+ taskNodeStatus: memory.New[string, *taskNodeStatus](),
+ taskID: taskID,
+ iteratorNodes: iteratorNodes,
+ wg: NewWaitGroup(),
+ }
+}
+
+func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
+ tm.updateTS(&tm.result)
+ tm.dag.callbackToConsumer(ctx, tm.result)
+ if tm.dag.server.NotifyHandler() != nil {
+ _ = tm.dag.server.NotifyHandler()(ctx, tm.result)
+ }
+ tm.dag.taskCleanupCh <- tm.taskID
+ tm.topic = tm.result.Topic
+ return tm.result
+}
+
+func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) {
+ if tm.dag.reportNodeResultCallback != nil {
+ tm.dag.reportNodeResultCallback(result)
+ }
+}
+
+func (tm *TaskManager) SetTotalItems(topic string, i int) {
+ if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok {
+ nodeStatus.totalItems = i
+ }
+}
+
+func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
+ topic := getTopic(ctx, node.Key)
+ tm.taskNodeStatus.Set(topic, newNodeStatus(topic))
+ defer mq.RecoverPanic(mq.RecoverTitle)
+ dag, isDAG := isDAGNode(node)
+ if isDAG {
+ if tm.dag.server.SyncMode() && !dag.server.SyncMode() {
+ dag.server.Options().SetSyncMode(true)
+ }
+ }
+ tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key})
+ var result mq.Result
+ if tm.dag.server.SyncMode() {
+ defer func() {
+ if isDAG {
+ result.Topic = dag.consumerTopic
+ result.TaskID = tm.taskID
+ tm.reportNodeResult(result, false)
+ tm.handleNextTask(result.Ctx, result)
+ } else {
+ result.Topic = node.Key
+ tm.reportNodeResult(result, false)
+ tm.handleNextTask(ctx, result)
+ }
+ }()
+ }
+ select {
+ case <-ctx.Done():
+ result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
+ tm.reportNodeResult(result, true)
+ tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
+ return
+ default:
+ ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
+ if tm.dag.server.SyncMode() {
+ result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key))
+ if result.Error != nil {
+ tm.reportNodeResult(result, true)
+ tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
+ return
+ }
+ return
+ }
+ err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
+ if err != nil {
+ tm.reportNodeResult(mq.Result{Error: err}, true)
+ tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
+ return
+ }
+ }
+}
+
+func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
+ defer mq.RecoverPanic(mq.RecoverTitle)
+ node, ok := tm.dag.nodes[nodeID]
+ if !ok {
+ return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
+ }
+ if tm.createdAt.IsZero() {
+ tm.createdAt = time.Now()
+ }
+ tm.wg.Add(1)
+ go func() {
+ ctxx := context.Background()
+ if headers, ok := mq.GetHeaders(ctx); ok {
+ headers.Set(consts.QueueKey, node.Key)
+ headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0))
+ ctxx = mq.SetHeaders(ctxx, headers.AsMap())
+ }
+ go tm.processNode(ctx, node, payload)
+ }()
+ tm.wg.Wait()
+ requestType, ok := mq.GetHeader(ctx, "request_type")
+ if ok && requestType == "render" {
+ return tm.renderResult(ctx)
+ }
+ return tm.dispatchFinalResult(ctx)
+}
+
+func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result {
+ tm.topic = result.Topic
+ defer func() {
+ tm.wg.Done()
+ mq.RecoverPanic(mq.RecoverTitle)
+ }()
+ if result.Ctx != nil {
+ if headers, ok := mq.GetHeaders(ctx); ok {
+ ctx = mq.SetHeaders(result.Ctx, headers.AsMap())
+ }
+ }
+ node, ok := tm.dag.nodes[result.Topic]
+ if !ok {
+ return result
+ }
+ if result.Error != nil {
+ tm.reportNodeResult(result, true)
+ tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
+ return result
+ }
+ edges := tm.getConditionalEdges(node, result)
+ if len(edges) == 0 {
+ tm.reportNodeResult(result, true)
+ tm.ChangeNodeStatus(ctx, node.Key, Completed, result)
+ return result
+ } else {
+ tm.reportNodeResult(result, false)
+ }
+ if node.Type == Page {
+ return result
+ }
+ for _, edge := range edges {
+ switch edge.Type {
+ case Iterator:
+ var items []json.RawMessage
+ err := json.Unmarshal(result.Payload, &items)
+ if err != nil {
+ tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
+ result.Error = err
+ tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
+ return result
+ }
+ tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To))
+ for _, target := range edge.To {
+ for i, item := range items {
+ tm.wg.Add(1)
+ go func(ctx context.Context, target *Node, item json.RawMessage, i int) {
+ ctxx := context.Background()
+ if headers, ok := mq.GetHeaders(ctx); ok {
+ headers.Set(consts.QueueKey, target.Key)
+ headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i))
+ ctxx = mq.SetHeaders(ctxx, headers.AsMap())
+ }
+ tm.processNode(ctxx, target, item)
+ }(ctx, target, item, i)
+ }
+ }
+ }
+ }
+ for _, edge := range edges {
+ switch edge.Type {
+ case Simple:
+ if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok {
+ continue
+ }
+ tm.processEdge(ctx, edge, result)
+ }
+ }
+ return result
+}
+
+func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) {
+ tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To))
+ index, _ := mq.GetHeader(ctx, "index")
+ if index != "" && strings.Contains(index, "__") {
+ index = strings.Split(index, "__")[1]
+ } else {
+ index = "0"
+ }
+ for _, target := range edge.To {
+ tm.wg.Add(1)
+ go func(ctx context.Context, target *Node, result mq.Result) {
+ ctxx := context.Background()
+ if headers, ok := mq.GetHeaders(ctx); ok {
+ headers.Set(consts.QueueKey, target.Key)
+ headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index))
+ ctxx = mq.SetHeaders(ctxx, headers.AsMap())
+ }
+ tm.processNode(ctxx, target, result.Payload)
+ }(ctx, target, result)
+ }
+}
+
+func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
+ edges := make([]Edge, len(node.Edges))
+ copy(edges, node.Edges)
+ if result.ConditionStatus != "" {
+ if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok {
+ if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok {
+ if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
+ edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
+ }
+ } else if targetNodeKey, ok = conditions["default"]; ok {
+ if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
+ edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
+ }
+ }
+ }
+ }
+ return edges
+}
+
+func (tm *TaskManager) renderResult(ctx context.Context) mq.Result {
+ var rs mq.Result
+ tm.updateTS(&rs)
+ tm.dag.callbackToConsumer(ctx, rs)
+ tm.topic = rs.Topic
+ return rs
+}
+
+func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) {
+ topic := nodeID
+ if !strings.Contains(nodeID, "__") {
+ nodeID = getTopic(ctx, nodeID)
+ } else {
+ topic = strings.Split(nodeID, "__")[0]
+ }
+ nodeStatus, ok := tm.taskNodeStatus.Get(nodeID)
+ if !ok || nodeStatus == nil {
+ return
+ }
+
+ nodeStatus.markAs(rs, status)
+ switch status {
+ case Completed:
+ canProceed := false
+ edges, ok := tm.iteratorNodes.Get(topic)
+ if ok {
+ if len(edges) == 0 {
+ canProceed = true
+ } else {
+ nodeStatus.status = Processing
+ nodeStatus.totalItems = 1
+ nodeStatus.itemResults.Clear()
+ for _, edge := range edges {
+ tm.processEdge(ctx, edge, rs)
+ }
+ tm.iteratorNodes.Del(topic)
+ }
+ }
+ if canProceed || !ok {
+ if topic == tm.dag.startNode {
+ tm.result = rs
+ } else {
+ tm.markParentTask(ctx, topic, nodeID, status, rs)
+ }
+ }
+ case Failed:
+ if topic == tm.dag.startNode {
+ tm.result = rs
+ } else {
+ tm.markParentTask(ctx, topic, nodeID, status, rs)
+ }
+ }
+}
+
+func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) {
+ parentNodes, err := tm.dag.GetPreviousNodes(topic)
+ if err != nil {
+ return
+ }
+ var index string
+ nodeParts := strings.Split(nodeID, "__")
+ if len(nodeParts) == 2 {
+ index = nodeParts[1]
+ }
+ for _, parentNode := range parentNodes {
+ parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index)
+ parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey)
+ if !exists {
+ parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0")
+ parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey)
+ }
+ if exists {
+ parentNodeStatus.itemResults.Set(nodeID, rs)
+ if parentNodeStatus.IsDone() {
+ rt := tm.prepareResult(ctx, parentNodeStatus)
+ tm.ChangeNodeStatus(ctx, parentKey, status, rt)
+ }
+ }
+ }
+}
+
+func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result {
+ aggregatedOutput := make([]json.RawMessage, 0)
+ var status mq.Status
+ var topic string
+ var err1 error
+ if nodeStatus.totalItems == 1 {
+ rs := nodeStatus.itemResults.Values()[0]
+ if rs.Ctx == nil {
+ rs.Ctx = ctx
+ }
+ return rs
+ }
+ nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool {
+ if topic == "" {
+ topic = result.Topic
+ status = result.Status
+ }
+ if result.Error != nil {
+ err1 = result.Error
+ return false
+ }
+ var item json.RawMessage
+ err := json.Unmarshal(result.Payload, &item)
+ if err != nil {
+ err1 = err
+ return false
+ }
+ aggregatedOutput = append(aggregatedOutput, item)
+ return true
+ })
+ if err1 != nil {
+ return mq.HandleError(ctx, err1)
+ }
+ finalOutput, err := json.Marshal(aggregatedOutput)
+ if err != nil {
+ return mq.HandleError(ctx, err)
+ }
+ return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx}
+}
diff --git a/dag/v2/ui.go b/dag/v1/ui.go
similarity index 54%
rename from dag/v2/ui.go
rename to dag/v1/ui.go
index bd32cc5..78fb02b 100644
--- a/dag/v2/ui.go
+++ b/dag/v1/ui.go
@@ -1,4 +1,4 @@
-package v2
+package v1
import (
"fmt"
@@ -9,24 +9,25 @@ import (
func (tm *DAG) PrintGraph() {
fmt.Println("DAG Graph structure:")
- tm.nodes.ForEach(func(_ string, node *Node) bool {
- fmt.Printf("Node: %s (%s) -> ", node.Label, node.ID)
- if conditions, ok := tm.conditions[node.ID]; ok {
+ for _, node := range tm.nodes {
+ fmt.Printf("Node: %s (%s) -> ", node.Name, node.Key)
+ if conditions, ok := tm.conditions[FromNode(node.Key)]; ok {
var c []string
for when, then := range conditions {
- if target, ok := tm.nodes.Get(then); ok {
- c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Label, target.ID))
+ if target, ok := tm.nodes[string(then)]; ok {
+ c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Name, target.Key))
}
}
fmt.Println(strings.Join(c, ", "))
}
var edges []string
- for _, target := range node.Edges {
- edges = append(edges, fmt.Sprintf("%s (%s)", target.To.Label, target.To.ID))
+ for _, edge := range node.Edges {
+ for _, target := range edge.To {
+ edges = append(edges, fmt.Sprintf("%s (%s)", target.Name, target.Key))
+ }
}
fmt.Println(strings.Join(edges, ", "))
- return true
- })
+ }
}
func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
@@ -43,7 +44,7 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
if startNode == "" {
firstNode := tm.findStartNode()
if firstNode != nil {
- startNode = firstNode.ID
+ startNode = firstNode.Key
}
}
if startNode == "" {
@@ -61,22 +62,24 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim
inRecursionStack[v] = true // mark node as part of recursion stack
*timeVal++
discoveryTime[v] = *timeVal
- node, _ := tm.nodes.Get(v)
+ node := tm.nodes[v]
hasCycle := false
var err error
for _, edge := range node.Edges {
- if !visited[edge.To.ID] {
- builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, edge.To.ID))
- hasCycle, err := tm.dfs(edge.To.ID, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
- if err != nil {
- return true, err
+ for _, adj := range edge.To {
+ if !visited[adj.Key] {
+ builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key))
+ hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
+ if err != nil {
+ return true, err
+ }
+ if hasCycle {
+ return true, nil
+ }
+ } else if inRecursionStack[adj.Key] {
+ cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key)
+ return true, fmt.Errorf(cycleMsg)
}
- if hasCycle {
- return true, nil
- }
- } else if inRecursionStack[edge.To.ID] {
- cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, edge.To.ID)
- return true, fmt.Errorf(cycleMsg)
}
}
hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
@@ -90,20 +93,20 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim
}
func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) {
- node, _ := tm.nodes.Get(v)
- for when, then := range tm.conditions[node.ID] {
- if targetNode, ok := tm.nodes.Get(then); ok {
- if !visited[targetNode.ID] {
- builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID))
- hasCycle, err := tm.dfs(targetNode.ID, visited, discoveryTime, finishedTime, time, inRecursionStack, builder)
+ node := tm.nodes[v]
+ for when, then := range tm.conditions[FromNode(node.Key)] {
+ if targetNode, ok := tm.nodes[string(then)]; ok {
+ if !visited[targetNode.Key] {
+ builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key))
+ hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder)
if err != nil {
return true, err
}
if hasCycle {
return true, nil
}
- } else if inRecursionStack[targetNode.ID] {
- cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID)
+ } else if inRecursionStack[targetNode.Key] {
+ cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)
return true, fmt.Errorf(cycleMsg)
}
}
@@ -143,113 +146,98 @@ func (tm *DAG) ExportDOT() string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name))
sb.WriteString("\n")
- sb.WriteString(` label="Enhanced DAG Representation";`)
+ sb.WriteString(fmt.Sprintf(` label="%s";`, tm.name))
sb.WriteString("\n")
- sb.WriteString(` labelloc="t"; fontsize=22; fontname="Helvetica";`)
+ sb.WriteString(` labelloc="t";`)
sb.WriteString("\n")
- sb.WriteString(` node [shape=box, fontname="Helvetica", fillcolor="#B3CDE0", fontcolor="#2C3E50", fontsize=10, margin="0.25,0.15", style="rounded,filled"];`)
+ sb.WriteString(` fontsize=20;`)
sb.WriteString("\n")
- sb.WriteString(` edge [fontname="Helvetica", fontsize=12, arrowsize=0.8];`)
+ sb.WriteString(` node [shape=box, style="rounded,filled", fillcolor="lightgray", fontname="Arial", margin="0.2,0.1"];`)
sb.WriteString("\n")
- sb.WriteString(` rankdir=TB;`)
+ sb.WriteString(` edge [fontname="Arial", fontsize=12, arrowsize=0.8];`)
+ sb.WriteString("\n")
+ sb.WriteString(` size="10,10";`)
+ sb.WriteString("\n")
+ sb.WriteString(` ratio="fill";`)
sb.WriteString("\n")
sortedNodes := tm.TopologicalSort()
for _, nodeKey := range sortedNodes {
- node, _ := tm.nodes.Get(nodeKey)
- nodeColor := "lightgray"
- nodeShape := "box"
- labelSuffix := ""
-
- // Apply styles based on NodeType
- switch node.NodeType {
- case Function:
- nodeColor = "#D4EDDA"
- labelSuffix = " [Function]"
- case Page:
- nodeColor = "#F08080"
- labelSuffix = " [Page]"
- }
- sb.WriteString(fmt.Sprintf(
- ` "%s" [label="%s%s", fontcolor="#2C3E50", fillcolor="%s", shape="%s", style="rounded,filled", id="node_%s"];`,
- node.ID, node.Label, labelSuffix, nodeColor, nodeShape, node.ID))
+ node := tm.nodes[nodeKey]
+ nodeColor := "lightblue"
+ sb.WriteString(fmt.Sprintf(` "%s" [label=" %s", fillcolor="%s", id="node_%s"];`, node.Key, node.Name, nodeColor, node.Key))
sb.WriteString("\n")
}
-
- // Define edges with unique styling by EdgeType
for _, nodeKey := range sortedNodes {
- node, _ := tm.nodes.Get(nodeKey)
+ node := tm.nodes[nodeKey]
for _, edge := range node.Edges {
- edgeStyle := "solid"
- edgeColor := "black"
- labelSuffix := ""
-
- // Apply styles based on EdgeType
+ var edgeStyle string
switch edge.Type {
case Iterator:
edgeStyle = "dashed"
- edgeColor = "blue"
- labelSuffix = " [Iter]"
- case Simple:
+ default:
edgeStyle = "solid"
- edgeColor = "black"
- labelSuffix = ""
}
- sb.WriteString(fmt.Sprintf(
- ` "%s" -> "%s" [label="%s%s", color="%s", style="%s"];`,
- node.ID, edge.To.ID, edge.Label, labelSuffix, edgeColor, edgeStyle))
- sb.WriteString("\n")
- }
- }
- for fromNodeKey, conditions := range tm.conditions {
- for when, then := range conditions {
- if toNode, ok := tm.nodes.Get(then); ok {
- sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.ID, when))
+ edgeColor := "black"
+ for _, to := range edge.To {
+ sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="%s", style=%s, fontsize=10, arrowsize=0.6];`, node.Key, to.Key, edge.Label, edgeColor, edgeStyle))
+ sb.WriteString("\n")
+ }
+ }
+ }
+ for fromNodeKey, conditions := range tm.conditions {
+ for when, then := range conditions {
+ if toNode, ok := tm.nodes[string(then)]; ok {
+ sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.Key, when))
sb.WriteString("\n")
}
}
}
-
- // Optional: Group related nodes into subgraphs (e.g., loops)
for _, nodeKey := range sortedNodes {
- node, _ := tm.nodes.Get(nodeKey)
+ node := tm.nodes[nodeKey]
if node.processor != nil {
subDAG, _ := isDAGNode(node)
if subDAG != nil {
sb.WriteString(fmt.Sprintf(` subgraph "cluster_%s" {`, subDAG.name))
sb.WriteString("\n")
- sb.WriteString(fmt.Sprintf(` label="Subgraph: %s";`, subDAG.name))
+ sb.WriteString(fmt.Sprintf(` label=" %s";`, subDAG.name))
sb.WriteString("\n")
- sb.WriteString(` style=filled; color=gray90;`)
+ sb.WriteString(` style=dashed;`)
sb.WriteString("\n")
- subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool {
- sb.WriteString(fmt.Sprintf(` "%s" [label="%s"];`, subNode.ID, subNode.Label))
+ sb.WriteString(` bgcolor="lightgray";`)
+ sb.WriteString("\n")
+ sb.WriteString(` node [shape=rectangle, style="filled", fillcolor="lightblue", fontname="Arial", margin="0.2,0.1"];`)
+ sb.WriteString("\n")
+ for subNodeKey, subNode := range subDAG.nodes {
+ sb.WriteString(fmt.Sprintf(` "%s" [label=" %s"];`, subNodeKey, subNode.Name))
sb.WriteString("\n")
- return true
- })
- subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool {
+ }
+ for subNodeKey, subNode := range subDAG.nodes {
for _, edge := range subNode.Edges {
- sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label="%s"];`, subNodeKey, edge.To.ID, edge.Label))
- sb.WriteString("\n")
+ for _, to := range edge.To {
+ sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, subNodeKey, to.Key, edge.Label))
+ sb.WriteString("\n")
+ }
}
- return true
- })
- sb.WriteString(" }\n")
+ }
+ sb.WriteString(` }`)
+ sb.WriteString("\n")
+ sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, node.Key, subDAG.startNode, subDAG.name))
+ sb.WriteString("\n")
}
}
}
-
- sb.WriteString("}\n")
+ sb.WriteString(`}`)
+ sb.WriteString("\n")
return sb.String()
}
func (tm *DAG) TopologicalSort() (stack []string) {
visited := make(map[string]bool)
- tm.nodes.ForEach(func(_ string, node *Node) bool {
- if !visited[node.ID] {
- tm.topologicalSortUtil(node.ID, visited, &stack)
+ for _, node := range tm.nodes {
+ if !visited[node.Key] {
+ tm.topologicalSortUtil(node.Key, visited, &stack)
}
- return true
- })
+ }
for i, j := 0, len(stack)-1; i < j; i, j = i+1, j-1 {
stack[i], stack[j] = stack[j], stack[i]
}
@@ -258,23 +246,13 @@ func (tm *DAG) TopologicalSort() (stack []string) {
func (tm *DAG) topologicalSortUtil(v string, visited map[string]bool, stack *[]string) {
visited[v] = true
- node, ok := tm.nodes.Get(v)
- if !ok {
- fmt.Println("Not found", v)
- }
+ node := tm.nodes[v]
for _, edge := range node.Edges {
- if !visited[edge.To.ID] {
- tm.topologicalSortUtil(edge.To.ID, visited, stack)
+ for _, to := range edge.To {
+ if !visited[to.Key] {
+ tm.topologicalSortUtil(to.Key, visited, stack)
+ }
}
}
*stack = append(*stack, v)
}
-
-func isDAGNode(node *Node) (*DAG, bool) {
- switch node := node.processor.(type) {
- case *DAG:
- return node, true
- default:
- return nil, false
- }
-}
diff --git a/dag/utils.go b/dag/v1/utils.go
similarity index 98%
rename from dag/utils.go
rename to dag/v1/utils.go
index efd8642..52c085f 100644
--- a/dag/utils.go
+++ b/dag/v1/utils.go
@@ -1,4 +1,4 @@
-package dag
+package v1
import (
"context"
diff --git a/dag/waitgroup.go b/dag/v1/waitgroup.go
similarity index 98%
rename from dag/waitgroup.go
rename to dag/v1/waitgroup.go
index e63f9bf..dfb84cc 100644
--- a/dag/waitgroup.go
+++ b/dag/v1/waitgroup.go
@@ -1,4 +1,4 @@
-package dag
+package v1
import (
"sync"
diff --git a/dag/v2/websocket.go b/dag/v1/websocket.go
similarity index 98%
rename from dag/v2/websocket.go
rename to dag/v1/websocket.go
index 0f0534a..4d4b544 100644
--- a/dag/v2/websocket.go
+++ b/dag/v1/websocket.go
@@ -1,4 +1,4 @@
-package v2
+package v1
import (
"encoding/json"
diff --git a/dag/v2/api.go b/dag/v2/api.go
deleted file mode 100644
index f2b6f4a..0000000
--- a/dag/v2/api.go
+++ /dev/null
@@ -1,152 +0,0 @@
-package v2
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "os"
- "strings"
-
- "github.com/oarkflow/mq"
- "github.com/oarkflow/mq/sio"
-
- "github.com/oarkflow/mq/consts"
- "github.com/oarkflow/mq/jsonparser"
-)
-
-func renderNotFound(w http.ResponseWriter) {
- html := []byte(`
-
-`)
- w.Header().Set(consts.ContentType, consts.TypeHtml)
- w.Write(html)
-}
-
-func (tm *DAG) render(w http.ResponseWriter, r *http.Request) {
- ctx, data, err := parse(r)
- if err != nil {
- http.Error(w, err.Error(), http.StatusNotFound)
- return
- }
- accept := r.Header.Get("Accept")
- userCtx := UserContext(ctx)
- ctx = context.WithValue(ctx, "method", r.Method)
- if r.Method == "GET" && userCtx.Get("task_id") != "" {
- manager, ok := tm.taskManager.Get(userCtx.Get("task_id"))
- if !ok || manager == nil {
- if strings.Contains(accept, "text/html") || accept == "" {
- renderNotFound(w)
- return
- }
- http.Error(w, fmt.Sprintf(`{"message": "%s"}`, "task not found"), http.StatusInternalServerError)
- return
- }
- }
- result := tm.Process(ctx, data)
- if result.Error != nil {
- http.Error(w, fmt.Sprintf(`{"message": "%s"}`, result.Error.Error()), http.StatusInternalServerError)
- return
- }
- contentType, ok := result.Ctx.Value(consts.ContentType).(string)
- if !ok {
- contentType = consts.TypeJson
- }
- switch contentType {
- case consts.TypeHtml:
- w.Header().Set(consts.ContentType, consts.TypeHtml)
- data, err := jsonparser.GetString(result.Payload, "html_content")
- if err != nil {
- return
- }
- w.Write([]byte(data))
- default:
- if r.Method != "POST" {
- http.Error(w, `{"message": "not allowed"}`, http.StatusMethodNotAllowed)
- return
- }
- w.Header().Set(consts.ContentType, consts.TypeJson)
- json.NewEncoder(w).Encode(result.Payload)
- }
-}
-
-func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) {
- taskID := r.URL.Query().Get("taskID")
- if taskID == "" {
- http.Error(w, `{"message": "taskID is missing"}`, http.StatusBadRequest)
- return
- }
- manager, ok := tm.taskManager.Get(taskID)
- if !ok {
- http.Error(w, `{"message": "Invalid TaskID"}`, http.StatusNotFound)
- return
- }
- result := make(map[string]TaskState)
- manager.taskStates.ForEach(func(key string, value *TaskState) bool {
- key = strings.Split(key, Delimiter)[0]
- nodeID := strings.Split(value.NodeID, Delimiter)[0]
- rs := jsonparser.Delete(value.Result.Payload, "html_content")
- status := value.Status
- if status == mq.Processing {
- status = mq.Completed
- }
- state := TaskState{
- NodeID: nodeID,
- Status: status,
- UpdatedAt: value.UpdatedAt,
- Result: mq.Result{
- Payload: rs,
- Error: value.Result.Error,
- Status: status,
- },
- }
- result[key] = state
- return true
- })
- w.Header().Set(consts.ContentType, consts.TypeJson)
- json.NewEncoder(w).Encode(result)
-}
-
-func (tm *DAG) SetupWS() *sio.Server {
- ws := sio.New(sio.Config{
- CheckOrigin: func(r *http.Request) bool { return true },
- EnableCompression: true,
- })
- WsEvents(ws)
- tm.Notifier = ws
- return ws
-}
-
-func (tm *DAG) Handlers() {
- http.Handle("/", http.FileServer(http.Dir("webroot")))
- http.Handle("/notify", tm.SetupWS())
- http.HandleFunc("/process", tm.render)
- http.HandleFunc("/request", tm.render)
- http.HandleFunc("/task/status", tm.taskStatusHandler)
- http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/plain")
- fmt.Fprintln(w, tm.ExportDOT())
- })
- http.HandleFunc("/ui", func(w http.ResponseWriter, r *http.Request) {
- image := fmt.Sprintf("%s.svg", mq.NewID())
- err := tm.SaveSVG(image)
- if err != nil {
- http.Error(w, "Failed to read request body", http.StatusBadRequest)
- return
- }
- defer os.Remove(image)
- svgBytes, err := os.ReadFile(image)
- if err != nil {
- http.Error(w, "Could not read SVG file", http.StatusInternalServerError)
- return
- }
- w.Header().Set("Content-Type", "image/svg+xml")
- if _, err := w.Write(svgBytes); err != nil {
- http.Error(w, "Could not write SVG response", http.StatusInternalServerError)
- return
- }
- })
-}
diff --git a/dag/v2/consts.go b/dag/v2/consts.go
deleted file mode 100644
index 3010831..0000000
--- a/dag/v2/consts.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package v2
-
-import "time"
-
-const (
- Delimiter = "___"
- ContextIndex = "index"
- DefaultChannelSize = 1000
- RetryInterval = 5 * time.Second
-)
-
-type NodeType int
-
-func (c NodeType) IsValid() bool { return c >= Function && c <= Page }
-
-const (
- Function NodeType = iota
- Page
-)
-
-type EdgeType int
-
-func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
-
-const (
- Simple EdgeType = iota
- Iterator
-)
diff --git a/dag/v2/dag.go b/dag/v2/dag.go
deleted file mode 100644
index 0e0cf28..0000000
--- a/dag/v2/dag.go
+++ /dev/null
@@ -1,524 +0,0 @@
-package v2
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "net/http"
- "strings"
- "time"
-
- "golang.org/x/time/rate"
-
- "github.com/oarkflow/mq/consts"
- "github.com/oarkflow/mq/sio"
-
- "github.com/oarkflow/mq"
- "github.com/oarkflow/mq/storage"
- "github.com/oarkflow/mq/storage/memory"
-)
-
-type Node struct {
- NodeType NodeType
- Label string
- ID string
- Edges []Edge
- processor mq.Processor
- isReady bool
-}
-
-type Edge struct {
- From *Node
- To *Node
- Type EdgeType
- Label string
-}
-
-type DAG struct {
- server *mq.Broker
- consumer *mq.Consumer
- nodes storage.IMap[string, *Node]
- taskManager storage.IMap[string, *TaskManager]
- iteratorNodes storage.IMap[string, []Edge]
- finalResult func(taskID string, result mq.Result)
- pool *mq.Pool
- name string
- key string
- startNode string
- opts []mq.Option
- conditions map[string]map[string]string
- consumerTopic string
- reportNodeResultCallback func(mq.Result)
- Error error
- Notifier *sio.Server
- paused bool
- report string
-}
-
-func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.Result), opts ...mq.Option) *DAG {
- callback := func(ctx context.Context, result mq.Result) error { return nil }
- d := &DAG{
- name: name,
- key: key,
- nodes: memory.New[string, *Node](),
- taskManager: memory.New[string, *TaskManager](),
- iteratorNodes: memory.New[string, []Edge](),
- conditions: make(map[string]map[string]string),
- finalResult: finalResultCallback,
- }
- opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
- d.server = mq.NewBroker(opts...)
- d.opts = opts
- options := d.server.Options()
- d.pool = mq.NewPool(
- options.NumOfWorkers(),
- mq.WithTaskQueueSize(options.QueueSize()),
- mq.WithMaxMemoryLoad(options.MaxMemoryLoad()),
- mq.WithHandler(d.ProcessTask),
- mq.WithPoolCallback(callback),
- mq.WithTaskStorage(options.Storage()),
- )
- d.pool.Start(d.server.Options().NumOfWorkers())
- return d
-}
-
-func (tm *DAG) SetKey(key string) {
- tm.key = key
-}
-
-func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
- tm.reportNodeResultCallback = callback
-}
-
-func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
- if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" {
- manager.onNodeCompleted(nodeResult{
- ctx: ctx,
- nodeID: result.Topic,
- status: result.Status,
- result: result,
- })
- }
- return mq.Result{}
-}
-
-func (tm *DAG) GetType() string {
- return tm.key
-}
-
-func (tm *DAG) Consume(ctx context.Context) error {
- if tm.consumer != nil {
- tm.server.Options().SetSyncMode(true)
- return tm.consumer.Consume(ctx)
- }
- return nil
-}
-
-func (tm *DAG) Stop(ctx context.Context) error {
- tm.nodes.ForEach(func(_ string, n *Node) bool {
- err := n.processor.Stop(ctx)
- if err != nil {
- return false
- }
- return true
- })
- return nil
-}
-
-func (tm *DAG) GetKey() string {
- return tm.key
-}
-
-func (tm *DAG) AssignTopic(topic string) {
- tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
- tm.consumerTopic = topic
-}
-
-func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
- if tm.consumer != nil {
- result.Topic = tm.consumerTopic
- if tm.consumer.Conn() == nil {
- tm.onTaskCallback(ctx, result)
- } else {
- tm.consumer.OnResponse(ctx, result)
- }
- }
-}
-
-func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
- if node, ok := tm.nodes.Get(topic); ok {
- log.Printf("DAG - CONSUMER ~> ready on %s", topic)
- node.isReady = true
- }
-}
-
-func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) {
- if node, ok := tm.nodes.Get(topic); ok {
- log.Printf("DAG - CONSUMER ~> down on %s", topic)
- node.isReady = false
- }
-}
-
-func (tm *DAG) Pause(_ context.Context) error {
- tm.paused = true
- return nil
-}
-
-func (tm *DAG) Resume(_ context.Context) error {
- tm.paused = false
- return nil
-}
-
-func (tm *DAG) Close() error {
- var err error
- tm.nodes.ForEach(func(_ string, n *Node) bool {
- err = n.processor.Close()
- if err != nil {
- return false
- }
- return true
- })
- return nil
-}
-
-func (tm *DAG) SetStartNode(node string) {
- tm.startNode = node
-}
-
-func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
- tm.server.SetNotifyHandler(callback)
-}
-
-func (tm *DAG) GetStartNode() string {
- return tm.startNode
-}
-
-func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG {
- tm.conditions[fromNode] = conditions
- return tm
-}
-
-func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG {
- if tm.Error != nil {
- return tm
- }
- con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask)
- n := &Node{
- Label: name,
- ID: nodeID,
- NodeType: nodeType,
- processor: con,
- }
- if tm.server != nil && tm.server.SyncMode() {
- n.isReady = true
- }
- tm.nodes.Set(nodeID, n)
- if len(startNode) > 0 && startNode[0] {
- tm.startNode = nodeID
- }
- return tm
-}
-
-func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error {
- if tm.server.SyncMode() {
- return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
- }
- tm.nodes.Set(key, &Node{
- Label: name,
- ID: key,
- NodeType: nodeType,
- })
- if len(firstNode) > 0 && firstNode[0] {
- tm.startNode = key
- }
- return nil
-}
-
-func (tm *DAG) IsReady() bool {
- var isReady bool
- tm.nodes.ForEach(func(_ string, n *Node) bool {
- if !n.isReady {
- return false
- }
- isReady = true
- return true
- })
- return isReady
-}
-
-func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG {
- if tm.Error != nil {
- return tm
- }
- if edgeType == Iterator {
- tm.iteratorNodes.Set(from, []Edge{})
- }
- node, ok := tm.nodes.Get(from)
- if !ok {
- tm.Error = fmt.Errorf("node not found %s", from)
- return tm
- }
- for _, target := range targets {
- if targetNode, ok := tm.nodes.Get(target); ok {
- edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label}
- node.Edges = append(node.Edges, edge)
- if edgeType != Iterator {
- if edges, ok := tm.iteratorNodes.Get(node.ID); ok {
- edges = append(edges, edge)
- tm.iteratorNodes.Set(node.ID, edges)
- }
- }
- }
- }
- return tm
-}
-
-func (tm *DAG) getCurrentNode(manager *TaskManager) string {
- if manager.currentNodePayload.Size() == 0 {
- return ""
- }
- return manager.currentNodePayload.Keys()[0]
-}
-
-func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
- ctx = context.WithValue(ctx, "task_id", task.ID)
- userContext := UserContext(ctx)
- next := userContext.Get("next")
- manager, ok := tm.taskManager.Get(task.ID)
- resultCh := make(chan mq.Result, 1)
- if !ok {
- manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone())
- tm.taskManager.Set(task.ID, manager)
- } else {
- manager.resultCh = resultCh
- }
- currentKey := tm.getCurrentNode(manager)
- currentNode := strings.Split(currentKey, Delimiter)[0]
- node, exists := tm.nodes.Get(currentNode)
- method, ok := ctx.Value("method").(string)
- if method == "GET" && exists && node.NodeType == Page {
- ctx = context.WithValue(ctx, "initial_node", currentNode)
- /*
- if isLastNode, err := tm.IsLastNode(currentNode); err != nil && isLastNode {
- if manager.result != nil {
- fmt.Println(string(manager.result.Payload))
- resultCh <- *manager.result
- return <-resultCh
- }
- }
- */
- if manager.result != nil {
- task.Payload = manager.result.Payload
- }
- } else if next == "true" {
- nodes, err := tm.GetNextNodes(currentNode)
- if err != nil {
- return mq.Result{Error: err, Ctx: ctx}
- }
- if len(nodes) > 0 {
- ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
- }
- }
- if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
- var taskPayload, resultPayload map[string]any
- if err := json.Unmarshal(task.Payload, &taskPayload); err == nil {
- if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
- for key, val := range resultPayload {
- taskPayload[key] = val
- }
- task.Payload, _ = json.Marshal(taskPayload)
- }
- }
- }
- firstNode, err := tm.parseInitialNode(ctx)
- if err != nil {
- return mq.Result{Error: err, Ctx: ctx}
- }
- node, ok = tm.nodes.Get(firstNode)
- if ok && node.NodeType != Page && task.Payload == nil {
- return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx}
- }
- task.Topic = firstNode
- ctx = context.WithValue(ctx, ContextIndex, "0")
- manager.ProcessTask(ctx, firstNode, task.Payload)
- return <-resultCh
-}
-
-func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
- var taskID string
- userCtx := UserContext(ctx)
- if val := userCtx.Get("task_id"); val != "" {
- taskID = val
- } else {
- taskID = mq.NewID()
- }
- return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, ""))
-}
-
-func (tm *DAG) Validate() error {
- report, hasCycle, err := tm.ClassifyEdges()
- if hasCycle || err != nil {
- tm.Error = err
- return err
- }
- tm.report = report
- return nil
-}
-
-func (tm *DAG) GetReport() string {
- return tm.report
-}
-
-func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
- dag.AssignTopic(key)
- tm.nodes.Set(key, &Node{
- Label: name,
- ID: key,
- processor: dag,
- isReady: true,
- })
- if len(firstNode) > 0 && firstNode[0] {
- tm.startNode = key
- }
- return tm
-}
-
-func (tm *DAG) Start(ctx context.Context, addr string) error {
- // Start the server in a separate goroutine
- go func() {
- defer mq.RecoverPanic(mq.RecoverTitle)
- if err := tm.server.Start(ctx); err != nil {
- panic(err)
- }
- }()
-
- // Start the node consumers if not in sync mode
- if !tm.server.SyncMode() {
- tm.nodes.ForEach(func(_ string, con *Node) bool {
- go func(con *Node) {
- defer mq.RecoverPanic(mq.RecoverTitle)
- limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
- for {
- err := con.processor.Consume(ctx)
- if err != nil {
- log.Printf("[ERROR] - Consumer %s failed to start: %v", con.ID, err)
- } else {
- log.Printf("[INFO] - Consumer %s started successfully", con.ID)
- break
- }
- limiter.Wait(ctx) // Wait with rate limiting before retrying
- }
- }(con)
- return true
- })
- }
- log.Printf("DAG - HTTP_SERVER ~> started on http://%s", addr)
- tm.Handlers()
- config := tm.server.TLSConfig()
- log.Printf("Server listening on http://%s", addr)
- if config.UseTLS {
- return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
- }
- return http.ListenAndServe(addr, nil)
-}
-
-func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
- var taskID string
- userCtx := UserContext(ctx)
- if val := userCtx.Get("task_id"); val != "" {
- taskID = val
- } else {
- taskID = mq.NewID()
- }
- t := mq.NewTask(taskID, payload, "")
-
- ctx = context.WithValue(ctx, "task_id", taskID)
- userContext := UserContext(ctx)
- next := userContext.Get("next")
- manager, ok := tm.taskManager.Get(taskID)
- resultCh := make(chan mq.Result, 1)
- if !ok {
- manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone())
- tm.taskManager.Set(taskID, manager)
- } else {
- manager.resultCh = resultCh
- }
- currentKey := tm.getCurrentNode(manager)
- currentNode := strings.Split(currentKey, Delimiter)[0]
- node, exists := tm.nodes.Get(currentNode)
- method, ok := ctx.Value("method").(string)
- if method == "GET" && exists && node.NodeType == Page {
- ctx = context.WithValue(ctx, "initial_node", currentNode)
- } else if next == "true" {
- nodes, err := tm.GetNextNodes(currentNode)
- if err != nil {
- return mq.Result{Error: err, Ctx: ctx}
- }
- if len(nodes) > 0 {
- ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
- }
- }
- if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
- var taskPayload, resultPayload map[string]any
- if err := json.Unmarshal(payload, &taskPayload); err == nil {
- if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
- for key, val := range resultPayload {
- taskPayload[key] = val
- }
- payload, _ = json.Marshal(taskPayload)
- }
- }
- }
- firstNode, err := tm.parseInitialNode(ctx)
- if err != nil {
- return mq.Result{Error: err, Ctx: ctx}
- }
- node, ok = tm.nodes.Get(firstNode)
- if ok && node.NodeType != Page && t.Payload == nil {
- return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx}
- }
- t.Topic = firstNode
- ctx = context.WithValue(ctx, ContextIndex, "0")
-
- headers, ok := mq.GetHeaders(ctx)
- ctxx := context.Background()
- if ok {
- ctxx = mq.SetHeaders(ctxx, headers.AsMap())
- }
- tm.pool.Scheduler().AddTask(ctxx, t, opts...)
- return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: "PENDING"}
-}
-
-func (tm *DAG) PauseConsumer(ctx context.Context, id string) {
- tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE)
-}
-
-func (tm *DAG) ResumeConsumer(ctx context.Context, id string) {
- tm.doConsumer(ctx, id, consts.CONSUMER_RESUME)
-}
-
-func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) {
- if node, ok := tm.nodes.Get(id); ok {
- switch action {
- case consts.CONSUMER_PAUSE:
- err := node.processor.Pause(ctx)
- if err == nil {
- node.isReady = false
- log.Printf("[INFO] - Consumer %s paused successfully", node.ID)
- } else {
- log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.ID, err)
- }
- case consts.CONSUMER_RESUME:
- err := node.processor.Resume(ctx)
- if err == nil {
- node.isReady = true
- log.Printf("[INFO] - Consumer %s resumed successfully", node.ID)
- } else {
- log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.ID, err)
- }
- }
- } else {
- log.Printf("[WARNING] - Consumer %s not found", id)
- }
-}
diff --git a/dag/v2/task_manager.go b/dag/v2/task_manager.go
deleted file mode 100644
index caf8dbf..0000000
--- a/dag/v2/task_manager.go
+++ /dev/null
@@ -1,388 +0,0 @@
-package v2
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "strings"
- "time"
-
- "github.com/oarkflow/mq"
-
- "github.com/oarkflow/mq/storage"
- "github.com/oarkflow/mq/storage/memory"
-)
-
-type TaskState struct {
- NodeID string
- Status mq.Status
- UpdatedAt time.Time
- Result mq.Result
- targetResults storage.IMap[string, mq.Result]
-}
-
-func newTaskState(nodeID string) *TaskState {
- return &TaskState{
- NodeID: nodeID,
- Status: mq.Pending,
- UpdatedAt: time.Now(),
- targetResults: memory.New[string, mq.Result](),
- }
-}
-
-type nodeResult struct {
- ctx context.Context
- nodeID string
- status mq.Status
- result mq.Result
-}
-
-type TaskManager struct {
- taskStates storage.IMap[string, *TaskState]
- parentNodes storage.IMap[string, string]
- childNodes storage.IMap[string, int]
- deferredTasks storage.IMap[string, *task]
- iteratorNodes storage.IMap[string, []Edge]
- currentNodePayload storage.IMap[string, json.RawMessage]
- currentNodeResult storage.IMap[string, mq.Result]
- result *mq.Result
- dag *DAG
- taskID string
- taskQueue chan *task
- resultQueue chan nodeResult
- resultCh chan mq.Result
- stopCh chan struct{}
-}
-
-type task struct {
- ctx context.Context
- taskID string
- nodeID string
- payload json.RawMessage
-}
-
-func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task {
- return &task{
- ctx: ctx,
- taskID: taskID,
- nodeID: nodeID,
- payload: payload,
- }
-}
-
-func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
- tm := &TaskManager{
- taskStates: memory.New[string, *TaskState](),
- parentNodes: memory.New[string, string](),
- childNodes: memory.New[string, int](),
- deferredTasks: memory.New[string, *task](),
- currentNodePayload: memory.New[string, json.RawMessage](),
- currentNodeResult: memory.New[string, mq.Result](),
- taskQueue: make(chan *task, DefaultChannelSize),
- resultQueue: make(chan nodeResult, DefaultChannelSize),
- iteratorNodes: iteratorNodes,
- stopCh: make(chan struct{}),
- resultCh: resultCh,
- taskID: taskID,
- dag: dag,
- }
- go tm.run()
- go tm.waitForResult()
- return tm
-}
-
-func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) {
- tm.send(ctx, startNode, tm.taskID, payload)
-}
-
-func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
- if index, ok := ctx.Value(ContextIndex).(string); ok {
- startNode = strings.Split(startNode, Delimiter)[0]
- startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index)
- }
- if _, exists := tm.taskStates.Get(startNode); !exists {
- tm.taskStates.Set(startNode, newTaskState(startNode))
- }
- t := newTask(ctx, taskID, startNode, payload)
- select {
- case tm.taskQueue <- t:
- default:
- log.Println("task queue is full, dropping task.")
- tm.deferredTasks.Set(taskID, t)
- }
-}
-
-func (tm *TaskManager) run() {
- for {
- select {
- case <-tm.stopCh:
- log.Println("Stopping TaskManager")
- return
- case task := <-tm.taskQueue:
- tm.processNode(task)
- }
- }
-}
-
-func (tm *TaskManager) waitForResult() {
- for {
- select {
- case <-tm.stopCh:
- log.Println("Stopping Result Listener")
- return
- case nr := <-tm.resultQueue:
- tm.onNodeCompleted(nr)
- }
- }
-}
-
-func (tm *TaskManager) processNode(exec *task) {
- pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
- node, exists := tm.dag.nodes.Get(pureNodeID)
- if !exists {
- log.Printf("Node %s does not exist while processing node\n", pureNodeID)
- return
- }
- state, _ := tm.taskStates.Get(exec.nodeID)
- if state == nil {
- log.Printf("State for node %s not found; creating new state.\n", exec.nodeID)
- state = newTaskState(exec.nodeID)
- tm.taskStates.Set(exec.nodeID, state)
- }
- state.Status = mq.Processing
- state.UpdatedAt = time.Now()
- tm.currentNodePayload.Clear()
- tm.currentNodeResult.Clear()
- tm.currentNodePayload.Set(exec.nodeID, exec.payload)
- result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
- tm.currentNodeResult.Set(exec.nodeID, result)
- state.Result = result
- result.Topic = node.ID
- if result.Error != nil {
- tm.result = &result
- tm.resultCh <- result
- tm.processFinalResult(state)
- return
- }
- if node.NodeType == Page {
- tm.result = &result
- tm.resultCh <- result
- return
- }
- tm.handleNext(exec.ctx, node, state, result)
-}
-
-func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
- state.targetResults.Set(childNode, result)
- state.targetResults.Del(state.NodeID)
- targetsCount, _ := tm.childNodes.Get(state.NodeID)
- size := state.targetResults.Size()
- nodeID := strings.Split(state.NodeID, Delimiter)
- if size == targetsCount {
- if size > 1 {
- aggregatedData := make([]json.RawMessage, size)
- i := 0
- state.targetResults.ForEach(func(_ string, rs mq.Result) bool {
- aggregatedData[i] = rs.Payload
- i++
- return true
- })
- aggregatedPayload, err := json.Marshal(aggregatedData)
- if err != nil {
- panic(err)
- }
- state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID}
- } else if size == 1 {
- state.Result = state.targetResults.Values()[0]
- }
- state.Status = result.Status
- state.Result.Status = result.Status
- }
- if state.Result.Payload == nil {
- state.Result.Payload = result.Payload
- }
- state.UpdatedAt = time.Now()
- if result.Ctx == nil {
- result.Ctx = ctx
- }
- if result.Error != nil {
- state.Status = mq.Failed
- }
- pn, ok := tm.parentNodes.Get(state.NodeID)
- if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed {
- state.Status = mq.Processing
- tm.iteratorNodes.Del(nodeID[0])
- state.targetResults.Clear()
- if len(nodeID) == 2 {
- ctx = context.WithValue(ctx, ContextIndex, nodeID[1])
- }
- toProcess := nodeResult{
- ctx: ctx,
- nodeID: state.NodeID,
- status: state.Status,
- result: state.Result,
- }
- tm.handleEdges(toProcess, edges)
- } else if ok {
- if targetsCount == size {
- parentState, _ := tm.taskStates.Get(pn)
- if parentState != nil {
- state.Result.Topic = state.NodeID
- tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
- }
- }
- } else {
- tm.result = &state.Result
- state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0]
- tm.resultCh <- state.Result
- tm.processFinalResult(state)
- }
-}
-
-func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) {
- state.UpdatedAt = time.Now()
- if result.Ctx == nil {
- result.Ctx = ctx
- }
- if result.Error != nil {
- state.Status = mq.Failed
- } else {
- edges := tm.getConditionalEdges(node, result)
- if len(edges) == 0 {
- state.Status = mq.Completed
- }
- }
- if result.Status == "" {
- result.Status = state.Status
- }
- select {
- case tm.resultQueue <- nodeResult{
- ctx: ctx,
- nodeID: state.NodeID,
- result: result,
- status: state.Status,
- }:
- default:
- log.Println("Result queue is full, dropping result.")
- }
-}
-
-func (tm *TaskManager) onNodeCompleted(rs nodeResult) {
- nodeID := strings.Split(rs.nodeID, Delimiter)[0]
- node, ok := tm.dag.nodes.Get(nodeID)
- if !ok {
- return
- }
- edges := tm.getConditionalEdges(node, rs.result)
- hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0
- if hasErrorOrCompleted {
- if index, ok := rs.ctx.Value(ContextIndex).(string); ok {
- childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index)
- pn, ok := tm.parentNodes.Get(childNode)
- if ok {
- parentState, _ := tm.taskStates.Get(pn)
- if parentState != nil {
- pn = strings.Split(pn, Delimiter)[0]
- tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true)
- }
- }
- }
- return
- }
- tm.handleEdges(rs, edges)
-}
-
-func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
- edges := make([]Edge, len(node.Edges))
- copy(edges, node.Edges)
- if result.ConditionStatus != "" {
- if conditions, ok := tm.dag.conditions[result.Topic]; ok {
- if targetNodeKey, ok := conditions[result.ConditionStatus]; ok {
- if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
- edges = append(edges, Edge{From: node, To: targetNode})
- }
- } else if targetNodeKey, ok = conditions["default"]; ok {
- if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
- edges = append(edges, Edge{From: node, To: targetNode})
- }
- }
- }
- }
- return edges
-}
-
-func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
- for _, edge := range edges {
- index, ok := currentResult.ctx.Value(ContextIndex).(string)
- if !ok {
- index = "0"
- }
- parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index)
- if edge.Type == Simple {
- if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok {
- continue
- }
- }
- if edge.Type == Iterator {
- var items []json.RawMessage
- err := json.Unmarshal(currentResult.result.Payload, &items)
- if err != nil {
- log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err)
- tm.resultQueue <- nodeResult{
- ctx: currentResult.ctx,
- nodeID: edge.To.ID,
- status: mq.Failed,
- result: mq.Result{Error: err},
- }
- return
- }
- tm.childNodes.Set(parentNode, len(items))
- for i, item := range items {
- childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i)
- ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i))
- tm.parentNodes.Set(childNode, parentNode)
- tm.send(ctx, edge.To.ID, tm.taskID, item)
- }
- } else {
- tm.childNodes.Set(parentNode, 1)
- idx, ok := currentResult.ctx.Value(ContextIndex).(string)
- if !ok {
- idx = "0"
- }
- childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
- ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
- tm.parentNodes.Set(childNode, parentNode)
- tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
- }
- }
-}
-
-func (tm *TaskManager) retryDeferredTasks() {
- const maxRetries = 5
- retries := 0
- for retries < maxRetries {
- select {
- case <-tm.stopCh:
- log.Println("Stopping Deferred task Retrier")
- return
- case <-time.After(RetryInterval):
- tm.deferredTasks.ForEach(func(taskID string, task *task) bool {
- tm.send(task.ctx, task.nodeID, taskID, task.payload)
- retries++
- return true
- })
- }
- }
-}
-
-func (tm *TaskManager) processFinalResult(state *TaskState) {
- state.targetResults.Clear()
- if tm.dag.finalResult != nil {
- tm.dag.finalResult(tm.taskID, state.Result)
- }
-}
-
-func (tm *TaskManager) Stop() {
- close(tm.stopCh)
-}
diff --git a/dag/websocket.go b/dag/websocket.go
index 70f2218..4f2c8cf 100644
--- a/dag/websocket.go
+++ b/dag/websocket.go
@@ -11,7 +11,6 @@ func WsEvents(s *sio.Server) {
}
func join(s *sio.Socket, data []byte) {
- //just one room at a time for the simple example
currentRooms := s.GetRooms()
for _, room := range currentRooms {
s.Leave(room)
diff --git a/examples/dag.go b/examples/dag.go
index 9a30d8e..6edc258 100644
--- a/examples/dag.go
+++ b/examples/dag.go
@@ -4,14 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
- v2 "github.com/oarkflow/mq/dag/v2"
-
"github.com/oarkflow/mq"
+ "github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/examples/tasks"
)
func main() {
- f := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
+ f := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload))
}, mq.WithSyncMode(true))
f.SetNotifyResponse(func(ctx context.Context, result mq.Result) error {
@@ -25,39 +24,38 @@ func main() {
if err != nil {
panic(err)
}
- f.Start(context.Background(), ":8083")
+ f.Start(context.Background(), ":8082")
sendData(f)
}
-func subDAG() *v2.DAG {
- f := v2.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) {
+func subDAG() *dag.DAG {
+ f := dag.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) {
fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload))
}, mq.WithSyncMode(true))
f.
- AddNode(v2.Function, "Store data", "store:data", &tasks.StoreData{Operation: v2.Operation{Type: "process"}}, true).
- AddNode(v2.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: v2.Operation{Type: "process"}}).
- AddNode(v2.Function, "Notification", "notification", &tasks.InAppNotification{Operation: v2.Operation{Type: "process"}}).
- AddEdge(v2.Simple, "Store Payload to send sms", "store:data", "send:sms").
- AddEdge(v2.Simple, "Store Payload to notification", "send:sms", "notification")
+ AddNode(dag.Function, "Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: "process"}}, true).
+ AddNode(dag.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: "process"}}).
+ AddNode(dag.Function, "Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: "process"}}).
+ AddEdge(dag.Simple, "Store Payload to send sms", "store:data", "send:sms").
+ AddEdge(dag.Simple, "Store Payload to notification", "send:sms", "notification")
return f
}
-func setup(f *v2.DAG) {
+func setup(f *dag.DAG) {
f.
- AddNode(v2.Function, "Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: v2.Operation{Type: "process"}}).
- AddNode(v2.Function, "Prepare Email", "prepare:email", &tasks.PrepareEmail{Operation: v2.Operation{Type: "process"}}).
- AddNode(v2.Function, "Get Input", "get:input", &tasks.GetData{Operation: v2.Operation{Type: "input"}}, true).
- AddNode(v2.Function, "Final Payload", "final", &tasks.Final{Operation: v2.Operation{Type: "page"}}).
- AddNode(v2.Function, "Iterator Processor", "loop", &tasks.Loop{Operation: v2.Operation{Type: "loop"}}).
- AddNode(v2.Function, "Condition", "condition", &tasks.Condition{Operation: v2.Operation{Type: "condition"}}).
+ AddNode(dag.Function, "Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: dag.Operation{Type: "process"}}).
+ AddNode(dag.Function, "Prepare Email", "prepare:email", &tasks.PrepareEmail{Operation: dag.Operation{Type: "process"}}).
+ AddNode(dag.Function, "Get Input", "get:input", &tasks.GetData{Operation: dag.Operation{Type: "input"}}, true).
+ AddNode(dag.Function, "Iterator Processor", "loop", &tasks.Loop{Operation: dag.Operation{Type: "loop"}}).
+ AddNode(dag.Function, "Condition", "condition", &tasks.Condition{Operation: dag.Operation{Type: "condition"}}).
AddDAGNode("Persistent", "persistent", subDAG()).
- AddEdge(v2.Simple, "Get input to loop", "get:input", "loop").
- AddEdge(v2.Iterator, "Loop to prepare email", "loop", "prepare:email").
- AddEdge(v2.Simple, "Prepare Email to condition", "prepare:email", "condition").
+ AddEdge(dag.Simple, "Get input to loop", "get:input", "loop").
+ AddEdge(dag.Iterator, "Loop to prepare email", "loop", "prepare:email").
+ AddEdge(dag.Simple, "Prepare Email to condition", "prepare:email", "condition").
AddCondition("condition", map[string]string{"pass": "email:deliver", "fail": "persistent"})
}
-func sendData(f *v2.DAG) {
+func sendData(f *dag.DAG) {
data := []map[string]any{
{"phone": "+123456789", "email": "abc.xyz@gmail.com"}, {"phone": "+98765412", "email": "xyz.abc@gmail.com"},
}
diff --git a/examples/dag_consumer.go b/examples/dag_consumer.go
index b4b9ad9..921e5e2 100644
--- a/examples/dag_consumer.go
+++ b/examples/dag_consumer.go
@@ -3,31 +3,30 @@ package main
import (
"context"
"fmt"
- v2 "github.com/oarkflow/mq/dag/v2"
-
"github.com/oarkflow/mq"
+ "github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/examples/tasks"
)
func main() {
- d := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
+ d := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
fmt.Println("Final", string(result.Payload))
},
mq.WithSyncMode(true),
mq.WithNotifyResponse(tasks.NotifyResponse),
)
- d.AddNode(v2.Function, "C", "C", &tasks.Node3{}, true)
- d.AddNode(v2.Function, "D", "D", &tasks.Node4{})
- d.AddNode(v2.Function, "E", "E", &tasks.Node5{})
- d.AddNode(v2.Function, "F", "F", &tasks.Node6{})
- d.AddNode(v2.Function, "G", "G", &tasks.Node7{})
- d.AddNode(v2.Function, "H", "H", &tasks.Node8{})
+ d.AddNode(dag.Function, "C", "C", &tasks.Node3{}, true)
+ d.AddNode(dag.Function, "D", "D", &tasks.Node4{})
+ d.AddNode(dag.Function, "E", "E", &tasks.Node5{})
+ d.AddNode(dag.Function, "F", "F", &tasks.Node6{})
+ d.AddNode(dag.Function, "G", "G", &tasks.Node7{})
+ d.AddNode(dag.Function, "H", "H", &tasks.Node8{})
d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"})
- d.AddEdge(v2.Simple, "Label 1", "B", "C")
- d.AddEdge(v2.Simple, "Label 2", "D", "F")
- d.AddEdge(v2.Simple, "Label 3", "E", "F")
- d.AddEdge(v2.Simple, "Label 4", "F", "G", "H")
+ d.AddEdge(dag.Simple, "Label 1", "B", "C")
+ d.AddEdge(dag.Simple, "Label 2", "D", "F")
+ d.AddEdge(dag.Simple, "Label 3", "E", "F")
+ d.AddEdge(dag.Simple, "Label 4", "F", "G", "H")
d.AssignTopic("queue")
err := d.Consume(context.Background())
if err != nil {
diff --git a/examples/form.go b/examples/form.go
index d20ea26..9c37aaa 100644
--- a/examples/form.go
+++ b/examples/form.go
@@ -4,25 +4,25 @@ import (
"context"
"encoding/json"
"fmt"
+ "github.com/oarkflow/mq/dag"
"github.com/oarkflow/jet"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
- v2 "github.com/oarkflow/mq/dag/v2"
)
func main() {
- flow := v2.NewDAG("Multi-Step Form", "multi-step-form", func(taskID string, result mq.Result) {
+ flow := dag.NewDAG("Multi-Step Form", "multi-step-form", func(taskID string, result mq.Result) {
fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload))
})
- flow.AddNode(v2.Page, "FormStep1", "FormStep1", &FormStep1{})
- flow.AddNode(v2.Page, "FormStep2", "FormStep2", &FormStep2{})
- flow.AddNode(v2.Page, "FormResult", "FormResult", &FormResult{})
+ flow.AddNode(dag.Page, "FormStep1", "FormStep1", &FormStep1{})
+ flow.AddNode(dag.Page, "FormStep2", "FormStep2", &FormStep2{})
+ flow.AddNode(dag.Page, "FormResult", "FormResult", &FormResult{})
// Define edges
- flow.AddEdge(v2.Simple, "FormStep1", "FormStep1", "FormStep2")
- flow.AddEdge(v2.Simple, "FormStep2", "FormStep2", "FormResult")
+ flow.AddEdge(dag.Simple, "FormStep1", "FormStep1", "FormStep2")
+ flow.AddEdge(dag.Simple, "FormStep2", "FormStep2", "FormResult")
// Start the flow
if flow.Error != nil {
@@ -32,7 +32,7 @@ func main() {
}
type FormStep1 struct {
- v2.Operation
+ dag.Operation
}
func (p *FormStep1) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -68,7 +68,7 @@ func (p *FormStep1) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
}
type FormStep2 struct {
- v2.Operation
+ dag.Operation
}
func (p *FormStep2) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -111,7 +111,7 @@ func (p *FormStep2) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
}
type FormResult struct {
- v2.Operation
+ dag.Operation
}
func (p *FormResult) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
diff --git a/examples/subdag.go b/examples/subdag.go
index 55d9d4d..4c197c9 100644
--- a/examples/subdag.go
+++ b/examples/subdag.go
@@ -3,21 +3,21 @@ package main
import (
"context"
"fmt"
+ "github.com/oarkflow/mq/dag/v1"
"github.com/oarkflow/mq/examples/tasks"
"github.com/oarkflow/mq"
- "github.com/oarkflow/mq/dag"
)
func main() {
- d := dag.NewDAG(
+ d := v1.NewDAG(
"Sample DAG",
"sample-dag",
mq.WithSyncMode(true),
mq.WithNotifyResponse(tasks.NotifyResponse),
)
- subDag := dag.NewDAG(
+ subDag := v1.NewDAG(
"Sub DAG",
"D",
mq.WithNotifyResponse(tasks.NotifySubDAGResponse),
@@ -35,7 +35,7 @@ func main() {
d.AddDAGNode("D", "D", subDag)
d.AddNode("E", "E", &tasks.Node5{})
d.AddIterator("Send each item", "A", "B")
- d.AddCondition("C", map[dag.When]dag.Then{"PASS": "D", "FAIL": "E"})
+ d.AddCondition("C", map[v1.When]v1.Then{"PASS": "D", "FAIL": "E"})
d.AddEdge("Label 1", "B", "C")
fmt.Println(d.ExportDOT())
diff --git a/examples/tasks/operations.go b/examples/tasks/operations.go
index 62e0d6b..dea31d0 100644
--- a/examples/tasks/operations.go
+++ b/examples/tasks/operations.go
@@ -2,9 +2,8 @@ package tasks
import (
"context"
- v2 "github.com/oarkflow/mq/dag/v2"
-
"github.com/oarkflow/json"
+ v2 "github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq"
)
diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go
index a72d7db..dcc2e47 100644
--- a/examples/tasks/tasks.go
+++ b/examples/tasks/tasks.go
@@ -4,7 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
- v2 "github.com/oarkflow/mq/dag/v2"
+ v2 "github.com/oarkflow/mq/dag"
"log"
"github.com/oarkflow/mq"
diff --git a/examples/v2.go b/examples/v2.go
index 1d86632..a259cca 100644
--- a/examples/v2.go
+++ b/examples/v2.go
@@ -5,16 +5,16 @@ import (
"encoding/json"
"fmt"
"github.com/oarkflow/mq"
+ "github.com/oarkflow/mq/dag"
"os"
"github.com/oarkflow/jet"
"github.com/oarkflow/mq/consts"
- v2 "github.com/oarkflow/mq/dag/v2"
)
type Form struct {
- v2.Operation
+ dag.Operation
}
func (p *Form) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -38,7 +38,7 @@ func (p *Form) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
}
type NodeA struct {
- v2.Operation
+ dag.Operation
}
func (p *NodeA) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -52,7 +52,7 @@ func (p *NodeA) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
}
type NodeB struct {
- v2.Operation
+ dag.Operation
}
func (p *NodeB) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -66,7 +66,7 @@ func (p *NodeB) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
}
type NodeC struct {
- v2.Operation
+ dag.Operation
}
func (p *NodeC) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -80,7 +80,7 @@ func (p *NodeC) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
}
type Result struct {
- v2.Operation
+ dag.Operation
}
func (p *Result) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -115,16 +115,16 @@ func notify(taskID string, result mq.Result) {
}
func main() {
- flow := v2.NewDAG("Sample DAG", "sample-dag", notify)
- flow.AddNode(v2.Page, "Form", "Form", &Form{})
- flow.AddNode(v2.Function, "NodeA", "NodeA", &NodeA{})
- flow.AddNode(v2.Function, "NodeB", "NodeB", &NodeB{})
- flow.AddNode(v2.Function, "NodeC", "NodeC", &NodeC{})
- flow.AddNode(v2.Page, "Result", "Result", &Result{})
- flow.AddEdge(v2.Simple, "Form", "Form", "NodeA")
- flow.AddEdge(v2.Simple, "NodeA", "NodeA", "NodeB")
- flow.AddEdge(v2.Simple, "NodeB", "NodeB", "NodeC")
- flow.AddEdge(v2.Simple, "NodeC", "NodeC", "Result")
+ flow := dag.NewDAG("Sample DAG", "sample-dag", notify)
+ flow.AddNode(dag.Page, "Form", "Form", &Form{})
+ flow.AddNode(dag.Function, "NodeA", "NodeA", &NodeA{})
+ flow.AddNode(dag.Function, "NodeB", "NodeB", &NodeB{})
+ flow.AddNode(dag.Function, "NodeC", "NodeC", &NodeC{})
+ flow.AddNode(dag.Page, "Result", "Result", &Result{})
+ flow.AddEdge(dag.Simple, "Form", "Form", "NodeA")
+ flow.AddEdge(dag.Simple, "NodeA", "NodeA", "NodeB")
+ flow.AddEdge(dag.Simple, "NodeB", "NodeB", "NodeC")
+ flow.AddEdge(dag.Simple, "NodeC", "NodeC", "Result")
if flow.Error != nil {
panic(flow.Error)
}
diff --git a/examples/v3.go b/examples/v3.go
index 9197d33..bb589fb 100644
--- a/examples/v3.go
+++ b/examples/v3.go
@@ -5,23 +5,23 @@ import (
"encoding/json"
"fmt"
"github.com/oarkflow/mq"
- v2 "github.com/oarkflow/mq/dag/v2"
+ "github.com/oarkflow/mq/dag"
)
func main() {
- flow := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
+ flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
// fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload))
})
- flow.AddNode(v2.Function, "GetData", "GetData", &GetData{}, true)
- flow.AddNode(v2.Function, "Loop", "Loop", &Loop{})
- flow.AddNode(v2.Function, "ValidateAge", "ValidateAge", &ValidateAge{})
- flow.AddNode(v2.Function, "ValidateGender", "ValidateGender", &ValidateGender{})
- flow.AddNode(v2.Function, "Final", "Final", &Final{})
+ flow.AddNode(dag.Function, "GetData", "GetData", &GetData{}, true)
+ flow.AddNode(dag.Function, "Loop", "Loop", &Loop{})
+ flow.AddNode(dag.Function, "ValidateAge", "ValidateAge", &ValidateAge{})
+ flow.AddNode(dag.Function, "ValidateGender", "ValidateGender", &ValidateGender{})
+ flow.AddNode(dag.Function, "Final", "Final", &Final{})
- flow.AddEdge(v2.Simple, "GetData", "GetData", "Loop")
- flow.AddEdge(v2.Iterator, "Validate age for each item", "Loop", "ValidateAge")
+ flow.AddEdge(dag.Simple, "GetData", "GetData", "Loop")
+ flow.AddEdge(dag.Iterator, "Validate age for each item", "Loop", "ValidateAge")
flow.AddCondition("ValidateAge", map[string]string{"pass": "ValidateGender"})
- flow.AddEdge(v2.Simple, "Mark as Done", "Loop", "Final")
+ flow.AddEdge(dag.Simple, "Mark as Done", "Loop", "Final")
// flow.Start(":8080")
data := []byte(`[{"age": "15", "gender": "female"}, {"age": "18", "gender": "male"}]`)
@@ -38,7 +38,7 @@ func main() {
}
type GetData struct {
- v2.Operation
+ dag.Operation
}
func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -46,7 +46,7 @@ func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
}
type Loop struct {
- v2.Operation
+ dag.Operation
}
func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -54,7 +54,7 @@ func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
}
type ValidateAge struct {
- v2.Operation
+ dag.Operation
}
func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -73,7 +73,7 @@ func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result
}
type ValidateGender struct {
- v2.Operation
+ dag.Operation
}
func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -87,7 +87,7 @@ func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Resu
}
type Final struct {
- v2.Operation
+ dag.Operation
}
func (p *Final) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {