Files
mq/dag/dag.go
2025-02-17 20:39:57 +05:45

438 lines
11 KiB
Go

package dag
import (
"context"
"encoding/json"
"fmt"
"github.com/oarkflow/mq/logger"
"log"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/oarkflow/form"
"golang.org/x/time/rate"
"github.com/oarkflow/mq/sio"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
)
type Node struct {
processor mq.Processor
Label string
ID string
Edges []Edge
NodeType NodeType
isReady bool
}
type Edge struct {
From *Node
To *Node
Label string
Type EdgeType
}
type DAG struct {
nodes storage.IMap[string, *Node]
taskManager storage.IMap[string, *TaskManager]
iteratorNodes storage.IMap[string, []Edge]
Error error
conditions map[string]map[string]string
consumer *mq.Consumer
finalResult func(taskID string, result mq.Result)
pool *mq.Pool
Notifier *sio.Server
server *mq.Broker
reportNodeResultCallback func(mq.Result)
key string
consumerTopic string
startNode string
name string
report string
hasPageNode bool
paused bool
}
func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.Result), opts ...mq.Option) *DAG {
callback := func(ctx context.Context, result mq.Result) error { return nil }
d := &DAG{
name: name,
key: key,
nodes: memory.New[string, *Node](),
taskManager: memory.New[string, *TaskManager](),
iteratorNodes: memory.New[string, []Edge](),
conditions: make(map[string]map[string]string),
finalResult: finalResultCallback,
}
opts = append(opts,
mq.WithCallback(d.onTaskCallback),
mq.WithConsumerOnSubscribe(d.onConsumerJoin),
mq.WithConsumerOnClose(d.onConsumerClose),
)
d.server = mq.NewBroker(opts...)
options := d.server.Options()
d.pool = mq.NewPool(
options.NumOfWorkers(),
mq.WithTaskQueueSize(options.QueueSize()),
mq.WithMaxMemoryLoad(options.MaxMemoryLoad()),
mq.WithHandler(d.ProcessTask),
mq.WithPoolCallback(callback),
mq.WithTaskStorage(options.Storage()),
)
d.pool.Start(d.server.Options().NumOfWorkers())
return d
}
func (tm *DAG) SetKey(key string) {
tm.key = key
}
func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
tm.reportNodeResultCallback = callback
}
func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" {
manager.onNodeCompleted(nodeResult{
ctx: ctx,
nodeID: result.Topic,
status: result.Status,
result: result,
})
}
return mq.Result{}
}
func (tm *DAG) GetType() string {
return tm.key
}
func (tm *DAG) Stop(ctx context.Context) error {
tm.nodes.ForEach(func(_ string, n *Node) bool {
err := n.processor.Stop(ctx)
if err != nil {
return false
}
return true
})
return nil
}
func (tm *DAG) GetKey() string {
return tm.key
}
func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
tm.server.SetNotifyHandler(callback)
}
func (tm *DAG) SetStartNode(node string) {
tm.startNode = node
}
func (tm *DAG) GetStartNode() string {
return tm.startNode
}
func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG {
tm.conditions[fromNode] = conditions
return tm
}
func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG {
if tm.Error != nil {
return tm
}
con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask)
n := &Node{
Label: name,
ID: nodeID,
NodeType: nodeType,
processor: con,
}
if tm.server != nil && tm.server.SyncMode() {
n.isReady = true
}
tm.nodes.Set(nodeID, n)
if len(startNode) > 0 && startNode[0] {
tm.startNode = nodeID
}
if nodeType == Page && !tm.hasPageNode {
tm.hasPageNode = true
}
return tm
}
func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error {
if tm.server.SyncMode() {
return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
}
tm.nodes.Set(key, &Node{
Label: name,
ID: key,
NodeType: nodeType,
})
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
return nil
}
func (tm *DAG) IsReady() bool {
var isReady bool
tm.nodes.ForEach(func(_ string, n *Node) bool {
if !n.isReady {
return false
}
isReady = true
return true
})
return isReady
}
func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG {
if tm.Error != nil {
return tm
}
if edgeType == Iterator {
tm.iteratorNodes.Set(from, []Edge{})
}
node, ok := tm.nodes.Get(from)
if !ok {
tm.Error = fmt.Errorf("node not found %s", from)
return tm
}
for _, target := range targets {
if targetNode, ok := tm.nodes.Get(target); ok {
edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label}
node.Edges = append(node.Edges, edge)
if edgeType != Iterator {
if edges, ok := tm.iteratorNodes.Get(node.ID); ok {
edges = append(edges, edge)
tm.iteratorNodes.Set(node.ID, edges)
}
}
}
}
return tm
}
func (tm *DAG) getCurrentNode(manager *TaskManager) string {
if manager.currentNodePayload.Size() == 0 {
return ""
}
return manager.currentNodePayload.Keys()[0]
}
func (tm *DAG) Logger() logger.Logger {
return tm.server.Options().Logger()
}
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
ctx = context.WithValue(ctx, "task_id", task.ID)
userContext := form.UserContext(ctx)
next := userContext.Get("next")
manager, ok := tm.taskManager.Get(task.ID)
resultCh := make(chan mq.Result, 1)
if !ok {
manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone())
tm.taskManager.Set(task.ID, manager)
} else {
manager.resultCh = resultCh
}
currentKey := tm.getCurrentNode(manager)
currentNode := strings.Split(currentKey, Delimiter)[0]
node, exists := tm.nodes.Get(currentNode)
method, ok := ctx.Value("method").(string)
if method == "GET" && exists && node.NodeType == Page {
ctx = context.WithValue(ctx, "initial_node", currentNode)
if manager.result != nil {
task.Payload = manager.result.Payload
}
} else if next == "true" {
nodes, err := tm.GetNextNodes(currentNode)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
if len(nodes) > 0 {
ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
}
}
if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
var taskPayload, resultPayload map[string]any
if err := json.Unmarshal(task.Payload, &taskPayload); err == nil {
if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
for key, val := range resultPayload {
taskPayload[key] = val
}
task.Payload, _ = json.Marshal(taskPayload)
}
}
}
firstNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
node, ok = tm.nodes.Get(firstNode)
task.Topic = firstNode
ctx = context.WithValue(ctx, ContextIndex, "0")
manager.ProcessTask(ctx, firstNode, task.Payload)
if tm.hasPageNode {
return <-resultCh
}
select {
case result := <-resultCh:
return result
case <-time.After(30 * time.Second):
return mq.Result{
Error: fmt.Errorf("timeout waiting for task result"),
Ctx: ctx,
}
}
}
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
var taskID string
userCtx := form.UserContext(ctx)
if val := userCtx.Get("task_id"); val != "" {
taskID = val
} else {
taskID = mq.NewID()
}
return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, ""))
}
func (tm *DAG) Validate() error {
report, hasCycle, err := tm.ClassifyEdges()
if hasCycle || err != nil {
tm.Error = err
return err
}
tm.report = report
return nil
}
func (tm *DAG) GetReport() string {
return tm.report
}
func (tm *DAG) AddDAGNode(nodeType NodeType, name string, key string, dag *DAG, firstNode ...bool) *DAG {
dag.AssignTopic(key)
tm.nodes.Set(key, &Node{
Label: name,
ID: key,
NodeType: nodeType,
processor: dag,
isReady: true,
})
if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key
}
return tm
}
func (tm *DAG) Start(ctx context.Context, addr string) error {
go func() {
defer mq.RecoverPanic(mq.RecoverTitle)
if err := tm.server.Start(ctx); err != nil {
panic(err)
}
}()
if !tm.server.SyncMode() {
tm.nodes.ForEach(func(_ string, con *Node) bool {
go func(con *Node) {
defer mq.RecoverPanic(mq.RecoverTitle)
limiter := rate.NewLimiter(rate.Every(1*time.Second), 1)
for {
err := con.processor.Consume(ctx)
if err != nil {
log.Printf("[ERROR] - Consumer %s failed to start: %v", con.ID, err)
} else {
log.Printf("[INFO] - Consumer %s started successfully", con.ID)
break
}
limiter.Wait(ctx)
}
}(con)
return true
})
}
app := fiber.New()
app.Use(recover.New(recover.Config{
EnableStackTrace: true,
}))
tm.Handlers(app)
return app.Listen(addr)
}
func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
var taskID string
userCtx := form.UserContext(ctx)
if val := userCtx.Get("task_id"); val != "" {
taskID = val
} else {
taskID = mq.NewID()
}
t := mq.NewTask(taskID, payload, "")
ctx = context.WithValue(ctx, "task_id", taskID)
userContext := form.UserContext(ctx)
next := userContext.Get("next")
manager, ok := tm.taskManager.Get(taskID)
resultCh := make(chan mq.Result, 1)
if !ok {
manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone())
tm.taskManager.Set(taskID, manager)
} else {
manager.resultCh = resultCh
}
currentKey := tm.getCurrentNode(manager)
currentNode := strings.Split(currentKey, Delimiter)[0]
node, exists := tm.nodes.Get(currentNode)
method, ok := ctx.Value("method").(string)
if method == "GET" && exists && node.NodeType == Page {
ctx = context.WithValue(ctx, "initial_node", currentNode)
} else if next == "true" {
nodes, err := tm.GetNextNodes(currentNode)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
if len(nodes) > 0 {
ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
}
}
if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
var taskPayload, resultPayload map[string]any
if err := json.Unmarshal(payload, &taskPayload); err == nil {
if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
for key, val := range resultPayload {
taskPayload[key] = val
}
payload, _ = json.Marshal(taskPayload)
}
}
}
firstNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err, Ctx: ctx}
}
node, ok = tm.nodes.Get(firstNode)
t.Topic = firstNode
ctx = context.WithValue(ctx, ContextIndex, "0")
headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background()
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.pool.Scheduler().AddTask(ctxx, t, opts...)
return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: mq.Pending}
}