update: dependencies

This commit is contained in:
Oarkflow
2024-11-23 10:51:22 +05:45
parent 6b87422a1a
commit 1c84e18d0c
30 changed files with 2291 additions and 2296 deletions

View File

@@ -4,26 +4,110 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url"
"os" "os"
"time" "strings"
"github.com/oarkflow/mq/jsonparser"
"github.com/oarkflow/mq/sio"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/sio"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/metrics" "github.com/oarkflow/mq/jsonparser"
) )
type Request struct { func renderNotFound(w http.ResponseWriter) {
Payload json.RawMessage `json:"payload"` html := []byte(`
Interval time.Duration `json:"interval"` <div>
Schedule bool `json:"schedule"` <h1>task not found</h1>
Overlap bool `json:"overlap"` <p><a href="/process">Back to home</a></p>
Recurring bool `json:"recurring"` </div>
`)
w.Header().Set(consts.ContentType, consts.TypeHtml)
w.Write(html)
}
func (tm *DAG) render(w http.ResponseWriter, r *http.Request) {
ctx, data, err := parse(r)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
accept := r.Header.Get("Accept")
userCtx := UserContext(ctx)
ctx = context.WithValue(ctx, "method", r.Method)
if r.Method == "GET" && userCtx.Get("task_id") != "" {
manager, ok := tm.taskManager.Get(userCtx.Get("task_id"))
if !ok || manager == nil {
if strings.Contains(accept, "text/html") || accept == "" {
renderNotFound(w)
return
}
http.Error(w, fmt.Sprintf(`{"message": "%s"}`, "task not found"), http.StatusInternalServerError)
return
}
}
result := tm.Process(ctx, data)
if result.Error != nil {
http.Error(w, fmt.Sprintf(`{"message": "%s"}`, result.Error.Error()), http.StatusInternalServerError)
return
}
contentType, ok := result.Ctx.Value(consts.ContentType).(string)
if !ok {
contentType = consts.TypeJson
}
switch contentType {
case consts.TypeHtml:
w.Header().Set(consts.ContentType, consts.TypeHtml)
data, err := jsonparser.GetString(result.Payload, "html_content")
if err != nil {
return
}
w.Write([]byte(data))
default:
if r.Method != "POST" {
http.Error(w, `{"message": "not allowed"}`, http.StatusMethodNotAllowed)
return
}
w.Header().Set(consts.ContentType, consts.TypeJson)
json.NewEncoder(w).Encode(result.Payload)
}
}
func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) {
taskID := r.URL.Query().Get("taskID")
if taskID == "" {
http.Error(w, `{"message": "taskID is missing"}`, http.StatusBadRequest)
return
}
manager, ok := tm.taskManager.Get(taskID)
if !ok {
http.Error(w, `{"message": "Invalid TaskID"}`, http.StatusNotFound)
return
}
result := make(map[string]TaskState)
manager.taskStates.ForEach(func(key string, value *TaskState) bool {
key = strings.Split(key, Delimiter)[0]
nodeID := strings.Split(value.NodeID, Delimiter)[0]
rs := jsonparser.Delete(value.Result.Payload, "html_content")
status := value.Status
if status == mq.Processing {
status = mq.Completed
}
state := TaskState{
NodeID: nodeID,
Status: status,
UpdatedAt: value.UpdatedAt,
Result: mq.Result{
Payload: rs,
Error: value.Result.Error,
Status: status,
},
}
result[key] = state
return true
})
w.Header().Set(consts.ContentType, consts.TypeJson)
json.NewEncoder(w).Encode(result)
} }
func (tm *DAG) SetupWS() *sio.Server { func (tm *DAG) SetupWS() *sio.Server {
@@ -37,57 +121,11 @@ func (tm *DAG) SetupWS() *sio.Server {
} }
func (tm *DAG) Handlers() { func (tm *DAG) Handlers() {
metrics.HandleHTTP()
http.Handle("/", http.FileServer(http.Dir("webroot"))) http.Handle("/", http.FileServer(http.Dir("webroot")))
http.Handle("/notify", tm.SetupWS()) http.Handle("/notify", tm.SetupWS())
http.HandleFunc("GET /render", tm.Render) http.HandleFunc("/process", tm.render)
http.HandleFunc("POST /request", tm.Request) http.HandleFunc("/request", tm.render)
http.HandleFunc("POST /publish", tm.Publish) http.HandleFunc("/task/status", tm.taskStatusHandler)
http.HandleFunc("POST /schedule", tm.Schedule)
http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {
tm.PauseConsumer(request.Context(), id)
}
})
http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {
tm.ResumeConsumer(request.Context(), id)
}
})
http.HandleFunc("/pause", func(w http.ResponseWriter, request *http.Request) {
err := tm.Pause(request.Context())
if err != nil {
http.Error(w, "Failed to pause", http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "paused"})
})
http.HandleFunc("/resume", func(w http.ResponseWriter, request *http.Request) {
err := tm.Resume(request.Context())
if err != nil {
http.Error(w, "Failed to resume", http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "resumed"})
})
http.HandleFunc("/stop", func(w http.ResponseWriter, request *http.Request) {
err := tm.Stop(request.Context())
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "stopped"})
})
http.HandleFunc("/close", func(w http.ResponseWriter, request *http.Request) {
err := tm.Close()
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "closed"})
})
http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, tm.ExportDOT()) fmt.Fprintln(w, tm.ExportDOT())
@@ -112,98 +150,3 @@ func (tm *DAG) Handlers() {
} }
}) })
} }
func (tm *DAG) request(w http.ResponseWriter, r *http.Request, async bool) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var request Request
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
err = json.Unmarshal(payload, &request)
if err != nil {
http.Error(w, "Failed to unmarshal body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
ctx := r.Context()
if async {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
var opts []mq.SchedulerOption
if request.Interval > 0 {
opts = append(opts, mq.WithInterval(request.Interval))
}
if request.Overlap {
opts = append(opts, mq.WithOverlap())
}
if request.Recurring {
opts = append(opts, mq.WithRecurring())
}
ctx = context.WithValue(ctx, "query_params", r.URL.Query())
var rs mq.Result
if request.Schedule {
rs = tm.ScheduleTask(ctx, request.Payload, opts...)
} else {
rs = tm.Process(ctx, request.Payload)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rs)
}
func (tm *DAG) Render(w http.ResponseWriter, r *http.Request) {
ctx := mq.SetHeaders(r.Context(), map[string]string{consts.AwaitResponseKey: "true", "request_type": "render"})
ctx = context.WithValue(ctx, "query_params", r.URL.Query())
rs := tm.Process(ctx, nil)
content, err := jsonparser.GetString(rs.Payload, "html_content")
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", consts.TypeHtml)
w.Write([]byte(content))
}
func (tm *DAG) Request(w http.ResponseWriter, r *http.Request) {
tm.request(w, r, true)
}
func (tm *DAG) Publish(w http.ResponseWriter, r *http.Request) {
tm.request(w, r, false)
}
func (tm *DAG) Schedule(w http.ResponseWriter, r *http.Request) {
tm.request(w, r, false)
}
func GetTaskID(ctx context.Context) string {
if queryParams := ctx.Value("query_params"); queryParams != nil {
if params, ok := queryParams.(url.Values); ok {
if id := params.Get("taskID"); id != "" {
return id
}
}
}
return ""
}
func CanNextNode(ctx context.Context) string {
if queryParams := ctx.Value("query_params"); queryParams != nil {
if params, ok := queryParams.(url.Values); ok {
if id := params.Get("next"); id != "" {
return id
}
}
}
return ""
}

View File

@@ -1,26 +1,28 @@
package dag package dag
type NodeStatus int import "time"
func (c NodeStatus) IsValid() bool { return c >= Pending && c <= Failed }
func (c NodeStatus) String() string {
switch c {
case Pending:
return "Pending"
case Processing:
return "Processing"
case Completed:
return "Completed"
case Failed:
return "Failed"
}
return ""
}
const ( const (
Pending NodeStatus = iota Delimiter = "___"
Processing ContextIndex = "index"
Completed DefaultChannelSize = 1000
Failed RetryInterval = 5 * time.Second
)
type NodeType int
func (c NodeType) IsValid() bool { return c >= Function && c <= Page }
const (
Function NodeType = iota
Page
)
type EdgeType int
func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
const (
Simple EdgeType = iota
Iterator
) )

View File

@@ -1,4 +1,4 @@
package v2 package dag
import ( import (
"context" "context"

View File

@@ -2,155 +2,70 @@ package dag
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"github.com/oarkflow/mq/sio"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/metrics" "github.com/oarkflow/mq/sio"
)
type EdgeType int "github.com/oarkflow/mq"
"github.com/oarkflow/mq/storage"
func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator } "github.com/oarkflow/mq/storage/memory"
const (
Simple EdgeType = iota
Iterator
)
type NodeType int
func (c NodeType) IsValid() bool { return c >= Process && c <= Page }
const (
Process NodeType = iota
Page
) )
type Node struct { type Node struct {
processor mq.Processor NodeType NodeType
Name string Label string
Type NodeType ID string
Key string
Edges []Edge Edges []Edge
processor mq.Processor
isReady bool isReady bool
} }
func (n *Node) ProcessTask(ctx context.Context, msg *mq.Task) mq.Result {
return n.processor.ProcessTask(ctx, msg)
}
func (n *Node) Close() error {
return n.processor.Close()
}
type Edge struct { type Edge struct {
Label string
From *Node From *Node
To []*Node To *Node
Type EdgeType Type EdgeType
Label string
} }
type (
FromNode string
When string
Then string
)
type DAG struct { type DAG struct {
server *mq.Broker server *mq.Broker
consumer *mq.Consumer consumer *mq.Consumer
taskContext storage.IMap[string, *TaskManager] nodes storage.IMap[string, *Node]
nodes map[string]*Node taskManager storage.IMap[string, *TaskManager]
iteratorNodes storage.IMap[string, []Edge] iteratorNodes storage.IMap[string, []Edge]
conditions map[FromNode]map[When]Then finalResult func(taskID string, result mq.Result)
pool *mq.Pool pool *mq.Pool
taskCleanupCh chan string
name string name string
key string key string
startNode string startNode string
consumerTopic string
opts []mq.Option opts []mq.Option
conditions map[string]map[string]string
consumerTopic string
reportNodeResultCallback func(mq.Result) reportNodeResultCallback func(mq.Result)
Error error
Notifier *sio.Server Notifier *sio.Server
paused bool paused bool
Error error
report string report string
index string
} }
func (tm *DAG) SetKey(key string) { func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.Result), opts ...mq.Option) *DAG {
tm.key = key
}
func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
tm.reportNodeResultCallback = callback
}
func (tm *DAG) GetType() string {
return tm.key
}
func (tm *DAG) listenForTaskCleanup() {
for taskID := range tm.taskCleanupCh {
if tm.server.Options().CleanTaskOnComplete() {
tm.taskCleanup(taskID)
}
}
}
func (tm *DAG) taskCleanup(taskID string) {
tm.taskContext.Del(taskID)
log.Printf("DAG - Task %s cleaned up", taskID)
}
func (tm *DAG) Consume(ctx context.Context) error {
if tm.consumer != nil {
tm.server.Options().SetSyncMode(true)
return tm.consumer.Consume(ctx)
}
return nil
}
func (tm *DAG) Stop(ctx context.Context) error {
for _, n := range tm.nodes {
err := n.processor.Stop(ctx)
if err != nil {
return err
}
}
return nil
}
func (tm *DAG) GetKey() string {
return tm.key
}
func (tm *DAG) AssignTopic(topic string) {
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
tm.consumerTopic = topic
}
func NewDAG(name, key string, opts ...mq.Option) *DAG {
callback := func(ctx context.Context, result mq.Result) error { return nil } callback := func(ctx context.Context, result mq.Result) error { return nil }
d := &DAG{ d := &DAG{
name: name, name: name,
key: key, key: key,
nodes: make(map[string]*Node), nodes: memory.New[string, *Node](),
taskManager: memory.New[string, *TaskManager](),
iteratorNodes: memory.New[string, []Edge](), iteratorNodes: memory.New[string, []Edge](),
taskContext: memory.New[string, *TaskManager](), conditions: make(map[string]map[string]string),
conditions: make(map[FromNode]map[When]Then), finalResult: finalResultCallback,
taskCleanupCh: make(chan string),
} }
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
d.server = mq.NewBroker(opts...) d.server = mq.NewBroker(opts...)
@@ -165,10 +80,61 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG {
mq.WithTaskStorage(options.Storage()), mq.WithTaskStorage(options.Storage()),
) )
d.pool.Start(d.server.Options().NumOfWorkers()) d.pool.Start(d.server.Options().NumOfWorkers())
go d.listenForTaskCleanup()
return d return d
} }
func (tm *DAG) SetKey(key string) {
tm.key = key
}
func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
tm.reportNodeResultCallback = callback
}
func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" {
manager.onNodeCompleted(nodeResult{
ctx: ctx,
nodeID: result.Topic,
status: result.Status,
result: result,
})
}
return mq.Result{}
}
func (tm *DAG) GetType() string {
return tm.key
}
func (tm *DAG) Consume(ctx context.Context) error {
if tm.consumer != nil {
tm.server.Options().SetSyncMode(true)
return tm.consumer.Consume(ctx)
}
return nil
}
func (tm *DAG) Stop(ctx context.Context) error {
tm.nodes.ForEach(func(_ string, n *Node) bool {
err := n.processor.Stop(ctx)
if err != nil {
return false
}
return true
})
return nil
}
func (tm *DAG) GetKey() string {
return tm.key
}
func (tm *DAG) AssignTopic(topic string) {
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
tm.consumerTopic = topic
}
func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) { func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
if tm.consumer != nil { if tm.consumer != nil {
result.Topic = tm.consumerTopic result.Topic = tm.consumerTopic
@@ -180,114 +146,89 @@ func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
} }
} }
func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
if taskContext, ok := tm.taskContext.Get(result.TaskID); ok && result.Topic != "" {
return taskContext.handleNextTask(ctx, result)
}
return mq.Result{}
}
func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
if node, ok := tm.nodes[topic]; ok { if node, ok := tm.nodes.Get(topic); ok {
log.Printf("DAG - CONSUMER ~> ready on %s", topic) log.Printf("DAG - CONSUMER ~> ready on %s", topic)
node.isReady = true node.isReady = true
} }
} }
func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) { func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) {
if node, ok := tm.nodes[topic]; ok { if node, ok := tm.nodes.Get(topic); ok {
log.Printf("DAG - CONSUMER ~> down on %s", topic) log.Printf("DAG - CONSUMER ~> down on %s", topic)
node.isReady = false node.isReady = false
} }
} }
func (tm *DAG) Pause(_ context.Context) error {
tm.paused = true
return nil
}
func (tm *DAG) Resume(_ context.Context) error {
tm.paused = false
return nil
}
func (tm *DAG) Close() error {
var err error
tm.nodes.ForEach(func(_ string, n *Node) bool {
err = n.processor.Close()
if err != nil {
return false
}
return true
})
return nil
}
func (tm *DAG) SetStartNode(node string) { func (tm *DAG) SetStartNode(node string) {
tm.startNode = node tm.startNode = node
} }
func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
tm.server.SetNotifyHandler(callback)
}
func (tm *DAG) GetStartNode() string { func (tm *DAG) GetStartNode() string {
return tm.startNode return tm.startNode
} }
func (tm *DAG) Start(ctx context.Context, addr string) error { func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG {
// Start the server in a separate goroutine tm.conditions[fromNode] = conditions
go func() {
defer mq.RecoverPanic(mq.RecoverTitle)
if err := tm.server.Start(ctx); err != nil {
panic(err)
}
}()
// Start the node consumers if not in sync mode
if !tm.server.SyncMode() {
for _, con := range tm.nodes {
go func(con *Node) {
defer mq.RecoverPanic(mq.RecoverTitle)
limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
for {
err := con.processor.Consume(ctx)
if err != nil {
log.Printf("[ERROR] - Consumer %s failed to start: %v", con.Key, err)
} else {
log.Printf("[INFO] - Consumer %s started successfully", con.Key)
break
}
limiter.Wait(ctx) // Wait with rate limiting before retrying
}
}(con)
}
}
log.Printf("DAG - HTTP_SERVER ~> started on http://localhost%s", addr)
tm.Handlers()
config := tm.server.TLSConfig()
if config.UseTLS {
return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
}
return http.ListenAndServe(addr, nil)
}
func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
dag.AssignTopic(key)
tm.nodes[key] = &Node{
Name: name,
Key: key,
processor: dag,
isReady: true,
}
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
return tm return tm
} }
func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG { func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG {
con := mq.NewConsumer(key, key, handler.ProcessTask, tm.opts...) if tm.Error != nil {
return tm
}
con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask)
n := &Node{ n := &Node{
Name: name, Label: name,
Key: key, ID: nodeID,
NodeType: nodeType,
processor: con, processor: con,
} }
if handler.GetType() == "page" { if tm.server != nil && tm.server.SyncMode() {
n.Type = Page
}
if tm.server.SyncMode() {
n.isReady = true n.isReady = true
} }
tm.nodes[key] = n tm.nodes.Set(nodeID, n)
if len(firstNode) > 0 && firstNode[0] { if len(startNode) > 0 && startNode[0] {
tm.startNode = key tm.startNode = nodeID
} }
return tm return tm
} }
func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error { func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error {
if tm.server.SyncMode() { if tm.server.SyncMode() {
return fmt.Errorf("DAG cannot have deferred node in Sync Mode") return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
} }
tm.nodes[key] = &Node{ tm.nodes.Set(key, &Node{
Name: name, Label: name,
Key: key, ID: key,
} NodeType: nodeType,
})
if len(firstNode) > 0 && firstNode[0] { if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key tm.startNode = key
} }
@@ -295,52 +236,124 @@ func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error {
} }
func (tm *DAG) IsReady() bool { func (tm *DAG) IsReady() bool {
for _, node := range tm.nodes { var isReady bool
if !node.isReady { tm.nodes.ForEach(func(_ string, n *Node) bool {
if !n.isReady {
return false return false
} }
} isReady = true
return true return true
})
return isReady
} }
func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG { func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG {
tm.conditions[fromNode] = conditions if tm.Error != nil {
return tm return tm
} }
if edgeType == Iterator {
func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG {
tm.Error = tm.addEdge(Iterator, label, from, targets...)
tm.iteratorNodes.Set(from, []Edge{}) tm.iteratorNodes.Set(from, []Edge{})
return tm
} }
node, ok := tm.nodes.Get(from)
func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG {
tm.Error = tm.addEdge(Simple, label, from, targets...)
return tm
}
func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error {
fromNode, ok := tm.nodes[from]
if !ok { if !ok {
return fmt.Errorf("Error: 'from' node %s does not exist\n", from) tm.Error = fmt.Errorf("node not found %s", from)
return tm
} }
var nodes []*Node
for _, target := range targets { for _, target := range targets {
toNode, ok := tm.nodes[target] if targetNode, ok := tm.nodes.Get(target); ok {
if !ok { edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label}
return fmt.Errorf("Error: 'from' node %s does not exist\n", target) node.Edges = append(node.Edges, edge)
}
nodes = append(nodes, toNode)
}
edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label}
fromNode.Edges = append(fromNode.Edges, edge)
if edgeType != Iterator { if edgeType != Iterator {
if edges, ok := tm.iteratorNodes.Get(fromNode.Key); ok { if edges, ok := tm.iteratorNodes.Get(node.ID); ok {
edges = append(edges, edge) edges = append(edges, edge)
tm.iteratorNodes.Set(fromNode.Key, edges) tm.iteratorNodes.Set(node.ID, edges)
} }
} }
return nil }
}
return tm
}
func (tm *DAG) getCurrentNode(manager *TaskManager) string {
if manager.currentNodePayload.Size() == 0 {
return ""
}
return manager.currentNodePayload.Keys()[0]
}
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
ctx = context.WithValue(ctx, "task_id", task.ID)
userContext := UserContext(ctx)
next := userContext.Get("next")
manager, ok := tm.taskManager.Get(task.ID)
resultCh := make(chan mq.Result, 1)
if !ok {
manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone())
tm.taskManager.Set(task.ID, manager)
} else {
manager.resultCh = resultCh
}
currentKey := tm.getCurrentNode(manager)
currentNode := strings.Split(currentKey, Delimiter)[0]
node, exists := tm.nodes.Get(currentNode)
method, ok := ctx.Value("method").(string)
if method == "GET" && exists && node.NodeType == Page {
ctx = context.WithValue(ctx, "initial_node", currentNode)
/*
if isLastNode, err := tm.IsLastNode(currentNode); err != nil && isLastNode {
if manager.result != nil {
fmt.Println(string(manager.result.Payload))
resultCh <- *manager.result
return <-resultCh
}
}
*/
if manager.result != nil {
task.Payload = manager.result.Payload
}
} else if next == "true" {
nodes, err := tm.GetNextNodes(currentNode)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
if len(nodes) > 0 {
ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
}
}
if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
var taskPayload, resultPayload map[string]any
if err := json.Unmarshal(task.Payload, &taskPayload); err == nil {
if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
for key, val := range resultPayload {
taskPayload[key] = val
}
task.Payload, _ = json.Marshal(taskPayload)
}
}
}
firstNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
node, ok = tm.nodes.Get(firstNode)
if ok && node.NodeType != Page && task.Payload == nil {
return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx}
}
task.Topic = firstNode
ctx = context.WithValue(ctx, ContextIndex, "0")
manager.ProcessTask(ctx, firstNode, task.Payload)
return <-resultCh
}
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
var taskID string
userCtx := UserContext(ctx)
if val := userCtx.Get("task_id"); val != "" {
taskID = val
} else {
taskID = mq.NewID()
}
return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, ""))
} }
func (tm *DAG) Validate() error { func (tm *DAG) Validate() error {
@@ -357,177 +370,124 @@ func (tm *DAG) GetReport() string {
return tm.report return tm.report
} }
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
if task.ID == "" { dag.AssignTopic(key)
task.ID = mq.NewID() tm.nodes.Set(key, &Node{
Label: name,
ID: key,
processor: dag,
isReady: true,
})
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
} }
if index, ok := mq.GetHeader(ctx, "index"); ok { return tm
tm.index = index
}
manager, exists := tm.taskContext.Get(task.ID)
if !exists {
manager = NewTaskManager(tm, task.ID, tm.iteratorNodes)
manager.createdAt = task.CreatedAt
tm.taskContext.Set(task.ID, manager)
} }
if tm.consumer != nil { func (tm *DAG) Start(ctx context.Context, addr string) error {
initialNode, err := tm.parseInitialNode(ctx) // Start the server in a separate goroutine
go func() {
defer mq.RecoverPanic(mq.RecoverTitle)
if err := tm.server.Start(ctx); err != nil {
panic(err)
}
}()
// Start the node consumers if not in sync mode
if !tm.server.SyncMode() {
tm.nodes.ForEach(func(_ string, con *Node) bool {
go func(con *Node) {
defer mq.RecoverPanic(mq.RecoverTitle)
limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
for {
err := con.processor.Consume(ctx)
if err != nil { if err != nil {
metrics.TasksErrors.WithLabelValues("unknown").Inc() // Increase error count log.Printf("[ERROR] - Consumer %s failed to start: %v", con.ID, err)
return mq.Result{Error: err}
}
task.Topic = initialNode
}
if manager.topic != "" {
task.Topic = manager.topic
canNext := CanNextNode(ctx)
if canNext != "" {
if n, ok := tm.nodes[task.Topic]; ok {
if len(n.Edges) > 0 {
task.Topic = n.Edges[0].To[0].Key
}
}
} else { } else {
log.Printf("[INFO] - Consumer %s started successfully", con.ID)
break
} }
limiter.Wait(ctx) // Wait with rate limiting before retrying
} }
result := manager.processTask(ctx, task.Topic, task.Payload) }(con)
if result.Ctx != nil && tm.index != "" { return true
result.Ctx = mq.SetHeaders(result.Ctx, map[string]string{"index": tm.index}) })
} }
if result.Error != nil { log.Printf("DAG - HTTP_SERVER ~> started on http://%s", addr)
metrics.TasksErrors.WithLabelValues(task.Topic).Inc() // Increase error count tm.Handlers()
} else { config := tm.server.TLSConfig()
metrics.TasksProcessed.WithLabelValues("success").Inc() // Increase processed task count log.Printf("Server listening on http://%s", addr)
if config.UseTLS {
return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
} }
return result return http.ListenAndServe(addr, nil)
}
func (tm *DAG) check(ctx context.Context, payload []byte) (context.Context, *mq.Task, error) {
if tm.paused {
return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not accepting any task")
}
if !tm.IsReady() {
return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not ready yet")
}
initialNode, err := tm.parseInitialNode(ctx)
if err != nil {
return ctx, nil, err
}
if tm.server.SyncMode() {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
taskID := GetTaskID(ctx)
if taskID != "" {
if _, exists := tm.taskContext.Get(taskID); !exists {
return ctx, nil, fmt.Errorf("provided task ID doesn't exist")
}
}
if taskID == "" {
taskID = mq.NewID()
}
return ctx, mq.NewTask(taskID, payload, initialNode), nil
}
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
ctx, task, err := tm.check(ctx, payload)
if err != nil {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
awaitResponse, _ := mq.GetAwaitResponse(ctx)
if awaitResponse != "true" {
headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background()
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil {
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "FAILED", Error: err}
}
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"}
}
return tm.ProcessTask(ctx, task)
} }
func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result { func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
ctx, task, err := tm.check(ctx, payload) var taskID string
if err != nil { userCtx := UserContext(ctx)
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} if val := userCtx.Get("task_id"); val != "" {
taskID = val
} else {
taskID = mq.NewID()
} }
t := mq.NewTask(taskID, payload, "")
ctx = context.WithValue(ctx, "task_id", taskID)
userContext := UserContext(ctx)
next := userContext.Get("next")
manager, ok := tm.taskManager.Get(taskID)
resultCh := make(chan mq.Result, 1)
if !ok {
manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone())
tm.taskManager.Set(taskID, manager)
} else {
manager.resultCh = resultCh
}
currentKey := tm.getCurrentNode(manager)
currentNode := strings.Split(currentKey, Delimiter)[0]
node, exists := tm.nodes.Get(currentNode)
method, ok := ctx.Value("method").(string)
if method == "GET" && exists && node.NodeType == Page {
ctx = context.WithValue(ctx, "initial_node", currentNode)
} else if next == "true" {
nodes, err := tm.GetNextNodes(currentNode)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
if len(nodes) > 0 {
ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
}
}
if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
var taskPayload, resultPayload map[string]any
if err := json.Unmarshal(payload, &taskPayload); err == nil {
if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
for key, val := range resultPayload {
taskPayload[key] = val
}
payload, _ = json.Marshal(taskPayload)
}
}
}
firstNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
node, ok = tm.nodes.Get(firstNode)
if ok && node.NodeType != Page && t.Payload == nil {
return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx}
}
t.Topic = firstNode
ctx = context.WithValue(ctx, ContextIndex, "0")
headers, ok := mq.GetHeaders(ctx) headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background() ctxx := context.Background()
if ok { if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap()) ctxx = mq.SetHeaders(ctxx, headers.AsMap())
} }
tm.pool.Scheduler().AddTask(ctxx, task, opts...) tm.pool.Scheduler().AddTask(ctxx, t, opts...)
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"} return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: "PENDING"}
}
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
val := ctx.Value("initial_node")
initialNode, ok := val.(string)
if ok {
return initialNode, nil
}
if tm.startNode == "" {
firstNode := tm.findStartNode()
if firstNode != nil {
tm.startNode = firstNode.Key
}
}
if tm.startNode == "" {
return "", fmt.Errorf("initial node not found")
}
return tm.startNode, nil
}
func (tm *DAG) findStartNode() *Node {
incomingEdges := make(map[string]bool)
connectedNodes := make(map[string]bool)
for _, node := range tm.nodes {
for _, edge := range node.Edges {
if edge.Type.IsValid() {
connectedNodes[node.Key] = true
for _, to := range edge.To {
connectedNodes[to.Key] = true
incomingEdges[to.Key] = true
}
}
}
if cond, ok := tm.conditions[FromNode(node.Key)]; ok {
for _, target := range cond {
connectedNodes[string(target)] = true
incomingEdges[string(target)] = true
}
}
}
for nodeID, node := range tm.nodes {
if !incomingEdges[nodeID] && connectedNodes[nodeID] {
return node
}
}
return nil
}
func (tm *DAG) Pause(_ context.Context) error {
tm.paused = true
return nil
}
func (tm *DAG) Resume(_ context.Context) error {
tm.paused = false
return nil
}
func (tm *DAG) Close() error {
for _, n := range tm.nodes {
err := n.Close()
if err != nil {
return err
}
}
return nil
} }
func (tm *DAG) PauseConsumer(ctx context.Context, id string) { func (tm *DAG) PauseConsumer(ctx context.Context, id string) {
@@ -539,74 +499,26 @@ func (tm *DAG) ResumeConsumer(ctx context.Context, id string) {
} }
func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) { func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) {
if node, ok := tm.nodes[id]; ok { if node, ok := tm.nodes.Get(id); ok {
switch action { switch action {
case consts.CONSUMER_PAUSE: case consts.CONSUMER_PAUSE:
err := node.processor.Pause(ctx) err := node.processor.Pause(ctx)
if err == nil { if err == nil {
node.isReady = false node.isReady = false
log.Printf("[INFO] - Consumer %s paused successfully", node.Key) log.Printf("[INFO] - Consumer %s paused successfully", node.ID)
} else { } else {
log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.Key, err) log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.ID, err)
} }
case consts.CONSUMER_RESUME: case consts.CONSUMER_RESUME:
err := node.processor.Resume(ctx) err := node.processor.Resume(ctx)
if err == nil { if err == nil {
node.isReady = true node.isReady = true
log.Printf("[INFO] - Consumer %s resumed successfully", node.Key) log.Printf("[INFO] - Consumer %s resumed successfully", node.ID)
} else { } else {
log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.Key, err) log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.ID, err)
} }
} }
} else { } else {
log.Printf("[WARNING] - Consumer %s not found", id) log.Printf("[WARNING] - Consumer %s not found", id)
} }
} }
func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
tm.server.SetNotifyHandler(callback)
}
func (tm *DAG) GetNextNodes(key string) ([]*Node, error) {
node, exists := tm.nodes[key]
if !exists {
return nil, fmt.Errorf("Node with key %s does not exist", key)
}
var successors []*Node
for _, edge := range node.Edges {
successors = append(successors, edge.To...)
}
if conds, exists := tm.conditions[FromNode(key)]; exists {
for _, targetKey := range conds {
if targetNode, exists := tm.nodes[string(targetKey)]; exists {
successors = append(successors, targetNode)
}
}
}
return successors, nil
}
func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) {
var predecessors []*Node
for _, node := range tm.nodes {
for _, edge := range node.Edges {
for _, target := range edge.To {
if target.Key == key {
predecessors = append(predecessors, node)
}
}
}
}
for fromNode, conds := range tm.conditions {
for _, targetKey := range conds {
if string(targetKey) == key {
node, exists := tm.nodes[string(fromNode)]
if !exists {
return nil, fmt.Errorf("Node with key %s does not exist", fromNode)
}
predecessors = append(predecessors, node)
}
}
}
return predecessors, nil
}

