Files
mq/dag.go
2024-09-26 19:10:36 +05:45

270 lines
5.7 KiB
Go

package mq
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/oarkflow/xsync"
)
type Node interface {
Queue() string
Consumer() *Consumer
Handler() Handler
}
type node struct {
queue string
consumer *Consumer
handler Handler
}
func (n *node) Queue() string {
return n.queue
}
func (n *node) Consumer() *Consumer {
return n.consumer
}
func (n *node) Handler() Handler {
return n.handler
}
type DAG struct {
nodes *xsync.MapOf[string, Node]
edges [][]string
loopEdges [][]string
broker *Broker
startNode Node
brokerAddr string
conditions map[string]map[string]string
syncMode bool
}
func NewDAG(brokerAddr string, syncMode bool) *DAG {
dag := &DAG{
nodes: xsync.NewMap[string, Node](),
brokerAddr: brokerAddr,
conditions: make(map[string]map[string]string),
syncMode: syncMode,
}
dag.broker = NewBroker(dag.TaskCallback)
return dag
}
func (dag *DAG) processEdge(ctx context.Context, id string, payload []byte, targets []string) {
newTask := &Task{
ID: id,
Payload: payload,
}
for _, target := range targets {
if target != "" {
dag.processNode(ctx, newTask, target)
}
}
}
func (dag *DAG) TaskCallback(ctx context.Context, task *Task) error {
if task.Error != nil {
return task.Error
}
if task.CurrentQueue == "" {
return errors.New("queue is empty")
}
for _, loopEdge := range dag.loopEdges {
start := loopEdge[0]
targets := loopEdge[1:]
if start != task.CurrentQueue {
continue
}
var items []any
err := json.Unmarshal(task.Result, &items)
if err != nil {
return err
}
for _, item := range items {
bt, _ := json.Marshal(item)
dag.processEdge(ctx, task.ID, bt, targets)
}
}
for _, edge := range dag.edges {
start := edge[0]
targets := edge[1:]
if start != task.CurrentQueue {
continue
}
dag.processEdge(ctx, task.ID, task.Result, targets)
}
n := dag.getConditionalNode(task.Status, task.CurrentQueue)
dag.processEdge(ctx, task.ID, task.Result, []string{n})
return nil
}
func (dag *DAG) AddNode(queue string, handler Handler, firstNode ...bool) {
consumer := NewConsumer(dag.brokerAddr, queue)
consumer.RegisterHandler(queue, handler)
dag.broker.NewQueue(queue)
n := &node{
queue: queue,
consumer: consumer,
handler: handler,
}
if len(firstNode) > 0 && firstNode[0] {
dag.startNode = n
}
dag.nodes.Set(queue, n)
}
func (dag *DAG) AddEdge(fromNodeID, toNodeID string) error {
err := dag.validateNodes(fromNodeID, toNodeID)
if err != nil {
return err
}
dag.edges = append(dag.edges, []string{fromNodeID, toNodeID})
return nil
}
func (dag *DAG) AddCondition(conditionNodeID string, conditions map[string]string) error {
for _, nodeID := range conditions {
if err := dag.validateNodes(nodeID); err != nil {
return err
}
}
dag.conditions[conditionNodeID] = conditions
return nil
}
func (dag *DAG) AddLoop(fromNodeID, toNodeID string) error {
err := dag.validateNodes(fromNodeID, toNodeID)
if err != nil {
return err
}
dag.loopEdges = append(dag.loopEdges, []string{fromNodeID, toNodeID})
return nil
}
func (dag *DAG) Start(ctx context.Context) error {
if dag.syncMode {
return nil
}
return dag.broker.Start(ctx, dag.brokerAddr)
}
func (dag *DAG) Prepare(ctx context.Context) error {
startNode, err := dag.findInitialNode()
if err != nil {
return err
}
if startNode == nil {
return fmt.Errorf("no initial node found")
}
dag.startNode = startNode
if dag.syncMode {
return nil
}
dag.nodes.ForEach(func(_ string, node Node) bool {
go node.Consumer().Consume(ctx)
return true
})
return nil
}
func (dag *DAG) calculateForFirstNode() (string, bool) {
inDegree := make(map[string]int)
for _, n := range dag.nodes.Keys() {
inDegree[n] = 0
}
for _, edge := range dag.edges {
inDegree[edge[1]]++
}
for _, edge := range dag.loopEdges {
inDegree[edge[1]]++
}
for n, count := range inDegree {
if count == 0 {
return n, true
}
}
return "", false
}
func (dag *DAG) findInitialNode() (Node, error) {
if dag.startNode != nil {
return dag.startNode, nil
}
var nt Node
n, ok := dag.calculateForFirstNode()
if !ok {
return nil, errors.New("no initial node found")
}
nt, ok = dag.nodes.Get(n)
if !ok {
return nil, errors.New("no initial node found")
}
return nt, nil
}
func (dag *DAG) processNode(ctx context.Context, task *Task, queue string) Result {
if !dag.syncMode {
if err := dag.broker.Publish(ctx, *task, queue); err != nil {
fmt.Println("Failed to publish task:", err)
}
return Result{}
}
n, ok := dag.nodes.Get(queue)
if task.CurrentQueue == "" {
task.CurrentQueue = queue
}
if !ok {
fmt.Println("Node not found:", queue)
return Result{Error: fmt.Errorf("node not found %s", queue)}
}
_, err := dag.broker.AddMessageToQueue(task, queue)
if err != nil {
return Result{Error: err}
}
result := n.Handler()(ctx, *task)
if result.Queue == "" {
result.Queue = task.CurrentQueue
}
if result.MessageID == "" {
result.MessageID = task.ID
}
if result.Error != nil {
return result
}
err = dag.broker.HandleProcessedMessage(ctx, result)
if err != nil {
return Result{Error: err, Status: result.Status}
}
return result
}
func (dag *DAG) ProcessTask(ctx context.Context, task Task) Result {
return dag.processNode(ctx, &task, dag.startNode.Queue())
}
func (dag *DAG) getConditionalNode(status, currentNode string) string {
conditions, ok := dag.conditions[currentNode]
if !ok {
return ""
}
conditionNodeID, ok := conditions[status]
if !ok {
return ""
}
return conditionNodeID
}
func (dag *DAG) validateNodes(nodeIDs ...string) error {
for _, nodeID := range nodeIDs {
if _, ok := dag.nodes.Get(nodeID); !ok {
return fmt.Errorf("node %s not found", nodeID)
}
}
return nil
}