View File

@@ -1,4 +1,4 @@
package v2 package dag
import ( import (
"context" "context"

View File

@@ -4,247 +4,307 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"strings" "strings"
"time" "time"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/storage" "github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory" "github.com/oarkflow/mq/storage/memory"
"github.com/oarkflow/mq"
) )
type TaskState struct {
NodeID string
Status mq.Status
UpdatedAt time.Time
Result mq.Result
targetResults storage.IMap[string, mq.Result]
}
func newTaskState(nodeID string) *TaskState {
return &TaskState{
NodeID: nodeID,
Status: mq.Pending,
UpdatedAt: time.Now(),
targetResults: memory.New[string, mq.Result](),
}
}
type nodeResult struct {
ctx context.Context
nodeID string
status mq.Status
result mq.Result
}
type TaskManager struct { type TaskManager struct {
createdAt time.Time taskStates storage.IMap[string, *TaskState]
processedAt time.Time parentNodes storage.IMap[string, string]
status string childNodes storage.IMap[string, int]
deferredTasks storage.IMap[string, *task]
iteratorNodes storage.IMap[string, []Edge]
currentNodePayload storage.IMap[string, json.RawMessage]
currentNodeResult storage.IMap[string, mq.Result]
result *mq.Result
dag *DAG dag *DAG
taskID string taskID string
wg *WaitGroup taskQueue chan *task
topic string resultQueue chan nodeResult
result mq.Result resultCh chan mq.Result
stopCh chan struct{}
iteratorNodes storage.IMap[string, []Edge]
taskNodeStatus storage.IMap[string, *taskNodeStatus]
} }
func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager { type task struct {
return &TaskManager{ ctx context.Context
dag: d, taskID string
taskNodeStatus: memory.New[string, *taskNodeStatus](), nodeID string
payload json.RawMessage
}
func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task {
return &task{
ctx: ctx,
taskID: taskID, taskID: taskID,
nodeID: nodeID,
payload: payload,
}
}
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
tm := &TaskManager{
taskStates: memory.New[string, *TaskState](),
parentNodes: memory.New[string, string](),
childNodes: memory.New[string, int](),
deferredTasks: memory.New[string, *task](),
currentNodePayload: memory.New[string, json.RawMessage](),
currentNodeResult: memory.New[string, mq.Result](),
taskQueue: make(chan *task, DefaultChannelSize),
resultQueue: make(chan nodeResult, DefaultChannelSize),
iteratorNodes: iteratorNodes, iteratorNodes: iteratorNodes,
wg: NewWaitGroup(), stopCh: make(chan struct{}),
resultCh: resultCh,
taskID: taskID,
dag: dag,
} }
go tm.run()
go tm.waitForResult()
return tm
} }
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result { func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) {
tm.updateTS(&tm.result) tm.send(ctx, startNode, tm.taskID, payload)
tm.dag.callbackToConsumer(ctx, tm.result)
if tm.dag.server.NotifyHandler() != nil {
_ = tm.dag.server.NotifyHandler()(ctx, tm.result)
}
tm.dag.taskCleanupCh <- tm.taskID
tm.topic = tm.result.Topic
return tm.result
} }
func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) { func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
if tm.dag.reportNodeResultCallback != nil { if index, ok := ctx.Value(ContextIndex).(string); ok {
tm.dag.reportNodeResultCallback(result) startNode = strings.Split(startNode, Delimiter)[0]
startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index)
} }
if _, exists := tm.taskStates.Get(startNode); !exists {
tm.taskStates.Set(startNode, newTaskState(startNode))
} }
t := newTask(ctx, taskID, startNode, payload)
func (tm *TaskManager) SetTotalItems(topic string, i int) {
if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok {
nodeStatus.totalItems = i
}
}
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
topic := getTopic(ctx, node.Key)
tm.taskNodeStatus.Set(topic, newNodeStatus(topic))
defer mq.RecoverPanic(mq.RecoverTitle)
dag, isDAG := isDAGNode(node)
if isDAG {
if tm.dag.server.SyncMode() && !dag.server.SyncMode() {
dag.server.Options().SetSyncMode(true)
}
}
tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key})
var result mq.Result
if tm.dag.server.SyncMode() {
defer func() {
if isDAG {
result.Topic = dag.consumerTopic
result.TaskID = tm.taskID
tm.reportNodeResult(result, false)
tm.handleNextTask(result.Ctx, result)
} else {
result.Topic = node.Key
tm.reportNodeResult(result, false)
tm.handleNextTask(ctx, result)
}
}()
}
select { select {
case <-ctx.Done(): case tm.taskQueue <- t:
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return
default: default:
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) log.Println("task queue is full, dropping task.")
if tm.dag.server.SyncMode() { tm.deferredTasks.Set(taskID, t)
result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key)) }
}
func (tm *TaskManager) run() {
for {
select {
case <-tm.stopCh:
log.Println("Stopping TaskManager")
return
case task := <-tm.taskQueue:
tm.processNode(task)
}
}
}
func (tm *TaskManager) waitForResult() {
for {
select {
case <-tm.stopCh:
log.Println("Stopping Result Listener")
return
case nr := <-tm.resultQueue:
tm.onNodeCompleted(nr)
}
}
}
func (tm *TaskManager) processNode(exec *task) {
pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
node, exists := tm.dag.nodes.Get(pureNodeID)
if !exists {
log.Printf("Node %s does not exist while processing node\n", pureNodeID)
return
}
state, _ := tm.taskStates.Get(exec.nodeID)
if state == nil {
log.Printf("State for node %s not found; creating new state.\n", exec.nodeID)
state = newTaskState(exec.nodeID)
tm.taskStates.Set(exec.nodeID, state)
}
state.Status = mq.Processing
state.UpdatedAt = time.Now()
tm.currentNodePayload.Clear()
tm.currentNodeResult.Clear()
tm.currentNodePayload.Set(exec.nodeID, exec.payload)
result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
tm.currentNodeResult.Set(exec.nodeID, result)
state.Result = result
result.Topic = node.ID
if result.Error != nil { if result.Error != nil {
tm.reportNodeResult(result, true) tm.result = &result
tm.ChangeNodeStatus(ctx, node.Key, Failed, result) tm.resultCh <- result
tm.processFinalResult(state)
return return
} }
if node.NodeType == Page {
tm.result = &result
tm.resultCh <- result
return return
} }
err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key) tm.handleNext(exec.ctx, node, state, result)
}
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
state.targetResults.Set(childNode, result)
state.targetResults.Del(state.NodeID)
targetsCount, _ := tm.childNodes.Get(state.NodeID)
size := state.targetResults.Size()
nodeID := strings.Split(state.NodeID, Delimiter)
if size == targetsCount {
if size > 1 {
aggregatedData := make([]json.RawMessage, size)
i := 0
state.targetResults.ForEach(func(_ string, rs mq.Result) bool {
aggregatedData[i] = rs.Payload
i++
return true
})
aggregatedPayload, err := json.Marshal(aggregatedData)
if err != nil { if err != nil {
tm.reportNodeResult(mq.Result{Error: err}, true) panic(err)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return
} }
state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID}
} else if size == 1 {
state.Result = state.targetResults.Values()[0]
} }
state.Status = result.Status
state.Result.Status = result.Status
} }
if state.Result.Payload == nil {
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { state.Result.Payload = result.Payload
defer mq.RecoverPanic(mq.RecoverTitle)
node, ok := tm.dag.nodes[nodeID]
if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
} }
if tm.createdAt.IsZero() { state.UpdatedAt = time.Now()
tm.createdAt = time.Now() if result.Ctx == nil {
} result.Ctx = ctx
tm.wg.Add(1)
go func() {
ctxx := context.Background()
if headers, ok := mq.GetHeaders(ctx); ok {
headers.Set(consts.QueueKey, node.Key)
headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0))
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
go tm.processNode(ctx, node, payload)
}()
tm.wg.Wait()
requestType, ok := mq.GetHeader(ctx, "request_type")
if ok && requestType == "render" {
return tm.renderResult(ctx)
}
return tm.dispatchFinalResult(ctx)
}
func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result {
tm.topic = result.Topic
defer func() {
tm.wg.Done()
mq.RecoverPanic(mq.RecoverTitle)
}()
if result.Ctx != nil {
if headers, ok := mq.GetHeaders(ctx); ok {
ctx = mq.SetHeaders(result.Ctx, headers.AsMap())
}
}
node, ok := tm.dag.nodes[result.Topic]
if !ok {
return result
} }
if result.Error != nil { if result.Error != nil {
tm.reportNodeResult(result, true) state.Status = mq.Failed
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return result
} }
pn, ok := tm.parentNodes.Get(state.NodeID)
if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed {
state.Status = mq.Processing
tm.iteratorNodes.Del(nodeID[0])
state.targetResults.Clear()
if len(nodeID) == 2 {
ctx = context.WithValue(ctx, ContextIndex, nodeID[1])
}
toProcess := nodeResult{
ctx: ctx,
nodeID: state.NodeID,
status: state.Status,
result: state.Result,
}
tm.handleEdges(toProcess, edges)
} else if ok {
if targetsCount == size {
parentState, _ := tm.taskStates.Get(pn)
if parentState != nil {
state.Result.Topic = state.NodeID
tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
}
}
} else {
tm.result = &state.Result
state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0]
tm.resultCh <- state.Result
tm.processFinalResult(state)
}
}
func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) {
state.UpdatedAt = time.Now()
if result.Ctx == nil {
result.Ctx = ctx
}
if result.Error != nil {
state.Status = mq.Failed
} else {
edges := tm.getConditionalEdges(node, result) edges := tm.getConditionalEdges(node, result)
if len(edges) == 0 { if len(edges) == 0 {
tm.reportNodeResult(result, true) state.Status = mq.Completed
tm.ChangeNodeStatus(ctx, node.Key, Completed, result)
return result
} else {
tm.reportNodeResult(result, false)
}
if node.Type == Page {
return result
}
for _, edge := range edges {
switch edge.Type {
case Iterator:
var items []json.RawMessage
err := json.Unmarshal(result.Payload, &items)
if err != nil {
tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
result.Error = err
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return result
}
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To))
for _, target := range edge.To {
for i, item := range items {
tm.wg.Add(1)
go func(ctx context.Context, target *Node, item json.RawMessage, i int) {
ctxx := context.Background()
if headers, ok := mq.GetHeaders(ctx); ok {
headers.Set(consts.QueueKey, target.Key)
headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i))
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.processNode(ctxx, target, item)
}(ctx, target, item, i)
} }
} }
if result.Status == "" {
result.Status = state.Status
} }
select {
case tm.resultQueue <- nodeResult{
ctx: ctx,
nodeID: state.NodeID,
result: result,
status: state.Status,
}:
default:
log.Println("Result queue is full, dropping result.")
} }
for _, edge := range edges {
switch edge.Type {
case Simple:
if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok {
continue
}
tm.processEdge(ctx, edge, result)
}
}
return result
} }
func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) { func (tm *TaskManager) onNodeCompleted(rs nodeResult) {
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To)) nodeID := strings.Split(rs.nodeID, Delimiter)[0]
index, _ := mq.GetHeader(ctx, "index") node, ok := tm.dag.nodes.Get(nodeID)
if index != "" && strings.Contains(index, "__") { if !ok {
index = strings.Split(index, "__")[1] return
} else {
index = "0"
} }
for _, target := range edge.To { edges := tm.getConditionalEdges(node, rs.result)
tm.wg.Add(1) hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0
go func(ctx context.Context, target *Node, result mq.Result) { if hasErrorOrCompleted {
ctxx := context.Background() if index, ok := rs.ctx.Value(ContextIndex).(string); ok {
if headers, ok := mq.GetHeaders(ctx); ok { childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index)
headers.Set(consts.QueueKey, target.Key) pn, ok := tm.parentNodes.Get(childNode)
headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index)) if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap()) parentState, _ := tm.taskStates.Get(pn)
if parentState != nil {
pn = strings.Split(pn, Delimiter)[0]
tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true)
} }
tm.processNode(ctxx, target, result.Payload)
}(ctx, target, result)
} }
} }
return
}
tm.handleEdges(rs, edges)
}
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge { func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
edges := make([]Edge, len(node.Edges)) edges := make([]Edge, len(node.Edges))
copy(edges, node.Edges) copy(edges, node.Edges)
if result.ConditionStatus != "" { if result.ConditionStatus != "" {
if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok { if conditions, ok := tm.dag.conditions[result.Topic]; ok {
if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok { if targetNodeKey, ok := conditions[result.ConditionStatus]; ok {
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok { if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
edges = append(edges, Edge{From: node, To: []*Node{targetNode}}) edges = append(edges, Edge{From: node, To: targetNode})
} }
} else if targetNodeKey, ok = conditions["default"]; ok { } else if targetNodeKey, ok = conditions["default"]; ok {
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok { if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
edges = append(edges, Edge{From: node, To: []*Node{targetNode}}) edges = append(edges, Edge{From: node, To: targetNode})
} }
} }
} }
@@ -252,123 +312,77 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge
return edges return edges
} }
func (tm *TaskManager) renderResult(ctx context.Context) mq.Result { func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
var rs mq.Result
tm.updateTS(&rs)
tm.dag.callbackToConsumer(ctx, rs)
tm.topic = rs.Topic
return rs
}
func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) {
topic := nodeID
if !strings.Contains(nodeID, "__") {
nodeID = getTopic(ctx, nodeID)
} else {
topic = strings.Split(nodeID, "__")[0]
}
nodeStatus, ok := tm.taskNodeStatus.Get(nodeID)
if !ok || nodeStatus == nil {
return
}
nodeStatus.markAs(rs, status)
switch status {
case Completed:
canProceed := false
edges, ok := tm.iteratorNodes.Get(topic)
if ok {
if len(edges) == 0 {
canProceed = true
} else {
nodeStatus.status = Processing
nodeStatus.totalItems = 1
nodeStatus.itemResults.Clear()
for _, edge := range edges { for _, edge := range edges {
tm.processEdge(ctx, edge, rs) index, ok := currentResult.ctx.Value(ContextIndex).(string)
if !ok {
index = "0"
} }
tm.iteratorNodes.Del(topic) parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index)
if edge.Type == Simple {
if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok {
continue
} }
} }
if canProceed || !ok { if edge.Type == Iterator {
if topic == tm.dag.startNode { var items []json.RawMessage
tm.result = rs err := json.Unmarshal(currentResult.result.Payload, &items)
} else {
tm.markParentTask(ctx, topic, nodeID, status, rs)
}
}
case Failed:
if topic == tm.dag.startNode {
tm.result = rs
} else {
tm.markParentTask(ctx, topic, nodeID, status, rs)
}
}
}
func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) {
parentNodes, err := tm.dag.GetPreviousNodes(topic)
if err != nil { if err != nil {
log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err)
tm.resultQueue <- nodeResult{
ctx: currentResult.ctx,
nodeID: edge.To.ID,
status: mq.Failed,
result: mq.Result{Error: err},
}
return return
} }
var index string tm.childNodes.Set(parentNode, len(items))
nodeParts := strings.Split(nodeID, "__") for i, item := range items {
if len(nodeParts) == 2 { childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i)
index = nodeParts[1] ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i))
tm.parentNodes.Set(childNode, parentNode)
tm.send(ctx, edge.To.ID, tm.taskID, item)
} }
for _, parentNode := range parentNodes { } else {
parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index) tm.childNodes.Set(parentNode, 1)
parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey) idx, ok := currentResult.ctx.Value(ContextIndex).(string)
if !exists { if !ok {
parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0") idx = "0"
parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey)
}
if exists {
parentNodeStatus.itemResults.Set(nodeID, rs)
if parentNodeStatus.IsDone() {
rt := tm.prepareResult(ctx, parentNodeStatus)
tm.ChangeNodeStatus(ctx, parentKey, status, rt)
} }
childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
tm.parentNodes.Set(childNode, parentNode)
tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
} }
} }
} }
func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result { func (tm *TaskManager) retryDeferredTasks() {
aggregatedOutput := make([]json.RawMessage, 0) const maxRetries = 5
var status mq.Status retries := 0
var topic string for retries < maxRetries {
var err1 error select {
if nodeStatus.totalItems == 1 { case <-tm.stopCh:
rs := nodeStatus.itemResults.Values()[0] log.Println("Stopping Deferred task Retrier")
if rs.Ctx == nil { return
rs.Ctx = ctx case <-time.After(RetryInterval):
} tm.deferredTasks.ForEach(func(taskID string, task *task) bool {
return rs tm.send(task.ctx, task.nodeID, taskID, task.payload)
} retries++
nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool {
if topic == "" {
topic = result.Topic
status = result.Status
}
if result.Error != nil {
err1 = result.Error
return false
}
var item json.RawMessage
err := json.Unmarshal(result.Payload, &item)
if err != nil {
err1 = err
return false
}
aggregatedOutput = append(aggregatedOutput, item)
return true return true
}) })
if err1 != nil {
return mq.HandleError(ctx, err1)
} }
finalOutput, err := json.Marshal(aggregatedOutput)
if err != nil {
return mq.HandleError(ctx, err)
} }
return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx} }
func (tm *TaskManager) processFinalResult(state *TaskState) {
state.targetResults.Clear()
if tm.dag.finalResult != nil {
tm.dag.finalResult(tm.taskID, state.Result)
}
}
func (tm *TaskManager) Stop() {
close(tm.stopCh)
} }

178
dag/ui.go
View File

@@ -9,25 +9,24 @@ import (
func (tm *DAG) PrintGraph() { func (tm *DAG) PrintGraph() {
fmt.Println("DAG Graph structure:") fmt.Println("DAG Graph structure:")
for _, node := range tm.nodes { tm.nodes.ForEach(func(_ string, node *Node) bool {
fmt.Printf("Node: %s (%s) -> ", node.Name, node.Key) fmt.Printf("Node: %s (%s) -> ", node.Label, node.ID)
if conditions, ok := tm.conditions[FromNode(node.Key)]; ok { if conditions, ok := tm.conditions[node.ID]; ok {
var c []string var c []string
for when, then := range conditions { for when, then := range conditions {
if target, ok := tm.nodes[string(then)]; ok { if target, ok := tm.nodes.Get(then); ok {
c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Name, target.Key)) c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Label, target.ID))
} }
} }
fmt.Println(strings.Join(c, ", ")) fmt.Println(strings.Join(c, ", "))
} }
var edges []string var edges []string
for _, edge := range node.Edges { for _, target := range node.Edges {
for _, target := range edge.To { edges = append(edges, fmt.Sprintf("%s (%s)", target.To.Label, target.To.ID))
edges = append(edges, fmt.Sprintf("%s (%s)", target.Name, target.Key))
}
} }
fmt.Println(strings.Join(edges, ", ")) fmt.Println(strings.Join(edges, ", "))
} return true
})
} }
func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) { func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
@@ -44,7 +43,7 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
if startNode == "" { if startNode == "" {
firstNode := tm.findStartNode() firstNode := tm.findStartNode()
if firstNode != nil { if firstNode != nil {
startNode = firstNode.Key startNode = firstNode.ID
} }
} }
if startNode == "" { if startNode == "" {
@@ -62,26 +61,24 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim
inRecursionStack[v] = true // mark node as part of recursion stack inRecursionStack[v] = true // mark node as part of recursion stack
*timeVal++ *timeVal++
discoveryTime[v] = *timeVal discoveryTime[v] = *timeVal
node := tm.nodes[v] node, _ := tm.nodes.Get(v)
hasCycle := false hasCycle := false
var err error var err error
for _, edge := range node.Edges { for _, edge := range node.Edges {
for _, adj := range edge.To { if !visited[edge.To.ID] {
if !visited[adj.Key] { builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, edge.To.ID))
builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key)) hasCycle, err := tm.dfs(edge.To.ID, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
if err != nil { if err != nil {
return true, err return true, err
} }
if hasCycle { if hasCycle {
return true, nil return true, nil
} }
} else if inRecursionStack[adj.Key] { } else if inRecursionStack[edge.To.ID] {
cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key) cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, edge.To.ID)
return true, fmt.Errorf(cycleMsg) return true, fmt.Errorf(cycleMsg)
} }
} }
}
hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
if err != nil { if err != nil {
return true, err return true, err
@@ -93,20 +90,20 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim
} }
func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) { func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) {
node := tm.nodes[v] node, _ := tm.nodes.Get(v)
for when, then := range tm.conditions[FromNode(node.Key)] { for when, then := range tm.conditions[node.ID] {
if targetNode, ok := tm.nodes[string(then)]; ok { if targetNode, ok := tm.nodes.Get(then); ok {
if !visited[targetNode.Key] { if !visited[targetNode.ID] {
builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)) builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID))
hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder) hasCycle, err := tm.dfs(targetNode.ID, visited, discoveryTime, finishedTime, time, inRecursionStack, builder)
if err != nil { if err != nil {
return true, err return true, err
} }
if hasCycle { if hasCycle {
return true, nil return true, nil
} }
} else if inRecursionStack[targetNode.Key] { } else if inRecursionStack[targetNode.ID] {
cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key) cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID)
return true, fmt.Errorf(cycleMsg) return true, fmt.Errorf(cycleMsg)
} }
} }
@@ -146,98 +143,113 @@ func (tm *DAG) ExportDOT() string {
var sb strings.Builder var sb strings.Builder
sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name)) sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name))
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(fmt.Sprintf(` label="%s";`, tm.name)) sb.WriteString(` label="Enhanced DAG Representation";`)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` labelloc="t";`) sb.WriteString(` labelloc="t"; fontsize=22; fontname="Helvetica";`)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` fontsize=20;`) sb.WriteString(` node [shape=box, fontname="Helvetica", fillcolor="#B3CDE0", fontcolor="#2C3E50", fontsize=10, margin="0.25,0.15", style="rounded,filled"];`)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` node [shape=box, style="rounded,filled", fillcolor="lightgray", fontname="Arial", margin="0.2,0.1"];`) sb.WriteString(` edge [fontname="Helvetica", fontsize=12, arrowsize=0.8];`)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` edge [fontname="Arial", fontsize=12, arrowsize=0.8];`) sb.WriteString(` rankdir=TB;`)
sb.WriteString("\n")
sb.WriteString(` size="10,10";`)
sb.WriteString("\n")
sb.WriteString(` ratio="fill";`)
sb.WriteString("\n") sb.WriteString("\n")
sortedNodes := tm.TopologicalSort() sortedNodes := tm.TopologicalSort()
for _, nodeKey := range sortedNodes { for _, nodeKey := range sortedNodes {
node := tm.nodes[nodeKey] node, _ := tm.nodes.Get(nodeKey)
nodeColor := "lightblue" nodeColor := "lightgray"
sb.WriteString(fmt.Sprintf(` "%s" [label=" %s", fillcolor="%s", id="node_%s"];`, node.Key, node.Name, nodeColor, node.Key)) nodeShape := "box"
labelSuffix := ""
// Apply styles based on NodeType
switch node.NodeType {
case Function:
nodeColor = "#D4EDDA"
labelSuffix = " [Function]"
case Page:
nodeColor = "#f0d2d1"
labelSuffix = " [Page]"
}
sb.WriteString(fmt.Sprintf(
` "%s" [label="%s%s", fontcolor="#2C3E50", fillcolor="%s", shape="%s", style="rounded,filled", id="node_%s"];`,
node.ID, node.Label, labelSuffix, nodeColor, nodeShape, node.ID))
sb.WriteString("\n") sb.WriteString("\n")
} }
// Define edges with unique styling by EdgeType
for _, nodeKey := range sortedNodes { for _, nodeKey := range sortedNodes {
node := tm.nodes[nodeKey] node, _ := tm.nodes.Get(nodeKey)
for _, edge := range node.Edges { for _, edge := range node.Edges {
var edgeStyle string edgeStyle := "solid"
edgeColor := "black"
labelSuffix := ""
// Apply styles based on EdgeType
switch edge.Type { switch edge.Type {
case Iterator: case Iterator:
edgeStyle = "dashed" edgeStyle = "dashed"
default: edgeColor = "blue"
labelSuffix = " [Iter]"
case Simple:
edgeStyle = "solid" edgeStyle = "solid"
edgeColor = "black"
labelSuffix = ""
} }
edgeColor := "black" sb.WriteString(fmt.Sprintf(
for _, to := range edge.To { ` "%s" -> "%s" [label="%s%s", color="%s", style="%s"];`,
sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="%s", style=%s, fontsize=10, arrowsize=0.6];`, node.Key, to.Key, edge.Label, edgeColor, edgeStyle)) node.ID, edge.To.ID, edge.Label, labelSuffix, edgeColor, edgeStyle))
sb.WriteString("\n") sb.WriteString("\n")
} }
} }
}
for fromNodeKey, conditions := range tm.conditions { for fromNodeKey, conditions := range tm.conditions {
for when, then := range conditions { for when, then := range conditions {
if toNode, ok := tm.nodes[string(then)]; ok { if toNode, ok := tm.nodes.Get(then); ok {
sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.Key, when)) sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.ID, when))
sb.WriteString("\n") sb.WriteString("\n")
} }
} }
} }
// Optional: Group related nodes into subgraphs (e.g., loops)
for _, nodeKey := range sortedNodes { for _, nodeKey := range sortedNodes {
node := tm.nodes[nodeKey] node, _ := tm.nodes.Get(nodeKey)
if node.processor != nil { if node.processor != nil {
subDAG, _ := isDAGNode(node) subDAG, _ := isDAGNode(node)
if subDAG != nil { if subDAG != nil {
sb.WriteString(fmt.Sprintf(` subgraph "cluster_%s" {`, subDAG.name)) sb.WriteString(fmt.Sprintf(` subgraph "cluster_%s" {`, subDAG.name))
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(fmt.Sprintf(` label=" %s";`, subDAG.name)) sb.WriteString(fmt.Sprintf(` label="Subgraph: %s";`, subDAG.name))
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` style=dashed;`) sb.WriteString(` style=filled; color=gray90;`)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` bgcolor="lightgray";`) subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool {
sb.WriteString(fmt.Sprintf(` "%s" [label="%s"];`, subNode.ID, subNode.Label))
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` node [shape=rectangle, style="filled", fillcolor="lightblue", fontname="Arial", margin="0.2,0.1"];`) return true
sb.WriteString("\n") })
for subNodeKey, subNode := range subDAG.nodes { subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool {
sb.WriteString(fmt.Sprintf(` "%s" [label=" %s"];`, subNodeKey, subNode.Name))
sb.WriteString("\n")
}
for subNodeKey, subNode := range subDAG.nodes {
for _, edge := range subNode.Edges { for _, edge := range subNode.Edges {
for _, to := range edge.To { sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label="%s"];`, subNodeKey, edge.To.ID, edge.Label))
sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, subNodeKey, to.Key, edge.Label))
sb.WriteString("\n") sb.WriteString("\n")
} }
} return true
} })
sb.WriteString(` }`) sb.WriteString(" }\n")
sb.WriteString("\n")
sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, node.Key, subDAG.startNode, subDAG.name))
sb.WriteString("\n")
} }
} }
} }
sb.WriteString(`}`)
sb.WriteString("\n") sb.WriteString("}\n")
return sb.String() return sb.String()
} }
func (tm *DAG) TopologicalSort() (stack []string) { func (tm *DAG) TopologicalSort() (stack []string) {
visited := make(map[string]bool) visited := make(map[string]bool)
for _, node := range tm.nodes { tm.nodes.ForEach(func(_ string, node *Node) bool {
if !visited[node.Key] { if !visited[node.ID] {
tm.topologicalSortUtil(node.Key, visited, &stack) tm.topologicalSortUtil(node.ID, visited, &stack)
}
} }
return true
})
for i, j := 0, len(stack)-1; i < j; i, j = i+1, j-1 { for i, j := 0, len(stack)-1; i < j; i, j = i+1, j-1 {
stack[i], stack[j] = stack[j], stack[i] stack[i], stack[j] = stack[j], stack[i]
} }
@@ -246,13 +258,23 @@ func (tm *DAG) TopologicalSort() (stack []string) {
func (tm *DAG) topologicalSortUtil(v string, visited map[string]bool, stack *[]string) { func (tm *DAG) topologicalSortUtil(v string, visited map[string]bool, stack *[]string) {
visited[v] = true visited[v] = true
node := tm.nodes[v] node, ok := tm.nodes.Get(v)
for _, edge := range node.Edges { if !ok {
for _, to := range edge.To { fmt.Println("Not found", v)
if !visited[to.Key] {
tm.topologicalSortUtil(to.Key, visited, stack)
} }
for _, edge := range node.Edges {
if !visited[edge.To.ID] {
tm.topologicalSortUtil(edge.To.ID, visited, stack)
} }
} }
*stack = append(*stack, v) *stack = append(*stack, v)
} }
func isDAGNode(node *Node) (*DAG, bool) {
switch node := node.processor.(type) {
case *DAG:
return node, true
default:
return nil, false
}
}

209
dag/v1/api.go Normal file
View File

@@ -0,0 +1,209 @@
package v1
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"
"github.com/oarkflow/mq/jsonparser"
"github.com/oarkflow/mq/sio"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/metrics"
)
type Request struct {
Payload json.RawMessage `json:"payload"`
Interval time.Duration `json:"interval"`
Schedule bool `json:"schedule"`
Overlap bool `json:"overlap"`
Recurring bool `json:"recurring"`
}
func (tm *DAG) SetupWS() *sio.Server {
ws := sio.New(sio.Config{
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: true,
})
WsEvents(ws)
tm.Notifier = ws
return ws
}
func (tm *DAG) Handlers() {
metrics.HandleHTTP()
http.Handle("/", http.FileServer(http.Dir("webroot")))
http.Handle("/notify", tm.SetupWS())
http.HandleFunc("GET /render", tm.Render)
http.HandleFunc("POST /request", tm.Request)
http.HandleFunc("POST /publish", tm.Publish)
http.HandleFunc("POST /schedule", tm.Schedule)
http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {
tm.PauseConsumer(request.Context(), id)
}
})
http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {
tm.ResumeConsumer(request.Context(), id)
}
})
http.HandleFunc("/pause", func(w http.ResponseWriter, request *http.Request) {
err := tm.Pause(request.Context())
if err != nil {
http.Error(w, "Failed to pause", http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "paused"})
})
http.HandleFunc("/resume", func(w http.ResponseWriter, request *http.Request) {
err := tm.Resume(request.Context())
if err != nil {
http.Error(w, "Failed to resume", http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "resumed"})
})
http.HandleFunc("/stop", func(w http.ResponseWriter, request *http.Request) {
err := tm.Stop(request.Context())
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "stopped"})
})
http.HandleFunc("/close", func(w http.ResponseWriter, request *http.Request) {
err := tm.Close()
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "closed"})
})
http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, tm.ExportDOT())
})
http.HandleFunc("/ui", func(w http.ResponseWriter, r *http.Request) {
image := fmt.Sprintf("%s.svg", mq.NewID())
err := tm.SaveSVG(image)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer os.Remove(image)
svgBytes, err := os.ReadFile(image)
if err != nil {
http.Error(w, "Could not read SVG file", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "image/svg+xml")
if _, err := w.Write(svgBytes); err != nil {
http.Error(w, "Could not write SVG response", http.StatusInternalServerError)
return
}
})
}
func (tm *DAG) request(w http.ResponseWriter, r *http.Request, async bool) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var request Request
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
err = json.Unmarshal(payload, &request)
if err != nil {
http.Error(w, "Failed to unmarshal body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
ctx := r.Context()
if async {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
var opts []mq.SchedulerOption
if request.Interval > 0 {
opts = append(opts, mq.WithInterval(request.Interval))
}
if request.Overlap {
opts = append(opts, mq.WithOverlap())
}
if request.Recurring {
opts = append(opts, mq.WithRecurring())
}
ctx = context.WithValue(ctx, "query_params", r.URL.Query())
var rs mq.Result
if request.Schedule {
rs = tm.ScheduleTask(ctx, request.Payload, opts...)
} else {
rs = tm.Process(ctx, request.Payload)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rs)
}
func (tm *DAG) Render(w http.ResponseWriter, r *http.Request) {
ctx := mq.SetHeaders(r.Context(), map[string]string{consts.AwaitResponseKey: "true", "request_type": "render"})
ctx = context.WithValue(ctx, "query_params", r.URL.Query())
rs := tm.Process(ctx, nil)
content, err := jsonparser.GetString(rs.Payload, "html_content")
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", consts.TypeHtml)
w.Write([]byte(content))
}
func (tm *DAG) Request(w http.ResponseWriter, r *http.Request) {
tm.request(w, r, true)
}
func (tm *DAG) Publish(w http.ResponseWriter, r *http.Request) {
tm.request(w, r, false)
}
func (tm *DAG) Schedule(w http.ResponseWriter, r *http.Request) {
tm.request(w, r, false)
}
func GetTaskID(ctx context.Context) string {
if queryParams := ctx.Value("query_params"); queryParams != nil {
if params, ok := queryParams.(url.Values); ok {
if id := params.Get("taskID"); id != "" {
return id
}
}
}
return ""
}
func CanNextNode(ctx context.Context) string {
if queryParams := ctx.Value("query_params"); queryParams != nil {
if params, ok := queryParams.(url.Values); ok {
if id := params.Get("next"); id != "" {
return id
}
}
}
return ""
}

26
dag/v1/consts.go Normal file
View File

@@ -0,0 +1,26 @@
package v1
type NodeStatus int
func (c NodeStatus) IsValid() bool { return c >= Pending && c <= Failed }
func (c NodeStatus) String() string {
switch c {
case Pending:
return "Pending"
case Processing:
return "Processing"
case Completed:
return "Completed"
case Failed:
return "Failed"
}
return ""
}
const (
Pending NodeStatus = iota
Processing
Completed
Failed
)

612
dag/v1/dag.go Normal file
View File

@@ -0,0 +1,612 @@
package v1
import (
"context"
"fmt"
"log"
"net/http"
"time"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"github.com/oarkflow/mq/sio"
"golang.org/x/time/rate"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/metrics"
)
type EdgeType int
func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
const (
Simple EdgeType = iota
Iterator
)
type NodeType int
func (c NodeType) IsValid() bool { return c >= Process && c <= Page }
const (
Process NodeType = iota
Page
)
type Node struct {
processor mq.Processor
Name string
Type NodeType
Key string
Edges []Edge
isReady bool
}
func (n *Node) ProcessTask(ctx context.Context, msg *mq.Task) mq.Result {
return n.processor.ProcessTask(ctx, msg)
}
func (n *Node) Close() error {
return n.processor.Close()
}
type Edge struct {
Label string
From *Node
To []*Node
Type EdgeType
}
type (
FromNode string
When string
Then string
)
type DAG struct {
server *mq.Broker
consumer *mq.Consumer
taskContext storage.IMap[string, *TaskManager]
nodes map[string]*Node
iteratorNodes storage.IMap[string, []Edge]
conditions map[FromNode]map[When]Then
pool *mq.Pool
taskCleanupCh chan string
name string
key string
startNode string
consumerTopic string
opts []mq.Option
reportNodeResultCallback func(mq.Result)
Notifier *sio.Server
paused bool
Error error
report string
index string
}
func (tm *DAG) SetKey(key string) {
tm.key = key
}
func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
tm.reportNodeResultCallback = callback
}
func (tm *DAG) GetType() string {
return tm.key
}
func (tm *DAG) listenForTaskCleanup() {
for taskID := range tm.taskCleanupCh {
if tm.server.Options().CleanTaskOnComplete() {
tm.taskCleanup(taskID)
}
}
}
func (tm *DAG) taskCleanup(taskID string) {
tm.taskContext.Del(taskID)
log.Printf("DAG - Task %s cleaned up", taskID)
}
func (tm *DAG) Consume(ctx context.Context) error {
if tm.consumer != nil {
tm.server.Options().SetSyncMode(true)
return tm.consumer.Consume(ctx)
}
return nil
}
func (tm *DAG) Stop(ctx context.Context) error {
for _, n := range tm.nodes {
err := n.processor.Stop(ctx)
if err != nil {
return err
}
}
return nil
}
func (tm *DAG) GetKey() string {
return tm.key
}
func (tm *DAG) AssignTopic(topic string) {
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
tm.consumerTopic = topic
}
func NewDAG(name, key string, opts ...mq.Option) *DAG {
callback := func(ctx context.Context, result mq.Result) error { return nil }
d := &DAG{
name: name,
key: key,
nodes: make(map[string]*Node),
iteratorNodes: memory.New[string, []Edge](),
taskContext: memory.New[string, *TaskManager](),
conditions: make(map[FromNode]map[When]Then),
taskCleanupCh: make(chan string),
}
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
d.server = mq.NewBroker(opts...)
d.opts = opts
options := d.server.Options()
d.pool = mq.NewPool(
options.NumOfWorkers(),
mq.WithTaskQueueSize(options.QueueSize()),
mq.WithMaxMemoryLoad(options.MaxMemoryLoad()),
mq.WithHandler(d.ProcessTask),
mq.WithPoolCallback(callback),
mq.WithTaskStorage(options.Storage()),
)
d.pool.Start(d.server.Options().NumOfWorkers())
go d.listenForTaskCleanup()
return d
}
func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
if tm.consumer != nil {
result.Topic = tm.consumerTopic
if tm.consumer.Conn() == nil {
tm.onTaskCallback(ctx, result)
} else {
tm.consumer.OnResponse(ctx, result)
}
}
}
func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
if taskContext, ok := tm.taskContext.Get(result.TaskID); ok && result.Topic != "" {
return taskContext.handleNextTask(ctx, result)
}
return mq.Result{}
}
func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
if node, ok := tm.nodes[topic]; ok {
log.Printf("DAG - CONSUMER ~> ready on %s", topic)
node.isReady = true
}
}
func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) {
if node, ok := tm.nodes[topic]; ok {
log.Printf("DAG - CONSUMER ~> down on %s", topic)
node.isReady = false
}
}
func (tm *DAG) SetStartNode(node string) {
tm.startNode = node
}
func (tm *DAG) GetStartNode() string {
return tm.startNode
}
func (tm *DAG) Start(ctx context.Context, addr string) error {
// Start the server in a separate goroutine
go func() {
defer mq.RecoverPanic(mq.RecoverTitle)
if err := tm.server.Start(ctx); err != nil {
panic(err)
}
}()
// Start the node consumers if not in sync mode
if !tm.server.SyncMode() {
for _, con := range tm.nodes {
go func(con *Node) {
defer mq.RecoverPanic(mq.RecoverTitle)
limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
for {
err := con.processor.Consume(ctx)
if err != nil {
log.Printf("[ERROR] - Consumer %s failed to start: %v", con.Key, err)
} else {
log.Printf("[INFO] - Consumer %s started successfully", con.Key)
break
}
limiter.Wait(ctx) // Wait with rate limiting before retrying
}
}(con)
}
}
log.Printf("DAG - HTTP_SERVER ~> started on http://localhost%s", addr)
tm.Handlers()
config := tm.server.TLSConfig()
if config.UseTLS {
return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
}
return http.ListenAndServe(addr, nil)
}
func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
dag.AssignTopic(key)
tm.nodes[key] = &Node{
Name: name,
Key: key,
processor: dag,
isReady: true,
}
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
return tm
}
func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG {
con := mq.NewConsumer(key, key, handler.ProcessTask, tm.opts...)
n := &Node{
Name: name,
Key: key,
processor: con,
}
if handler.GetType() == "page" {
n.Type = Page
}
if tm.server.SyncMode() {
n.isReady = true
}
tm.nodes[key] = n
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
return tm
}
func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error {
if tm.server.SyncMode() {
return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
}
tm.nodes[key] = &Node{
Name: name,
Key: key,
}
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
return nil
}
func (tm *DAG) IsReady() bool {
for _, node := range tm.nodes {
if !node.isReady {
return false
}
}
return true
}
func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG {
tm.conditions[fromNode] = conditions
return tm
}
func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG {
tm.Error = tm.addEdge(Iterator, label, from, targets...)
tm.iteratorNodes.Set(from, []Edge{})
return tm
}
func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG {
tm.Error = tm.addEdge(Simple, label, from, targets...)
return tm
}
func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error {
fromNode, ok := tm.nodes[from]
if !ok {
return fmt.Errorf("Error: 'from' node %s does not exist\n", from)
}
var nodes []*Node
for _, target := range targets {
toNode, ok := tm.nodes[target]
if !ok {
return fmt.Errorf("Error: 'from' node %s does not exist\n", target)
}
nodes = append(nodes, toNode)
}
edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label}
fromNode.Edges = append(fromNode.Edges, edge)
if edgeType != Iterator {
if edges, ok := tm.iteratorNodes.Get(fromNode.Key); ok {
edges = append(edges, edge)
tm.iteratorNodes.Set(fromNode.Key, edges)
}
}
return nil
}
func (tm *DAG) Validate() error {
report, hasCycle, err := tm.ClassifyEdges()
if hasCycle || err != nil {
tm.Error = err
return err
}
tm.report = report
return nil
}
func (tm *DAG) GetReport() string {
return tm.report
}
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
if task.ID == "" {
task.ID = mq.NewID()
}
if index, ok := mq.GetHeader(ctx, "index"); ok {
tm.index = index
}
manager, exists := tm.taskContext.Get(task.ID)
if !exists {
manager = NewTaskManager(tm, task.ID, tm.iteratorNodes)
manager.createdAt = task.CreatedAt
tm.taskContext.Set(task.ID, manager)
}
if tm.consumer != nil {
initialNode, err := tm.parseInitialNode(ctx)
if err != nil {
metrics.TasksErrors.WithLabelValues("unknown").Inc() // Increase error count
return mq.Result{Error: err}
}
task.Topic = initialNode
}
if manager.topic != "" {
task.Topic = manager.topic
canNext := CanNextNode(ctx)
if canNext != "" {
if n, ok := tm.nodes[task.Topic]; ok {
if len(n.Edges) > 0 {
task.Topic = n.Edges[0].To[0].Key
}
}
} else {
}
}
result := manager.processTask(ctx, task.Topic, task.Payload)
if result.Ctx != nil && tm.index != "" {
result.Ctx = mq.SetHeaders(result.Ctx, map[string]string{"index": tm.index})
}
if result.Error != nil {
metrics.TasksErrors.WithLabelValues(task.Topic).Inc() // Increase error count
} else {
metrics.TasksProcessed.WithLabelValues("success").Inc() // Increase processed task count
}
return result
}
func (tm *DAG) check(ctx context.Context, payload []byte) (context.Context, *mq.Task, error) {
if tm.paused {
return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not accepting any task")
}
if !tm.IsReady() {
return ctx, nil, fmt.Errorf("unable to process task, error: DAG is not ready yet")
}
initialNode, err := tm.parseInitialNode(ctx)
if err != nil {
return ctx, nil, err
}
if tm.server.SyncMode() {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
taskID := GetTaskID(ctx)
if taskID != "" {
if _, exists := tm.taskContext.Get(taskID); !exists {
return ctx, nil, fmt.Errorf("provided task ID doesn't exist")
}
}
if taskID == "" {
taskID = mq.NewID()
}
return ctx, mq.NewTask(taskID, payload, initialNode), nil
}
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
ctx, task, err := tm.check(ctx, payload)
if err != nil {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
awaitResponse, _ := mq.GetAwaitResponse(ctx)
if awaitResponse != "true" {
headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background()
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil {
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "FAILED", Error: err}
}
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"}
}
return tm.ProcessTask(ctx, task)
}
func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
ctx, task, err := tm.check(ctx, payload)
if err != nil {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background()
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.pool.Scheduler().AddTask(ctxx, task, opts...)
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: task.Topic, Status: "PENDING"}
}
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
val := ctx.Value("initial_node")
initialNode, ok := val.(string)
if ok {
return initialNode, nil
}
if tm.startNode == "" {
firstNode := tm.findStartNode()
if firstNode != nil {
tm.startNode = firstNode.Key
}
}
if tm.startNode == "" {
return "", fmt.Errorf("initial node not found")
}
return tm.startNode, nil
}
func (tm *DAG) findStartNode() *Node {
incomingEdges := make(map[string]bool)
connectedNodes := make(map[string]bool)
for _, node := range tm.nodes {
for _, edge := range node.Edges {
if edge.Type.IsValid() {
connectedNodes[node.Key] = true
for _, to := range edge.To {
connectedNodes[to.Key] = true
incomingEdges[to.Key] = true
}
}
}
if cond, ok := tm.conditions[FromNode(node.Key)]; ok {
for _, target := range cond {
connectedNodes[string(target)] = true
incomingEdges[string(target)] = true
}
}
}
for nodeID, node := range tm.nodes {
if !incomingEdges[nodeID] && connectedNodes[nodeID] {
return node
}
}
return nil
}
func (tm *DAG) Pause(_ context.Context) error {
tm.paused = true
return nil
}
func (tm *DAG) Resume(_ context.Context) error {
tm.paused = false
return nil
}
func (tm *DAG) Close() error {
for _, n := range tm.nodes {
err := n.Close()
if err != nil {
return err
}
}
return nil
}
func (tm *DAG) PauseConsumer(ctx context.Context, id string) {
tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE)
}
func (tm *DAG) ResumeConsumer(ctx context.Context, id string) {
tm.doConsumer(ctx, id, consts.CONSUMER_RESUME)
}
func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) {
if node, ok := tm.nodes[id]; ok {
switch action {
case consts.CONSUMER_PAUSE:
err := node.processor.Pause(ctx)
if err == nil {
node.isReady = false
log.Printf("[INFO] - Consumer %s paused successfully", node.Key)
} else {
log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.Key, err)
}
case consts.CONSUMER_RESUME:
err := node.processor.Resume(ctx)
if err == nil {
node.isReady = true
log.Printf("[INFO] - Consumer %s resumed successfully", node.Key)
} else {
log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.Key, err)
}
}
} else {
log.Printf("[WARNING] - Consumer %s not found", id)
}
}
func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
tm.server.SetNotifyHandler(callback)
}
func (tm *DAG) GetNextNodes(key string) ([]*Node, error) {
node, exists := tm.nodes[key]
if !exists {
return nil, fmt.Errorf("Node with key %s does not exist", key)
}
var successors []*Node
for _, edge := range node.Edges {
successors = append(successors, edge.To...)
}
if conds, exists := tm.conditions[FromNode(key)]; exists {
for _, targetKey := range conds {
if targetNode, exists := tm.nodes[string(targetKey)]; exists {
successors = append(successors, targetNode)
}
}
}
return successors, nil
}
func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) {
var predecessors []*Node
for _, node := range tm.nodes {
for _, edge := range node.Edges {
for _, target := range edge.To {
if target.Key == key {
predecessors = append(predecessors, node)
}
}
}
}
for fromNode, conds := range tm.conditions {
for _, targetKey := range conds {
if string(targetKey) == key {
node, exists := tm.nodes[string(fromNode)]
if !exists {
return nil, fmt.Errorf("Node with key %s does not exist", fromNode)
}
predecessors = append(predecessors, node)
}
}
}
return predecessors, nil
}

View File

@@ -1,4 +1,4 @@
package v2 package v1
import ( import (
"context" "context"

View File

@@ -1,4 +1,4 @@
package v2 package v1
import ( import (
"sync" "sync"

374
dag/v1/task_manager.go Normal file
View File

@@ -0,0 +1,374 @@
package v1
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"github.com/oarkflow/mq"
)
type TaskManager struct {
createdAt time.Time
processedAt time.Time
status string
dag *DAG
taskID string
wg *WaitGroup
topic string
result mq.Result
iteratorNodes storage.IMap[string, []Edge]
taskNodeStatus storage.IMap[string, *taskNodeStatus]
}
func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
return &TaskManager{
dag: d,
taskNodeStatus: memory.New[string, *taskNodeStatus](),
taskID: taskID,
iteratorNodes: iteratorNodes,
wg: NewWaitGroup(),
}
}
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
tm.updateTS(&tm.result)
tm.dag.callbackToConsumer(ctx, tm.result)
if tm.dag.server.NotifyHandler() != nil {
_ = tm.dag.server.NotifyHandler()(ctx, tm.result)
}
tm.dag.taskCleanupCh <- tm.taskID
tm.topic = tm.result.Topic
return tm.result
}
func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) {
if tm.dag.reportNodeResultCallback != nil {
tm.dag.reportNodeResultCallback(result)
}
}
func (tm *TaskManager) SetTotalItems(topic string, i int) {
if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok {
nodeStatus.totalItems = i
}
}
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
topic := getTopic(ctx, node.Key)
tm.taskNodeStatus.Set(topic, newNodeStatus(topic))
defer mq.RecoverPanic(mq.RecoverTitle)
dag, isDAG := isDAGNode(node)
if isDAG {
if tm.dag.server.SyncMode() && !dag.server.SyncMode() {
dag.server.Options().SetSyncMode(true)
}
}
tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key})
var result mq.Result
if tm.dag.server.SyncMode() {
defer func() {
if isDAG {
result.Topic = dag.consumerTopic
result.TaskID = tm.taskID
tm.reportNodeResult(result, false)
tm.handleNextTask(result.Ctx, result)
} else {
result.Topic = node.Key
tm.reportNodeResult(result, false)
tm.handleNextTask(ctx, result)
}
}()
}
select {
case <-ctx.Done():
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return
default:
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
if tm.dag.server.SyncMode() {
result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key))
if result.Error != nil {
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return
}
return
}
err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
if err != nil {
tm.reportNodeResult(mq.Result{Error: err}, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return
}
}
}
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
defer mq.RecoverPanic(mq.RecoverTitle)
node, ok := tm.dag.nodes[nodeID]
if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
}
if tm.createdAt.IsZero() {
tm.createdAt = time.Now()
}
tm.wg.Add(1)
go func() {
ctxx := context.Background()
if headers, ok := mq.GetHeaders(ctx); ok {
headers.Set(consts.QueueKey, node.Key)
headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0))
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
go tm.processNode(ctx, node, payload)
}()
tm.wg.Wait()
requestType, ok := mq.GetHeader(ctx, "request_type")
if ok && requestType == "render" {
return tm.renderResult(ctx)
}
return tm.dispatchFinalResult(ctx)
}
func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result {
tm.topic = result.Topic
defer func() {
tm.wg.Done()
mq.RecoverPanic(mq.RecoverTitle)
}()
if result.Ctx != nil {
if headers, ok := mq.GetHeaders(ctx); ok {
ctx = mq.SetHeaders(result.Ctx, headers.AsMap())
}
}
node, ok := tm.dag.nodes[result.Topic]
if !ok {
return result
}
if result.Error != nil {
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return result
}
edges := tm.getConditionalEdges(node, result)
if len(edges) == 0 {
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Completed, result)
return result
} else {
tm.reportNodeResult(result, false)
}
if node.Type == Page {
return result
}
for _, edge := range edges {
switch edge.Type {
case Iterator:
var items []json.RawMessage
err := json.Unmarshal(result.Payload, &items)
if err != nil {
tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
result.Error = err
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return result
}
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To))
for _, target := range edge.To {
for i, item := range items {
tm.wg.Add(1)
go func(ctx context.Context, target *Node, item json.RawMessage, i int) {
ctxx := context.Background()
if headers, ok := mq.GetHeaders(ctx); ok {
headers.Set(consts.QueueKey, target.Key)
headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i))
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.processNode(ctxx, target, item)
}(ctx, target, item, i)
}
}
}
}
for _, edge := range edges {
switch edge.Type {
case Simple:
if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok {
continue
}
tm.processEdge(ctx, edge, result)
}
}
return result
}
func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) {
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To))
index, _ := mq.GetHeader(ctx, "index")
if index != "" && strings.Contains(index, "__") {
index = strings.Split(index, "__")[1]
} else {
index = "0"
}
for _, target := range edge.To {
tm.wg.Add(1)
go func(ctx context.Context, target *Node, result mq.Result) {
ctxx := context.Background()
if headers, ok := mq.GetHeaders(ctx); ok {
headers.Set(consts.QueueKey, target.Key)
headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index))
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.processNode(ctxx, target, result.Payload)
}(ctx, target, result)
}
}
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
edges := make([]Edge, len(node.Edges))
copy(edges, node.Edges)
if result.ConditionStatus != "" {
if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok {
if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok {
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
}
} else if targetNodeKey, ok = conditions["default"]; ok {
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
}
}
}
}
return edges
}
func (tm *TaskManager) renderResult(ctx context.Context) mq.Result {
var rs mq.Result
tm.updateTS(&rs)
tm.dag.callbackToConsumer(ctx, rs)
tm.topic = rs.Topic
return rs
}
func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) {
topic := nodeID
if !strings.Contains(nodeID, "__") {
nodeID = getTopic(ctx, nodeID)
} else {
topic = strings.Split(nodeID, "__")[0]
}
nodeStatus, ok := tm.taskNodeStatus.Get(nodeID)
if !ok || nodeStatus == nil {
return
}
nodeStatus.markAs(rs, status)
switch status {
case Completed:
canProceed := false
edges, ok := tm.iteratorNodes.Get(topic)
if ok {
if len(edges) == 0 {
canProceed = true
} else {
nodeStatus.status = Processing
nodeStatus.totalItems = 1
nodeStatus.itemResults.Clear()
for _, edge := range edges {
tm.processEdge(ctx, edge, rs)
}
tm.iteratorNodes.Del(topic)
}
}
if canProceed || !ok {
if topic == tm.dag.startNode {
tm.result = rs
} else {
tm.markParentTask(ctx, topic, nodeID, status, rs)
}
}
case Failed:
if topic == tm.dag.startNode {
tm.result = rs
} else {
tm.markParentTask(ctx, topic, nodeID, status, rs)
}
}
}
func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) {
parentNodes, err := tm.dag.GetPreviousNodes(topic)
if err != nil {
return
}
var index string
nodeParts := strings.Split(nodeID, "__")
if len(nodeParts) == 2 {
index = nodeParts[1]
}
for _, parentNode := range parentNodes {
parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index)
parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey)
if !exists {
parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0")
parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey)
}
if exists {
parentNodeStatus.itemResults.Set(nodeID, rs)
if parentNodeStatus.IsDone() {
rt := tm.prepareResult(ctx, parentNodeStatus)
tm.ChangeNodeStatus(ctx, parentKey, status, rt)
}
}
}
}
func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result {
aggregatedOutput := make([]json.RawMessage, 0)
var status mq.Status
var topic string
var err1 error
if nodeStatus.totalItems == 1 {
rs := nodeStatus.itemResults.Values()[0]
if rs.Ctx == nil {
rs.Ctx = ctx
}
return rs
}
nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool {
if topic == "" {
topic = result.Topic
status = result.Status
}
if result.Error != nil {
err1 = result.Error
return false
}
var item json.RawMessage
err := json.Unmarshal(result.Payload, &item)
if err != nil {
err1 = err
return false
}
aggregatedOutput = append(aggregatedOutput, item)
return true
})
if err1 != nil {
return mq.HandleError(ctx, err1)
}
finalOutput, err := json.Marshal(aggregatedOutput)
if err != nil {
return mq.HandleError(ctx, err)
}
return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx}
}

View File

@@ -1,4 +1,4 @@
package v2 package v1
import ( import (
"fmt" "fmt"
@@ -9,24 +9,25 @@ import (
func (tm *DAG) PrintGraph() { func (tm *DAG) PrintGraph() {
fmt.Println("DAG Graph structure:") fmt.Println("DAG Graph structure:")
tm.nodes.ForEach(func(_ string, node *Node) bool { for _, node := range tm.nodes {
fmt.Printf("Node: %s (%s) -> ", node.Label, node.ID) fmt.Printf("Node: %s (%s) -> ", node.Name, node.Key)
if conditions, ok := tm.conditions[node.ID]; ok { if conditions, ok := tm.conditions[FromNode(node.Key)]; ok {
var c []string var c []string
for when, then := range conditions { for when, then := range conditions {
if target, ok := tm.nodes.Get(then); ok { if target, ok := tm.nodes[string(then)]; ok {
c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Label, target.ID)) c = append(c, fmt.Sprintf("If [%s] Then %s (%s)", when, target.Name, target.Key))
} }
} }
fmt.Println(strings.Join(c, ", ")) fmt.Println(strings.Join(c, ", "))
} }
var edges []string var edges []string
for _, target := range node.Edges { for _, edge := range node.Edges {
edges = append(edges, fmt.Sprintf("%s (%s)", target.To.Label, target.To.ID)) for _, target := range edge.To {
edges = append(edges, fmt.Sprintf("%s (%s)", target.Name, target.Key))
}
} }
fmt.Println(strings.Join(edges, ", ")) fmt.Println(strings.Join(edges, ", "))
return true }
})
} }
func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) { func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
@@ -43,7 +44,7 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
if startNode == "" { if startNode == "" {
firstNode := tm.findStartNode() firstNode := tm.findStartNode()
if firstNode != nil { if firstNode != nil {
startNode = firstNode.ID startNode = firstNode.Key
} }
} }
if startNode == "" { if startNode == "" {
@@ -61,24 +62,26 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim
inRecursionStack[v] = true // mark node as part of recursion stack inRecursionStack[v] = true // mark node as part of recursion stack
*timeVal++ *timeVal++
discoveryTime[v] = *timeVal discoveryTime[v] = *timeVal
node, _ := tm.nodes.Get(v) node := tm.nodes[v]
hasCycle := false hasCycle := false
var err error var err error
for _, edge := range node.Edges { for _, edge := range node.Edges {
if !visited[edge.To.ID] { for _, adj := range edge.To {
builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, edge.To.ID)) if !visited[adj.Key] {
hasCycle, err := tm.dfs(edge.To.ID, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key))
hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
if err != nil { if err != nil {
return true, err return true, err
} }
if hasCycle { if hasCycle {
return true, nil return true, nil
} }
} else if inRecursionStack[edge.To.ID] { } else if inRecursionStack[adj.Key] {
cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, edge.To.ID) cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key)
return true, fmt.Errorf(cycleMsg) return true, fmt.Errorf(cycleMsg)
} }
} }
}
hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder) hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
if err != nil { if err != nil {
return true, err return true, err
@@ -90,20 +93,20 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim
} }
func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) { func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) {
node, _ := tm.nodes.Get(v) node := tm.nodes[v]
for when, then := range tm.conditions[node.ID] { for when, then := range tm.conditions[FromNode(node.Key)] {
if targetNode, ok := tm.nodes.Get(then); ok { if targetNode, ok := tm.nodes[string(then)]; ok {
if !visited[targetNode.ID] { if !visited[targetNode.Key] {
builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID)) builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key))
hasCycle, err := tm.dfs(targetNode.ID, visited, discoveryTime, finishedTime, time, inRecursionStack, builder) hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder)
if err != nil { if err != nil {
return true, err return true, err
} }
if hasCycle { if hasCycle {
return true, nil return true, nil
} }
} else if inRecursionStack[targetNode.ID] { } else if inRecursionStack[targetNode.Key] {
cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.ID) cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)
return true, fmt.Errorf(cycleMsg) return true, fmt.Errorf(cycleMsg)
} }
} }
@@ -143,113 +146,98 @@ func (tm *DAG) ExportDOT() string {
var sb strings.Builder var sb strings.Builder
sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name)) sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name))
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` label="Enhanced DAG Representation";`) sb.WriteString(fmt.Sprintf(` label="%s";`, tm.name))
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` labelloc="t"; fontsize=22; fontname="Helvetica";`) sb.WriteString(` labelloc="t";`)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` node [shape=box, fontname="Helvetica", fillcolor="#B3CDE0", fontcolor="#2C3E50", fontsize=10, margin="0.25,0.15", style="rounded,filled"];`) sb.WriteString(` fontsize=20;`)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` edge [fontname="Helvetica", fontsize=12, arrowsize=0.8];`) sb.WriteString(` node [shape=box, style="rounded,filled", fillcolor="lightgray", fontname="Arial", margin="0.2,0.1"];`)
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` rankdir=TB;`) sb.WriteString(` edge [fontname="Arial", fontsize=12, arrowsize=0.8];`)
sb.WriteString("\n")
sb.WriteString(` size="10,10";`)
sb.WriteString("\n")
sb.WriteString(` ratio="fill";`)
sb.WriteString("\n") sb.WriteString("\n")
sortedNodes := tm.TopologicalSort() sortedNodes := tm.TopologicalSort()
for _, nodeKey := range sortedNodes { for _, nodeKey := range sortedNodes {
node, _ := tm.nodes.Get(nodeKey) node := tm.nodes[nodeKey]
nodeColor := "lightgray" nodeColor := "lightblue"
nodeShape := "box" sb.WriteString(fmt.Sprintf(` "%s" [label=" %s", fillcolor="%s", id="node_%s"];`, node.Key, node.Name, nodeColor, node.Key))
labelSuffix := ""
// Apply styles based on NodeType
switch node.NodeType {
case Function:
nodeColor = "#D4EDDA"
labelSuffix = " [Function]"
case Page:
nodeColor = "#F08080"
labelSuffix = " [Page]"
}
sb.WriteString(fmt.Sprintf(
` "%s" [label="%s%s", fontcolor="#2C3E50", fillcolor="%s", shape="%s", style="rounded,filled", id="node_%s"];`,
node.ID, node.Label, labelSuffix, nodeColor, nodeShape, node.ID))
sb.WriteString("\n") sb.WriteString("\n")
} }
// Define edges with unique styling by EdgeType
for _, nodeKey := range sortedNodes { for _, nodeKey := range sortedNodes {
node, _ := tm.nodes.Get(nodeKey) node := tm.nodes[nodeKey]
for _, edge := range node.Edges { for _, edge := range node.Edges {
edgeStyle := "solid" var edgeStyle string
edgeColor := "black"
labelSuffix := ""
// Apply styles based on EdgeType
switch edge.Type { switch edge.Type {
case Iterator: case Iterator:
edgeStyle = "dashed" edgeStyle = "dashed"
edgeColor = "blue" default:
labelSuffix = " [Iter]"
case Simple:
edgeStyle = "solid" edgeStyle = "solid"
edgeColor = "black"
labelSuffix = ""
} }
sb.WriteString(fmt.Sprintf( edgeColor := "black"
` "%s" -> "%s" [label="%s%s", color="%s", style="%s"];`, for _, to := range edge.To {
node.ID, edge.To.ID, edge.Label, labelSuffix, edgeColor, edgeStyle)) sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="%s", style=%s, fontsize=10, arrowsize=0.6];`, node.Key, to.Key, edge.Label, edgeColor, edgeStyle))
sb.WriteString("\n") sb.WriteString("\n")
} }
} }
}
for fromNodeKey, conditions := range tm.conditions { for fromNodeKey, conditions := range tm.conditions {
for when, then := range conditions { for when, then := range conditions {
if toNode, ok := tm.nodes.Get(then); ok { if toNode, ok := tm.nodes[string(then)]; ok {
sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.ID, when)) sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="purple", style=dotted, fontsize=10, arrowsize=0.6];`, fromNodeKey, toNode.Key, when))
sb.WriteString("\n") sb.WriteString("\n")
} }
} }
} }
// Optional: Group related nodes into subgraphs (e.g., loops)
for _, nodeKey := range sortedNodes { for _, nodeKey := range sortedNodes {
node, _ := tm.nodes.Get(nodeKey) node := tm.nodes[nodeKey]
if node.processor != nil { if node.processor != nil {
subDAG, _ := isDAGNode(node) subDAG, _ := isDAGNode(node)
if subDAG != nil { if subDAG != nil {
sb.WriteString(fmt.Sprintf(` subgraph "cluster_%s" {`, subDAG.name)) sb.WriteString(fmt.Sprintf(` subgraph "cluster_%s" {`, subDAG.name))
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(fmt.Sprintf(` label="Subgraph: %s";`, subDAG.name)) sb.WriteString(fmt.Sprintf(` label=" %s";`, subDAG.name))
sb.WriteString("\n") sb.WriteString("\n")
sb.WriteString(` style=filled; color=gray90;`) sb.WriteString(` style=dashed;`)
sb.WriteString("\n") sb.WriteString("\n")
subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool { sb.WriteString(` bgcolor="lightgray";`)
sb.WriteString(fmt.Sprintf(` "%s" [label="%s"];`, subNode.ID, subNode.Label))
sb.WriteString("\n") sb.WriteString("\n")
return true sb.WriteString(` node [shape=rectangle, style="filled", fillcolor="lightblue", fontname="Arial", margin="0.2,0.1"];`)
}) sb.WriteString("\n")
subDAG.nodes.ForEach(func(subNodeKey string, subNode *Node) bool { for subNodeKey, subNode := range subDAG.nodes {
sb.WriteString(fmt.Sprintf(` "%s" [label=" %s"];`, subNodeKey, subNode.Name))
sb.WriteString("\n")
}
for subNodeKey, subNode := range subDAG.nodes {
for _, edge := range subNode.Edges { for _, edge := range subNode.Edges {
sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label="%s"];`, subNodeKey, edge.To.ID, edge.Label)) for _, to := range edge.To {
sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, subNodeKey, to.Key, edge.Label))
sb.WriteString("\n") sb.WriteString("\n")
} }
return true }
}) }
sb.WriteString(" }\n") sb.WriteString(` }`)
sb.WriteString("\n")
sb.WriteString(fmt.Sprintf(` "%s" -> "%s" [label=" %s", color="black", style=solid, arrowsize=0.6];`, node.Key, subDAG.startNode, subDAG.name))
sb.WriteString("\n")
} }
} }
} }
sb.WriteString(`}`)
sb.WriteString("}\n") sb.WriteString("\n")
return sb.String() return sb.String()
} }
func (tm *DAG) TopologicalSort() (stack []string) { func (tm *DAG) TopologicalSort() (stack []string) {
visited := make(map[string]bool) visited := make(map[string]bool)
tm.nodes.ForEach(func(_ string, node *Node) bool { for _, node := range tm.nodes {
if !visited[node.ID] { if !visited[node.Key] {
tm.topologicalSortUtil(node.ID, visited, &stack) tm.topologicalSortUtil(node.Key, visited, &stack)
}
} }
return true
})
for i, j := 0, len(stack)-1; i < j; i, j = i+1, j-1 { for i, j := 0, len(stack)-1; i < j; i, j = i+1, j-1 {
stack[i], stack[j] = stack[j], stack[i] stack[i], stack[j] = stack[j], stack[i]
} }
@@ -258,23 +246,13 @@ func (tm *DAG) TopologicalSort() (stack []string) {
func (tm *DAG) topologicalSortUtil(v string, visited map[string]bool, stack *[]string) { func (tm *DAG) topologicalSortUtil(v string, visited map[string]bool, stack *[]string) {
visited[v] = true visited[v] = true
node, ok := tm.nodes.Get(v) node := tm.nodes[v]
if !ok {
fmt.Println("Not found", v)
}
for _, edge := range node.Edges { for _, edge := range node.Edges {
if !visited[edge.To.ID] { for _, to := range edge.To {
tm.topologicalSortUtil(edge.To.ID, visited, stack) if !visited[to.Key] {
tm.topologicalSortUtil(to.Key, visited, stack)
}
} }
} }
*stack = append(*stack, v) *stack = append(*stack, v)
} }
func isDAGNode(node *Node) (*DAG, bool) {
switch node := node.processor.(type) {
case *DAG:
return node, true
default:
return nil, false
}
}

View File

@@ -1,4 +1,4 @@
package dag package v1
import ( import (
"context" "context"

View File

@@ -1,4 +1,4 @@
package dag package v1
import ( import (
"sync" "sync"

View File

@@ -1,4 +1,4 @@
package v2 package v1
import ( import (
"encoding/json" "encoding/json"

View File

@@ -1,152 +0,0 @@
package v2
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/sio"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/jsonparser"
)
func renderNotFound(w http.ResponseWriter) {
html := []byte(`
<div>
<h1>task not found</h1>
<p><a href="/process">Back to home</a></p>
</div>
`)
w.Header().Set(consts.ContentType, consts.TypeHtml)
w.Write(html)
}
func (tm *DAG) render(w http.ResponseWriter, r *http.Request) {
ctx, data, err := parse(r)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
accept := r.Header.Get("Accept")
userCtx := UserContext(ctx)
ctx = context.WithValue(ctx, "method", r.Method)
if r.Method == "GET" && userCtx.Get("task_id") != "" {
manager, ok := tm.taskManager.Get(userCtx.Get("task_id"))
if !ok || manager == nil {
if strings.Contains(accept, "text/html") || accept == "" {
renderNotFound(w)
return
}
http.Error(w, fmt.Sprintf(`{"message": "%s"}`, "task not found"), http.StatusInternalServerError)
return
}
}
result := tm.Process(ctx, data)
if result.Error != nil {
http.Error(w, fmt.Sprintf(`{"message": "%s"}`, result.Error.Error()), http.StatusInternalServerError)
return
}
contentType, ok := result.Ctx.Value(consts.ContentType).(string)
if !ok {
contentType = consts.TypeJson
}
switch contentType {
case consts.TypeHtml:
w.Header().Set(consts.ContentType, consts.TypeHtml)
data, err := jsonparser.GetString(result.Payload, "html_content")
if err != nil {
return
}
w.Write([]byte(data))
default:
if r.Method != "POST" {
http.Error(w, `{"message": "not allowed"}`, http.StatusMethodNotAllowed)
return
}
w.Header().Set(consts.ContentType, consts.TypeJson)
json.NewEncoder(w).Encode(result.Payload)
}
}
func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) {
taskID := r.URL.Query().Get("taskID")
if taskID == "" {
http.Error(w, `{"message": "taskID is missing"}`, http.StatusBadRequest)
return
}
manager, ok := tm.taskManager.Get(taskID)
if !ok {
http.Error(w, `{"message": "Invalid TaskID"}`, http.StatusNotFound)
return
}
result := make(map[string]TaskState)
manager.taskStates.ForEach(func(key string, value *TaskState) bool {
key = strings.Split(key, Delimiter)[0]
nodeID := strings.Split(value.NodeID, Delimiter)[0]
rs := jsonparser.Delete(value.Result.Payload, "html_content")
status := value.Status
if status == mq.Processing {
status = mq.Completed
}
state := TaskState{
NodeID: nodeID,
Status: status,
UpdatedAt: value.UpdatedAt,
Result: mq.Result{
Payload: rs,
Error: value.Result.Error,
Status: status,
},
}
result[key] = state
return true
})
w.Header().Set(consts.ContentType, consts.TypeJson)
json.NewEncoder(w).Encode(result)
}
func (tm *DAG) SetupWS() *sio.Server {
ws := sio.New(sio.Config{
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: true,
})
WsEvents(ws)
tm.Notifier = ws
return ws
}
func (tm *DAG) Handlers() {
http.Handle("/", http.FileServer(http.Dir("webroot")))
http.Handle("/notify", tm.SetupWS())
http.HandleFunc("/process", tm.render)
http.HandleFunc("/request", tm.render)
http.HandleFunc("/task/status", tm.taskStatusHandler)
http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, tm.ExportDOT())
})
http.HandleFunc("/ui", func(w http.ResponseWriter, r *http.Request) {
image := fmt.Sprintf("%s.svg", mq.NewID())
err := tm.SaveSVG(image)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer os.Remove(image)
svgBytes, err := os.ReadFile(image)
if err != nil {
http.Error(w, "Could not read SVG file", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "image/svg+xml")
if _, err := w.Write(svgBytes); err != nil {
http.Error(w, "Could not write SVG response", http.StatusInternalServerError)
return
}
})
}

View File

@@ -1,28 +0,0 @@
package v2
import "time"
const (
Delimiter = "___"
ContextIndex = "index"
DefaultChannelSize = 1000
RetryInterval = 5 * time.Second
)
type NodeType int
func (c NodeType) IsValid() bool { return c >= Function && c <= Page }
const (
Function NodeType = iota
Page
)
type EdgeType int
func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
const (
Simple EdgeType = iota
Iterator
)

View File

@@ -1,524 +0,0 @@
package v2
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
"golang.org/x/time/rate"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/sio"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
)
type Node struct {
NodeType NodeType
Label string
ID string
Edges []Edge
processor mq.Processor
isReady bool
}
type Edge struct {
From *Node
To *Node
Type EdgeType
Label string
}
type DAG struct {
server *mq.Broker
consumer *mq.Consumer
nodes storage.IMap[string, *Node]
taskManager storage.IMap[string, *TaskManager]
iteratorNodes storage.IMap[string, []Edge]
finalResult func(taskID string, result mq.Result)
pool *mq.Pool
name string
key string
startNode string
opts []mq.Option
conditions map[string]map[string]string
consumerTopic string
reportNodeResultCallback func(mq.Result)
Error error
Notifier *sio.Server
paused bool
report string
}
func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.Result), opts ...mq.Option) *DAG {
callback := func(ctx context.Context, result mq.Result) error { return nil }
d := &DAG{
name: name,
key: key,
nodes: memory.New[string, *Node](),
taskManager: memory.New[string, *TaskManager](),
iteratorNodes: memory.New[string, []Edge](),
conditions: make(map[string]map[string]string),
finalResult: finalResultCallback,
}
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
d.server = mq.NewBroker(opts...)
d.opts = opts
options := d.server.Options()
d.pool = mq.NewPool(
options.NumOfWorkers(),
mq.WithTaskQueueSize(options.QueueSize()),
mq.WithMaxMemoryLoad(options.MaxMemoryLoad()),
mq.WithHandler(d.ProcessTask),
mq.WithPoolCallback(callback),
mq.WithTaskStorage(options.Storage()),
)
d.pool.Start(d.server.Options().NumOfWorkers())
return d
}
func (tm *DAG) SetKey(key string) {
tm.key = key
}
func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
tm.reportNodeResultCallback = callback
}
func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" {
manager.onNodeCompleted(nodeResult{
ctx: ctx,
nodeID: result.Topic,
status: result.Status,
result: result,
})
}
return mq.Result{}
}
func (tm *DAG) GetType() string {
return tm.key
}
func (tm *DAG) Consume(ctx context.Context) error {
if tm.consumer != nil {
tm.server.Options().SetSyncMode(true)
return tm.consumer.Consume(ctx)
}
return nil
}
func (tm *DAG) Stop(ctx context.Context) error {
tm.nodes.ForEach(func(_ string, n *Node) bool {
err := n.processor.Stop(ctx)
if err != nil {
return false
}
return true
})
return nil
}
func (tm *DAG) GetKey() string {
return tm.key
}
func (tm *DAG) AssignTopic(topic string) {
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
tm.consumerTopic = topic
}
func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
if tm.consumer != nil {
result.Topic = tm.consumerTopic
if tm.consumer.Conn() == nil {
tm.onTaskCallback(ctx, result)
} else {
tm.consumer.OnResponse(ctx, result)
}
}
}
func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
if node, ok := tm.nodes.Get(topic); ok {
log.Printf("DAG - CONSUMER ~> ready on %s", topic)
node.isReady = true
}
}
func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) {
if node, ok := tm.nodes.Get(topic); ok {
log.Printf("DAG - CONSUMER ~> down on %s", topic)
node.isReady = false
}
}
func (tm *DAG) Pause(_ context.Context) error {
tm.paused = true
return nil
}
func (tm *DAG) Resume(_ context.Context) error {
tm.paused = false
return nil
}
func (tm *DAG) Close() error {
var err error
tm.nodes.ForEach(func(_ string, n *Node) bool {
err = n.processor.Close()
if err != nil {
return false
}
return true
})
return nil
}
func (tm *DAG) SetStartNode(node string) {
tm.startNode = node
}
func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
tm.server.SetNotifyHandler(callback)
}
func (tm *DAG) GetStartNode() string {
return tm.startNode
}
func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG {
tm.conditions[fromNode] = conditions
return tm
}
func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG {
if tm.Error != nil {
return tm
}
con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask)
n := &Node{
Label: name,
ID: nodeID,
NodeType: nodeType,
processor: con,
}
if tm.server != nil && tm.server.SyncMode() {
n.isReady = true
}
tm.nodes.Set(nodeID, n)
if len(startNode) > 0 && startNode[0] {
tm.startNode = nodeID
}
return tm
}
func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error {
if tm.server.SyncMode() {
return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
}
tm.nodes.Set(key, &Node{
Label: name,
ID: key,
NodeType: nodeType,
})
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
return nil
}
func (tm *DAG) IsReady() bool {
var isReady bool
tm.nodes.ForEach(func(_ string, n *Node) bool {
if !n.isReady {
return false
}
isReady = true
return true
})
return isReady
}
func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG {
if tm.Error != nil {
return tm
}
if edgeType == Iterator {
tm.iteratorNodes.Set(from, []Edge{})
}
node, ok := tm.nodes.Get(from)
if !ok {
tm.Error = fmt.Errorf("node not found %s", from)
return tm
}
for _, target := range targets {
if targetNode, ok := tm.nodes.Get(target); ok {
edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label}
node.Edges = append(node.Edges, edge)
if edgeType != Iterator {
if edges, ok := tm.iteratorNodes.Get(node.ID); ok {
edges = append(edges, edge)
tm.iteratorNodes.Set(node.ID, edges)
}
}
}
}
return tm
}
func (tm *DAG) getCurrentNode(manager *TaskManager) string {
if manager.currentNodePayload.Size() == 0 {
return ""
}
return manager.currentNodePayload.Keys()[0]
}
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
ctx = context.WithValue(ctx, "task_id", task.ID)
userContext := UserContext(ctx)
next := userContext.Get("next")
manager, ok := tm.taskManager.Get(task.ID)
resultCh := make(chan mq.Result, 1)
if !ok {
manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone())
tm.taskManager.Set(task.ID, manager)
} else {
manager.resultCh = resultCh
}
currentKey := tm.getCurrentNode(manager)
currentNode := strings.Split(currentKey, Delimiter)[0]
node, exists := tm.nodes.Get(currentNode)
method, ok := ctx.Value("method").(string)
if method == "GET" && exists && node.NodeType == Page {
ctx = context.WithValue(ctx, "initial_node", currentNode)
/*
if isLastNode, err := tm.IsLastNode(currentNode); err != nil && isLastNode {
if manager.result != nil {
fmt.Println(string(manager.result.Payload))
resultCh <- *manager.result
return <-resultCh
}
}
*/
if manager.result != nil {
task.Payload = manager.result.Payload
}
} else if next == "true" {
nodes, err := tm.GetNextNodes(currentNode)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
if len(nodes) > 0 {
ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
}
}
if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
var taskPayload, resultPayload map[string]any
if err := json.Unmarshal(task.Payload, &taskPayload); err == nil {
if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
for key, val := range resultPayload {
taskPayload[key] = val
}
task.Payload, _ = json.Marshal(taskPayload)
}
}
}
firstNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
node, ok = tm.nodes.Get(firstNode)
if ok && node.NodeType != Page && task.Payload == nil {
return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx}
}
task.Topic = firstNode
ctx = context.WithValue(ctx, ContextIndex, "0")
manager.ProcessTask(ctx, firstNode, task.Payload)
return <-resultCh
}
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
var taskID string
userCtx := UserContext(ctx)
if val := userCtx.Get("task_id"); val != "" {
taskID = val
} else {
taskID = mq.NewID()
}
return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, ""))
}
func (tm *DAG) Validate() error {
report, hasCycle, err := tm.ClassifyEdges()
if hasCycle || err != nil {
tm.Error = err
return err
}
tm.report = report
return nil
}
func (tm *DAG) GetReport() string {
return tm.report
}
func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
dag.AssignTopic(key)
tm.nodes.Set(key, &Node{
Label: name,
ID: key,
processor: dag,
isReady: true,
})
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
return tm
}
func (tm *DAG) Start(ctx context.Context, addr string) error {
// Start the server in a separate goroutine
go func() {
defer mq.RecoverPanic(mq.RecoverTitle)
if err := tm.server.Start(ctx); err != nil {
panic(err)
}
}()
// Start the node consumers if not in sync mode
if !tm.server.SyncMode() {
tm.nodes.ForEach(func(_ string, con *Node) bool {
go func(con *Node) {
defer mq.RecoverPanic(mq.RecoverTitle)
limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
for {
err := con.processor.Consume(ctx)
if err != nil {
log.Printf("[ERROR] - Consumer %s failed to start: %v", con.ID, err)
} else {
log.Printf("[INFO] - Consumer %s started successfully", con.ID)
break
}
limiter.Wait(ctx) // Wait with rate limiting before retrying
}
}(con)
return true
})
}
log.Printf("DAG - HTTP_SERVER ~> started on http://%s", addr)
tm.Handlers()
config := tm.server.TLSConfig()
log.Printf("Server listening on http://%s", addr)
if config.UseTLS {
return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
}
return http.ListenAndServe(addr, nil)
}
func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
var taskID string
userCtx := UserContext(ctx)
if val := userCtx.Get("task_id"); val != "" {
taskID = val
} else {
taskID = mq.NewID()
}
t := mq.NewTask(taskID, payload, "")
ctx = context.WithValue(ctx, "task_id", taskID)
userContext := UserContext(ctx)
next := userContext.Get("next")
manager, ok := tm.taskManager.Get(taskID)
resultCh := make(chan mq.Result, 1)
if !ok {
manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone())
tm.taskManager.Set(taskID, manager)
} else {
manager.resultCh = resultCh
}
currentKey := tm.getCurrentNode(manager)
currentNode := strings.Split(currentKey, Delimiter)[0]
node, exists := tm.nodes.Get(currentNode)
method, ok := ctx.Value("method").(string)
if method == "GET" && exists && node.NodeType == Page {
ctx = context.WithValue(ctx, "initial_node", currentNode)
} else if next == "true" {
nodes, err := tm.GetNextNodes(currentNode)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
if len(nodes) > 0 {
ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
}
}
if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
var taskPayload, resultPayload map[string]any
if err := json.Unmarshal(payload, &taskPayload); err == nil {
if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
for key, val := range resultPayload {
taskPayload[key] = val
}
payload, _ = json.Marshal(taskPayload)
}
}
}
firstNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
node, ok = tm.nodes.Get(firstNode)
if ok && node.NodeType != Page && t.Payload == nil {
return mq.Result{Error: fmt.Errorf("payload is required for node %s", firstNode), Ctx: ctx}
}
t.Topic = firstNode
ctx = context.WithValue(ctx, ContextIndex, "0")
headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background()
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.pool.Scheduler().AddTask(ctxx, t, opts...)
return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: "PENDING"}
}
func (tm *DAG) PauseConsumer(ctx context.Context, id string) {
tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE)
}
func (tm *DAG) ResumeConsumer(ctx context.Context, id string) {
tm.doConsumer(ctx, id, consts.CONSUMER_RESUME)
}
func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) {
if node, ok := tm.nodes.Get(id); ok {
switch action {
case consts.CONSUMER_PAUSE:
err := node.processor.Pause(ctx)
if err == nil {
node.isReady = false
log.Printf("[INFO] - Consumer %s paused successfully", node.ID)
} else {
log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.ID, err)
}
case consts.CONSUMER_RESUME:
err := node.processor.Resume(ctx)
if err == nil {
node.isReady = true
log.Printf("[INFO] - Consumer %s resumed successfully", node.ID)
} else {
log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.ID, err)
}
}
} else {
log.Printf("[WARNING] - Consumer %s not found", id)
}
}

View File

@@ -1,388 +0,0 @@
package v2
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
)
type TaskState struct {
NodeID string
Status mq.Status
UpdatedAt time.Time
Result mq.Result
targetResults storage.IMap[string, mq.Result]
}
func newTaskState(nodeID string) *TaskState {
return &TaskState{
NodeID: nodeID,
Status: mq.Pending,
UpdatedAt: time.Now(),
targetResults: memory.New[string, mq.Result](),
}
}
type nodeResult struct {
ctx context.Context
nodeID string
status mq.Status
result mq.Result
}
type TaskManager struct {
taskStates storage.IMap[string, *TaskState]
parentNodes storage.IMap[string, string]
childNodes storage.IMap[string, int]
deferredTasks storage.IMap[string, *task]
iteratorNodes storage.IMap[string, []Edge]
currentNodePayload storage.IMap[string, json.RawMessage]
currentNodeResult storage.IMap[string, mq.Result]
result *mq.Result
dag *DAG
taskID string
taskQueue chan *task
resultQueue chan nodeResult
resultCh chan mq.Result
stopCh chan struct{}
}
type task struct {
ctx context.Context
taskID string
nodeID string
payload json.RawMessage
}
func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task {
return &task{
ctx: ctx,
taskID: taskID,
nodeID: nodeID,
payload: payload,
}
}
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
tm := &TaskManager{
taskStates: memory.New[string, *TaskState](),
parentNodes: memory.New[string, string](),
childNodes: memory.New[string, int](),
deferredTasks: memory.New[string, *task](),
currentNodePayload: memory.New[string, json.RawMessage](),
currentNodeResult: memory.New[string, mq.Result](),
taskQueue: make(chan *task, DefaultChannelSize),
resultQueue: make(chan nodeResult, DefaultChannelSize),
iteratorNodes: iteratorNodes,
stopCh: make(chan struct{}),
resultCh: resultCh,
taskID: taskID,
dag: dag,
}
go tm.run()
go tm.waitForResult()
return tm
}
func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) {
tm.send(ctx, startNode, tm.taskID, payload)
}
func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
if index, ok := ctx.Value(ContextIndex).(string); ok {
startNode = strings.Split(startNode, Delimiter)[0]
startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index)
}
if _, exists := tm.taskStates.Get(startNode); !exists {
tm.taskStates.Set(startNode, newTaskState(startNode))
}
t := newTask(ctx, taskID, startNode, payload)
select {
case tm.taskQueue <- t:
default:
log.Println("task queue is full, dropping task.")
tm.deferredTasks.Set(taskID, t)
}
}
func (tm *TaskManager) run() {
for {
select {
case <-tm.stopCh:
log.Println("Stopping TaskManager")
return
case task := <-tm.taskQueue:
tm.processNode(task)
}
}
}
func (tm *TaskManager) waitForResult() {
for {
select {
case <-tm.stopCh:
log.Println("Stopping Result Listener")
return
case nr := <-tm.resultQueue:
tm.onNodeCompleted(nr)
}
}
}
func (tm *TaskManager) processNode(exec *task) {
pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
node, exists := tm.dag.nodes.Get(pureNodeID)
if !exists {
log.Printf("Node %s does not exist while processing node\n", pureNodeID)
return
}
state, _ := tm.taskStates.Get(exec.nodeID)
if state == nil {
log.Printf("State for node %s not found; creating new state.\n", exec.nodeID)
state = newTaskState(exec.nodeID)
tm.taskStates.Set(exec.nodeID, state)
}
state.Status = mq.Processing
state.UpdatedAt = time.Now()
tm.currentNodePayload.Clear()
tm.currentNodeResult.Clear()
tm.currentNodePayload.Set(exec.nodeID, exec.payload)
result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
tm.currentNodeResult.Set(exec.nodeID, result)
state.Result = result
result.Topic = node.ID
if result.Error != nil {
tm.result = &result
tm.resultCh <- result
tm.processFinalResult(state)
return
}
if node.NodeType == Page {
tm.result = &result
tm.resultCh <- result
return
}
tm.handleNext(exec.ctx, node, state, result)
}
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
state.targetResults.Set(childNode, result)
state.targetResults.Del(state.NodeID)
targetsCount, _ := tm.childNodes.Get(state.NodeID)
size := state.targetResults.Size()
nodeID := strings.Split(state.NodeID, Delimiter)
if size == targetsCount {
if size > 1 {
aggregatedData := make([]json.RawMessage, size)
i := 0
state.targetResults.ForEach(func(_ string, rs mq.Result) bool {
aggregatedData[i] = rs.Payload
i++
return true
})
aggregatedPayload, err := json.Marshal(aggregatedData)
if err != nil {
panic(err)
}
state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID}
} else if size == 1 {
state.Result = state.targetResults.Values()[0]
}
state.Status = result.Status
state.Result.Status = result.Status
}
if state.Result.Payload == nil {
state.Result.Payload = result.Payload
}
state.UpdatedAt = time.Now()
if result.Ctx == nil {
result.Ctx = ctx
}
if result.Error != nil {
state.Status = mq.Failed
}
pn, ok := tm.parentNodes.Get(state.NodeID)
if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed {
state.Status = mq.Processing
tm.iteratorNodes.Del(nodeID[0])
state.targetResults.Clear()
if len(nodeID) == 2 {
ctx = context.WithValue(ctx, ContextIndex, nodeID[1])
}
toProcess := nodeResult{
ctx: ctx,
nodeID: state.NodeID,
status: state.Status,
result: state.Result,
}
tm.handleEdges(toProcess, edges)
} else if ok {
if targetsCount == size {
parentState, _ := tm.taskStates.Get(pn)
if parentState != nil {
state.Result.Topic = state.NodeID
tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
}
}
} else {
tm.result = &state.Result
state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0]
tm.resultCh <- state.Result
tm.processFinalResult(state)
}
}
func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) {
state.UpdatedAt = time.Now()
if result.Ctx == nil {
result.Ctx = ctx
}
if result.Error != nil {
state.Status = mq.Failed
} else {
edges := tm.getConditionalEdges(node, result)
if len(edges) == 0 {
state.Status = mq.Completed
}
}
if result.Status == "" {
result.Status = state.Status
}
select {
case tm.resultQueue <- nodeResult{
ctx: ctx,
nodeID: state.NodeID,
result: result,
status: state.Status,
}:
default:
log.Println("Result queue is full, dropping result.")
}
}
func (tm *TaskManager) onNodeCompleted(rs nodeResult) {
nodeID := strings.Split(rs.nodeID, Delimiter)[0]
node, ok := tm.dag.nodes.Get(nodeID)
if !ok {
return
}
edges := tm.getConditionalEdges(node, rs.result)
hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0
if hasErrorOrCompleted {
if index, ok := rs.ctx.Value(ContextIndex).(string); ok {
childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index)
pn, ok := tm.parentNodes.Get(childNode)
if ok {
parentState, _ := tm.taskStates.Get(pn)
if parentState != nil {
pn = strings.Split(pn, Delimiter)[0]
tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true)
}
}
}
return
}
tm.handleEdges(rs, edges)
}
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
edges := make([]Edge, len(node.Edges))
copy(edges, node.Edges)
if result.ConditionStatus != "" {
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
if targetNodeKey, ok := conditions[result.ConditionStatus]; ok {
if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
edges = append(edges, Edge{From: node, To: targetNode})
}
} else if targetNodeKey, ok = conditions["default"]; ok {
if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
edges = append(edges, Edge{From: node, To: targetNode})
}
}
}
}
return edges
}
func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
for _, edge := range edges {
index, ok := currentResult.ctx.Value(ContextIndex).(string)
if !ok {
index = "0"
}
parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index)
if edge.Type == Simple {
if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok {
continue
}
}
if edge.Type == Iterator {
var items []json.RawMessage
err := json.Unmarshal(currentResult.result.Payload, &items)
if err != nil {
log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err)
tm.resultQueue <- nodeResult{
ctx: currentResult.ctx,
nodeID: edge.To.ID,
status: mq.Failed,
result: mq.Result{Error: err},
}
return
}
tm.childNodes.Set(parentNode, len(items))
for i, item := range items {
childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i)
ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i))
tm.parentNodes.Set(childNode, parentNode)
tm.send(ctx, edge.To.ID, tm.taskID, item)
}
} else {
tm.childNodes.Set(parentNode, 1)
idx, ok := currentResult.ctx.Value(ContextIndex).(string)
if !ok {
idx = "0"
}
childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
tm.parentNodes.Set(childNode, parentNode)
tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
}
}
}
func (tm *TaskManager) retryDeferredTasks() {
const maxRetries = 5
retries := 0
for retries < maxRetries {
select {
case <-tm.stopCh:
log.Println("Stopping Deferred task Retrier")
return
case <-time.After(RetryInterval):
tm.deferredTasks.ForEach(func(taskID string, task *task) bool {
tm.send(task.ctx, task.nodeID, taskID, task.payload)
retries++
return true
})
}
}
}
func (tm *TaskManager) processFinalResult(state *TaskState) {
state.targetResults.Clear()
if tm.dag.finalResult != nil {
tm.dag.finalResult(tm.taskID, state.Result)
}
}
func (tm *TaskManager) Stop() {
close(tm.stopCh)
}

View File

@@ -11,7 +11,6 @@ func WsEvents(s *sio.Server) {
} }
func join(s *sio.Socket, data []byte) { func join(s *sio.Socket, data []byte) {
//just one room at a time for the simple example
currentRooms := s.GetRooms() currentRooms := s.GetRooms()
for _, room := range currentRooms { for _, room := range currentRooms {
s.Leave(room) s.Leave(room)

View File

@@ -4,14 +4,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
v2 "github.com/oarkflow/mq/dag/v2"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/examples/tasks"
) )
func main() { func main() {
f := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { f := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload))
}, mq.WithSyncMode(true)) }, mq.WithSyncMode(true))
f.SetNotifyResponse(func(ctx context.Context, result mq.Result) error { f.SetNotifyResponse(func(ctx context.Context, result mq.Result) error {
@@ -25,39 +24,38 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
f.Start(context.Background(), ":8083") f.Start(context.Background(), ":8082")
sendData(f) sendData(f)
} }
func subDAG() *v2.DAG { func subDAG() *dag.DAG {
f := v2.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) { f := dag.NewDAG("Sub DAG", "sub-dag", func(taskID string, result mq.Result) {
fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload))
}, mq.WithSyncMode(true)) }, mq.WithSyncMode(true))
f. f.
AddNode(v2.Function, "Store data", "store:data", &tasks.StoreData{Operation: v2.Operation{Type: "process"}}, true). AddNode(dag.Function, "Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: "process"}}, true).
AddNode(v2.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: v2.Operation{Type: "process"}}). AddNode(dag.Function, "Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: "process"}}).
AddNode(v2.Function, "Notification", "notification", &tasks.InAppNotification{Operation: v2.Operation{Type: "process"}}). AddNode(dag.Function, "Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: "process"}}).
AddEdge(v2.Simple, "Store Payload to send sms", "store:data", "send:sms"). AddEdge(dag.Simple, "Store Payload to send sms", "store:data", "send:sms").
AddEdge(v2.Simple, "Store Payload to notification", "send:sms", "notification") AddEdge(dag.Simple, "Store Payload to notification", "send:sms", "notification")
return f return f
} }
func setup(f *v2.DAG) { func setup(f *dag.DAG) {
f. f.
AddNode(v2.Function, "Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: v2.Operation{Type: "process"}}). AddNode(dag.Function, "Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: dag.Operation{Type: "process"}}).
AddNode(v2.Function, "Prepare Email", "prepare:email", &tasks.PrepareEmail{Operation: v2.Operation{Type: "process"}}). AddNode(dag.Function, "Prepare Email", "prepare:email", &tasks.PrepareEmail{Operation: dag.Operation{Type: "process"}}).
AddNode(v2.Function, "Get Input", "get:input", &tasks.GetData{Operation: v2.Operation{Type: "input"}}, true). AddNode(dag.Function, "Get Input", "get:input", &tasks.GetData{Operation: dag.Operation{Type: "input"}}, true).
AddNode(v2.Function, "Final Payload", "final", &tasks.Final{Operation: v2.Operation{Type: "page"}}). AddNode(dag.Function, "Iterator Processor", "loop", &tasks.Loop{Operation: dag.Operation{Type: "loop"}}).
AddNode(v2.Function, "Iterator Processor", "loop", &tasks.Loop{Operation: v2.Operation{Type: "loop"}}). AddNode(dag.Function, "Condition", "condition", &tasks.Condition{Operation: dag.Operation{Type: "condition"}}).
AddNode(v2.Function, "Condition", "condition", &tasks.Condition{Operation: v2.Operation{Type: "condition"}}).
AddDAGNode("Persistent", "persistent", subDAG()). AddDAGNode("Persistent", "persistent", subDAG()).
AddEdge(v2.Simple, "Get input to loop", "get:input", "loop"). AddEdge(dag.Simple, "Get input to loop", "get:input", "loop").
AddEdge(v2.Iterator, "Loop to prepare email", "loop", "prepare:email"). AddEdge(dag.Iterator, "Loop to prepare email", "loop", "prepare:email").
AddEdge(v2.Simple, "Prepare Email to condition", "prepare:email", "condition"). AddEdge(dag.Simple, "Prepare Email to condition", "prepare:email", "condition").
AddCondition("condition", map[string]string{"pass": "email:deliver", "fail": "persistent"}) AddCondition("condition", map[string]string{"pass": "email:deliver", "fail": "persistent"})
} }
func sendData(f *v2.DAG) { func sendData(f *dag.DAG) {
data := []map[string]any{ data := []map[string]any{
{"phone": "+123456789", "email": "abc.xyz@gmail.com"}, {"phone": "+98765412", "email": "xyz.abc@gmail.com"}, {"phone": "+123456789", "email": "abc.xyz@gmail.com"}, {"phone": "+98765412", "email": "xyz.abc@gmail.com"},
} }

View File

@@ -3,31 +3,30 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
v2 "github.com/oarkflow/mq/dag/v2"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/examples/tasks"
) )
func main() { func main() {
d := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { d := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
fmt.Println("Final", string(result.Payload)) fmt.Println("Final", string(result.Payload))
}, },
mq.WithSyncMode(true), mq.WithSyncMode(true),
mq.WithNotifyResponse(tasks.NotifyResponse), mq.WithNotifyResponse(tasks.NotifyResponse),
) )
d.AddNode(v2.Function, "C", "C", &tasks.Node3{}, true) d.AddNode(dag.Function, "C", "C", &tasks.Node3{}, true)
d.AddNode(v2.Function, "D", "D", &tasks.Node4{}) d.AddNode(dag.Function, "D", "D", &tasks.Node4{})
d.AddNode(v2.Function, "E", "E", &tasks.Node5{}) d.AddNode(dag.Function, "E", "E", &tasks.Node5{})
d.AddNode(v2.Function, "F", "F", &tasks.Node6{}) d.AddNode(dag.Function, "F", "F", &tasks.Node6{})
d.AddNode(v2.Function, "G", "G", &tasks.Node7{}) d.AddNode(dag.Function, "G", "G", &tasks.Node7{})
d.AddNode(v2.Function, "H", "H", &tasks.Node8{}) d.AddNode(dag.Function, "H", "H", &tasks.Node8{})
d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"})
d.AddEdge(v2.Simple, "Label 1", "B", "C") d.AddEdge(dag.Simple, "Label 1", "B", "C")
d.AddEdge(v2.Simple, "Label 2", "D", "F") d.AddEdge(dag.Simple, "Label 2", "D", "F")
d.AddEdge(v2.Simple, "Label 3", "E", "F") d.AddEdge(dag.Simple, "Label 3", "E", "F")
d.AddEdge(v2.Simple, "Label 4", "F", "G", "H") d.AddEdge(dag.Simple, "Label 4", "F", "G", "H")
d.AssignTopic("queue") d.AssignTopic("queue")
err := d.Consume(context.Background()) err := d.Consume(context.Background())
if err != nil { if err != nil {

View File

@@ -4,25 +4,25 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/oarkflow/mq/dag"
"github.com/oarkflow/jet" "github.com/oarkflow/jet"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
v2 "github.com/oarkflow/mq/dag/v2"
) )
func main() { func main() {
flow := v2.NewDAG("Multi-Step Form", "multi-step-form", func(taskID string, result mq.Result) { flow := dag.NewDAG("Multi-Step Form", "multi-step-form", func(taskID string, result mq.Result) {
fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload))
}) })
flow.AddNode(v2.Page, "FormStep1", "FormStep1", &FormStep1{}) flow.AddNode(dag.Page, "FormStep1", "FormStep1", &FormStep1{})
flow.AddNode(v2.Page, "FormStep2", "FormStep2", &FormStep2{}) flow.AddNode(dag.Page, "FormStep2", "FormStep2", &FormStep2{})
flow.AddNode(v2.Page, "FormResult", "FormResult", &FormResult{}) flow.AddNode(dag.Page, "FormResult", "FormResult", &FormResult{})
// Define edges // Define edges
flow.AddEdge(v2.Simple, "FormStep1", "FormStep1", "FormStep2") flow.AddEdge(dag.Simple, "FormStep1", "FormStep1", "FormStep2")
flow.AddEdge(v2.Simple, "FormStep2", "FormStep2", "FormResult") flow.AddEdge(dag.Simple, "FormStep2", "FormStep2", "FormResult")
// Start the flow // Start the flow
if flow.Error != nil { if flow.Error != nil {
@@ -32,7 +32,7 @@ func main() {
} }
type FormStep1 struct { type FormStep1 struct {
v2.Operation dag.Operation
} }
func (p *FormStep1) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *FormStep1) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -68,7 +68,7 @@ func (p *FormStep1) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
type FormStep2 struct { type FormStep2 struct {
v2.Operation dag.Operation
} }
func (p *FormStep2) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *FormStep2) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -111,7 +111,7 @@ func (p *FormStep2) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
type FormResult struct { type FormResult struct {
v2.Operation dag.Operation
} }
func (p *FormResult) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *FormResult) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {

View File

@@ -3,21 +3,21 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/oarkflow/mq/dag/v1"
"github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/examples/tasks"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
) )
func main() { func main() {
d := dag.NewDAG( d := v1.NewDAG(
"Sample DAG", "Sample DAG",
"sample-dag", "sample-dag",
mq.WithSyncMode(true), mq.WithSyncMode(true),
mq.WithNotifyResponse(tasks.NotifyResponse), mq.WithNotifyResponse(tasks.NotifyResponse),
) )
subDag := dag.NewDAG( subDag := v1.NewDAG(
"Sub DAG", "Sub DAG",
"D", "D",
mq.WithNotifyResponse(tasks.NotifySubDAGResponse), mq.WithNotifyResponse(tasks.NotifySubDAGResponse),
@@ -35,7 +35,7 @@ func main() {
d.AddDAGNode("D", "D", subDag) d.AddDAGNode("D", "D", subDag)
d.AddNode("E", "E", &tasks.Node5{}) d.AddNode("E", "E", &tasks.Node5{})
d.AddIterator("Send each item", "A", "B") d.AddIterator("Send each item", "A", "B")
d.AddCondition("C", map[dag.When]dag.Then{"PASS": "D", "FAIL": "E"}) d.AddCondition("C", map[v1.When]v1.Then{"PASS": "D", "FAIL": "E"})
d.AddEdge("Label 1", "B", "C") d.AddEdge("Label 1", "B", "C")
fmt.Println(d.ExportDOT()) fmt.Println(d.ExportDOT())

View File

@@ -2,9 +2,8 @@ package tasks
import ( import (
"context" "context"
v2 "github.com/oarkflow/mq/dag/v2"
"github.com/oarkflow/json" "github.com/oarkflow/json"
v2 "github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
) )

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
v2 "github.com/oarkflow/mq/dag/v2" v2 "github.com/oarkflow/mq/dag"
"log" "log"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"

View File

@@ -5,16 +5,16 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
"os" "os"
"github.com/oarkflow/jet" "github.com/oarkflow/jet"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
v2 "github.com/oarkflow/mq/dag/v2"
) )
type Form struct { type Form struct {
v2.Operation dag.Operation
} }
func (p *Form) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *Form) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -38,7 +38,7 @@ func (p *Form) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
type NodeA struct { type NodeA struct {
v2.Operation dag.Operation
} }
func (p *NodeA) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *NodeA) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -52,7 +52,7 @@ func (p *NodeA) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
type NodeB struct { type NodeB struct {
v2.Operation dag.Operation
} }
func (p *NodeB) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *NodeB) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -66,7 +66,7 @@ func (p *NodeB) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
type NodeC struct { type NodeC struct {
v2.Operation dag.Operation
} }
func (p *NodeC) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *NodeC) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -80,7 +80,7 @@ func (p *NodeC) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
type Result struct { type Result struct {
v2.Operation dag.Operation
} }
func (p *Result) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *Result) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -115,16 +115,16 @@ func notify(taskID string, result mq.Result) {
} }
func main() { func main() {
flow := v2.NewDAG("Sample DAG", "sample-dag", notify) flow := dag.NewDAG("Sample DAG", "sample-dag", notify)
flow.AddNode(v2.Page, "Form", "Form", &Form{}) flow.AddNode(dag.Page, "Form", "Form", &Form{})
flow.AddNode(v2.Function, "NodeA", "NodeA", &NodeA{}) flow.AddNode(dag.Function, "NodeA", "NodeA", &NodeA{})
flow.AddNode(v2.Function, "NodeB", "NodeB", &NodeB{}) flow.AddNode(dag.Function, "NodeB", "NodeB", &NodeB{})
flow.AddNode(v2.Function, "NodeC", "NodeC", &NodeC{}) flow.AddNode(dag.Function, "NodeC", "NodeC", &NodeC{})
flow.AddNode(v2.Page, "Result", "Result", &Result{}) flow.AddNode(dag.Page, "Result", "Result", &Result{})
flow.AddEdge(v2.Simple, "Form", "Form", "NodeA") flow.AddEdge(dag.Simple, "Form", "Form", "NodeA")
flow.AddEdge(v2.Simple, "NodeA", "NodeA", "NodeB") flow.AddEdge(dag.Simple, "NodeA", "NodeA", "NodeB")
flow.AddEdge(v2.Simple, "NodeB", "NodeB", "NodeC") flow.AddEdge(dag.Simple, "NodeB", "NodeB", "NodeC")
flow.AddEdge(v2.Simple, "NodeC", "NodeC", "Result") flow.AddEdge(dag.Simple, "NodeC", "NodeC", "Result")
if flow.Error != nil { if flow.Error != nil {
panic(flow.Error) panic(flow.Error)
} }

View File

@@ -5,23 +5,23 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
v2 "github.com/oarkflow/mq/dag/v2" "github.com/oarkflow/mq/dag"
) )
func main() { func main() {
flow := v2.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
// fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload)) // fmt.Printf("Final result for task %s: %s\n", taskID, string(result.Payload))
}) })
flow.AddNode(v2.Function, "GetData", "GetData", &GetData{}, true) flow.AddNode(dag.Function, "GetData", "GetData", &GetData{}, true)
flow.AddNode(v2.Function, "Loop", "Loop", &Loop{}) flow.AddNode(dag.Function, "Loop", "Loop", &Loop{})
flow.AddNode(v2.Function, "ValidateAge", "ValidateAge", &ValidateAge{}) flow.AddNode(dag.Function, "ValidateAge", "ValidateAge", &ValidateAge{})
flow.AddNode(v2.Function, "ValidateGender", "ValidateGender", &ValidateGender{}) flow.AddNode(dag.Function, "ValidateGender", "ValidateGender", &ValidateGender{})
flow.AddNode(v2.Function, "Final", "Final", &Final{}) flow.AddNode(dag.Function, "Final", "Final", &Final{})
flow.AddEdge(v2.Simple, "GetData", "GetData", "Loop") flow.AddEdge(dag.Simple, "GetData", "GetData", "Loop")
flow.AddEdge(v2.Iterator, "Validate age for each item", "Loop", "ValidateAge") flow.AddEdge(dag.Iterator, "Validate age for each item", "Loop", "ValidateAge")
flow.AddCondition("ValidateAge", map[string]string{"pass": "ValidateGender"}) flow.AddCondition("ValidateAge", map[string]string{"pass": "ValidateGender"})
flow.AddEdge(v2.Simple, "Mark as Done", "Loop", "Final") flow.AddEdge(dag.Simple, "Mark as Done", "Loop", "Final")
// flow.Start(":8080") // flow.Start(":8080")
data := []byte(`[{"age": "15", "gender": "female"}, {"age": "18", "gender": "male"}]`) data := []byte(`[{"age": "15", "gender": "female"}, {"age": "18", "gender": "male"}]`)
@@ -38,7 +38,7 @@ func main() {
} }
type GetData struct { type GetData struct {
v2.Operation dag.Operation
} }
func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -46,7 +46,7 @@ func (p *GetData) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
type Loop struct { type Loop struct {
v2.Operation dag.Operation
} }
func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -54,7 +54,7 @@ func (p *Loop) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
type ValidateAge struct { type ValidateAge struct {
v2.Operation dag.Operation
} }
func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -73,7 +73,7 @@ func (p *ValidateAge) ProcessTask(ctx context.Context, task *mq.Task) mq.Result
} }
type ValidateGender struct { type ValidateGender struct {
v2.Operation dag.Operation
} }
func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
@@ -87,7 +87,7 @@ func (p *ValidateGender) ProcessTask(ctx context.Context, task *mq.Task) mq.Resu
} }
type Final struct { type Final struct {
v2.Operation dag.Operation
} }
func (p *Final) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (p *Final) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {