diff --git a/dag/dag.go b/dag/dag.go index 64666c3..c0745ef 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -1,102 +1,382 @@ -package v2 +package dag import ( "context" "encoding/json" + "fmt" + "log" + "net/http" "sync" + "time" - "github.com/oarkflow/xid" + "github.com/oarkflow/mq/consts" + + "github.com/oarkflow/mq" ) -type Handler func(ctx context.Context, task *Task) Result - -type Result struct { - TaskID string `json:"task_id"` - NodeKey string `json:"node_key"` - Payload json.RawMessage `json:"payload"` - Status string `json:"status"` - Error error `json:"error"` -} - -type Task struct { - ID string `json:"id"` - NodeKey string `json:"node_key"` - Payload json.RawMessage `json:"payload"` - Results map[string]Result `json:"results"` -} - -type Node struct { - Key string - Edges []Edge - handler Handler -} - -type EdgeType int - -const ( - SimpleEdge EdgeType = iota - LoopEdge - ConditionEdge -) - -type Edge struct { - From *Node - To *Node - Type EdgeType - Condition func(result Result) bool +type taskContext struct { + totalItems int + completed int + results []json.RawMessage + result json.RawMessage + multipleResults bool } type DAG struct { - Nodes map[string]*Node - taskContext map[string]*TaskManager - mu sync.RWMutex + FirstNode string + server *mq.Broker + nodes map[string]*mq.Consumer + edges map[string]string + conditions map[string]map[string]string + loopEdges map[string][]string + taskChMap map[string]chan mq.Result + taskResults map[string]map[string]*taskContext + mu sync.Mutex } -func NewDAG() *DAG { - return &DAG{ - Nodes: make(map[string]*Node), - taskContext: make(map[string]*TaskManager), +func New(opts ...mq.Option) *DAG { + d := &DAG{ + nodes: make(map[string]*mq.Consumer), + edges: make(map[string]string), + conditions: make(map[string]map[string]string), + loopEdges: make(map[string][]string), + taskChMap: make(map[string]chan mq.Result), + taskResults: make(map[string]map[string]*taskContext), + } + opts = append(opts, mq.WithCallback(d.TaskCallback)) + d.server = mq.NewBroker(opts...) + return d +} + +func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { + tlsConfig := d.server.TLSConfig() + con := mq.NewConsumer(name, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath)) + if len(firstNode) > 0 { + d.FirstNode = name + } + con.RegisterHandler(name, handler) + d.nodes[name] = con +} + +func (d *DAG) AddCondition(fromNode string, conditions map[string]string) { + d.conditions[fromNode] = conditions +} + +func (d *DAG) AddEdge(fromNode string, toNodes string) { + d.edges[fromNode] = toNodes +} + +func (d *DAG) AddLoop(fromNode string, toNode ...string) { + d.loopEdges[fromNode] = toNode +} + +func (d *DAG) Prepare() { + if d.FirstNode == "" { + firstNode, ok := d.FindFirstNode() + if ok && firstNode != "" { + d.FirstNode = firstNode + } } } -func (tm *DAG) AddNode(key string, handler Handler) { - tm.mu.Lock() - defer tm.mu.Unlock() - tm.Nodes[key] = &Node{ - Key: key, - handler: handler, +func (d *DAG) Start(ctx context.Context, addr string) error { + d.Prepare() + if d.server.SyncMode() { + return nil } + go func() { + err := d.server.Start(ctx) + if err != nil { + panic(err) + } + }() + for _, con := range d.nodes { + go func(con *mq.Consumer) { + con.Consume(ctx) + }(con) + } + log.Printf("HTTP server started on %s", addr) + config := d.server.TLSConfig() + if config.UseTLS { + return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) + } + return http.ListenAndServe(addr, nil) } -func (tm *DAG) AddEdge(from, to string, edgeType EdgeType) { - tm.mu.Lock() - defer tm.mu.Unlock() - fromNode, ok := tm.Nodes[from] +func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) mq.Result { + queue, ok := mq.GetQueue(ctx) if !ok { - return + queue = d.FirstNode } - toNode, ok := tm.Nodes[to] - if !ok { - return + var id string + if len(taskID) > 0 { + id = taskID[0] + } else { + id = mq.NewID() } - fromNode.Edges = append(fromNode.Edges, Edge{ - From: fromNode, - To: toNode, - Type: edgeType, + task := mq.Task{ + ID: id, + Payload: payload, + CreatedAt: time.Now(), + } + err := d.server.Publish(ctx, task, queue) + if err != nil { + return mq.Result{Error: err} + } + return mq.Result{ + Payload: payload, + Queue: queue, + MessageID: id, + } +} + +func (d *DAG) FindFirstNode() (string, bool) { + inDegree := make(map[string]int) + for n, _ := range d.nodes { + inDegree[n] = 0 + } + for _, outNode := range d.edges { + inDegree[outNode]++ + } + for _, targets := range d.loopEdges { + for _, outNode := range targets { + inDegree[outNode]++ + } + } + for n, count := range inDegree { + if count == 0 { + return n, true + } + } + return "", false +} + +func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result { + return d.sendSync(ctx, mq.Result{Payload: payload}) +} + +func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { + if d.FirstNode == "" { + return mq.Result{Error: fmt.Errorf("initial node not defined")} + } + if d.server.SyncMode() { + return d.sendSync(ctx, mq.Result{Payload: payload}) + } + resultCh := make(chan mq.Result) + result := d.PublishTask(ctx, payload) + if result.Error != nil { + return result + } + d.mu.Lock() + d.taskChMap[result.MessageID] = resultCh + d.mu.Unlock() + finalResult := <-resultCh + return finalResult +} + +func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { + if con, ok := d.nodes[task.Queue]; ok { + return con.ProcessTask(ctx, mq.Task{ + ID: task.MessageID, + Payload: task.Payload, + }) + } + return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)} +} + +func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { + if task.MessageID == "" { + task.MessageID = mq.NewID() + } + if task.Queue == "" { + task.Queue = d.FirstNode + } + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: task.Queue, }) + result := d.processNode(ctx, task) + if result.Error != nil { + return result + } + for _, target := range d.loopEdges[task.Queue] { + var items, results []json.RawMessage + if err := json.Unmarshal(result.Payload, &items); err != nil { + return mq.Result{Error: err} + } + for _, item := range items { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + }) + result = d.sendSync(ctx, mq.Result{ + Payload: item, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + results = append(results, result.Payload) + } + bt, err := json.Marshal(results) + if err != nil { + return mq.Result{Error: err} + } + result.Payload = bt + } + if conditions, ok := d.conditions[task.Queue]; ok { + if target, exists := conditions[result.Status]; exists { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + }) + result = d.sendSync(ctx, mq.Result{ + Payload: result.Payload, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + } + } + if target, ok := d.edges[task.Queue]; ok { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + }) + result = d.sendSync(ctx, mq.Result{ + Payload: result.Payload, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + } + return result } -func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) Result { - tm.mu.Lock() - defer tm.mu.Unlock() - taskID := xid.New().String() - task := &Task{ - ID: taskID, - NodeKey: node, - Payload: payload, - Results: make(map[string]Result), +func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) { + var result any + var payload []byte + completed := false + multipleResults := false + if ok && triggeredNode != "" { + taskResults, ok := d.taskResults[task.MessageID] + if ok { + nodeResult, exists := taskResults[triggeredNode] + if exists { + multipleResults = nodeResult.multipleResults + nodeResult.completed++ + if nodeResult.completed == nodeResult.totalItems { + completed = true + } + if multipleResults { + nodeResult.results = append(nodeResult.results, task.Payload) + if completed { + result = nodeResult.results + } + } else { + nodeResult.result = task.Payload + if completed { + result = nodeResult.result + } + } + } + if completed { + delete(taskResults, triggeredNode) + } + } } - manager := NewTaskManager(tm) - tm.taskContext[taskID] = manager - return manager.processTask(ctx, node, task) + if completed { + payload, _ = json.Marshal(result) + } else { + payload = task.Payload + } + return payload, completed, multipleResults +} + +func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { + if task.Error != nil { + return mq.Result{Error: task.Error} + } + triggeredNode, ok := mq.GetTriggerNode(ctx) + payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) + if loopNodes, exists := d.loopEdges[task.Queue]; exists { + var items []json.RawMessage + if err := json.Unmarshal(payload, &items); err != nil { + return mq.Result{Error: task.Error} + } + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { + totalItems: len(items), + multipleResults: true, + }, + } + + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) + for _, loopNode := range loopNodes { + for _, item := range items { + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: loopNode, + }) + result := d.PublishTask(ctx, item, task.MessageID) + if result.Error != nil { + return result + } + } + } + + return task + } + if multipleResults && completed { + task.Queue = triggeredNode + } + if conditions, ok := d.conditions[task.Queue]; ok { + if target, exists := conditions[task.Status]; exists { + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { + totalItems: len(conditions), + }, + } + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: target, + consts.TriggerNode: task.Queue, + }) + result := d.PublishTask(ctx, payload, task.MessageID) + if result.Error != nil { + return result + } + } + } else { + ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) + edge, exists := d.edges[task.Queue] + if exists { + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { + totalItems: 1, + }, + } + ctx = mq.SetHeaders(ctx, map[string]string{ + consts.QueueKey: edge, + }) + result := d.PublishTask(ctx, payload, task.MessageID) + if result.Error != nil { + return result + } + } else if completed { + d.mu.Lock() + if resultCh, ok := d.taskChMap[task.MessageID]; ok { + resultCh <- mq.Result{ + Payload: payload, + Queue: task.Queue, + MessageID: task.MessageID, + Status: "done", + } + delete(d.taskChMap, task.MessageID) + delete(d.taskResults, task.MessageID) + } + d.mu.Unlock() + } + } + + return task } diff --git a/examples/dag_v2.go b/examples/dag_v2.go index 8757147..9ca810e 100644 --- a/examples/dag_v2.go +++ b/examples/dag_v2.go @@ -4,272 +4,58 @@ import ( "context" "encoding/json" "fmt" - "sync" - "time" + v2 "github.com/oarkflow/mq/v2" ) -type Task struct { - ID string `json:"id"` - Payload json.RawMessage `json:"payload"` - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - Status string `json:"status"` - Error error `json:"error"` -} - -type Result struct { - Payload json.RawMessage `json:"payload"` - Queue string `json:"queue"` - MessageID string `json:"message_id"` - Error error `json:"error,omitempty"` - Status string `json:"status"` -} - -const ( - SimpleEdge = iota - LoopEdge - ConditionEdge -) - -type Edge struct { - edgeType int - to string - conditions map[string]string -} - -type Node struct { - key string - handler func(context.Context, Task) Result - edges []Edge -} - -type RadixTrie struct { - children map[rune]*RadixTrie - node *Node - mu sync.RWMutex -} - -func NewRadixTrie() *RadixTrie { - return &RadixTrie{ - children: make(map[rune]*RadixTrie), +func handler1(ctx context.Context, task *v2.Task) v2.Result { + return v2.Result{ + TaskID: task.ID, + NodeKey: "A", + Payload: task.Payload, + Status: "success", } } -func (trie *RadixTrie) Insert(key string, node *Node) { - trie.mu.Lock() - defer trie.mu.Unlock() - - current := trie - for _, char := range key { - if _, exists := current.children[char]; !exists { - current.children[char] = NewRadixTrie() - } - current = current.children[char] - } - current.node = node -} - -func (trie *RadixTrie) Search(key string) (*Node, bool) { - trie.mu.RLock() - defer trie.mu.RUnlock() - current := trie - for _, char := range key { - if _, exists := current.children[char]; !exists { - return nil, false - } - current = current.children[char] - } - if current.node != nil { - return current.node, true - } - return nil, false -} - -type DAG struct { - trie *RadixTrie - mu sync.RWMutex -} - -func NewDAG() *DAG { - return &DAG{ - trie: NewRadixTrie(), +func handler2(ctx context.Context, task *v2.Task) v2.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + return v2.Result{ + TaskID: task.ID, + NodeKey: "B", + Payload: task.Payload, + Status: "success", } } -func (d *DAG) AddNode(key string, handler func(context.Context, Task) Result, isRoot ...bool) { - node := &Node{key: key, handler: handler} - d.trie.Insert(key, node) -} - -func (d *DAG) AddEdge(fromKey string, toKey string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add edge.\n", fromKey) - return +func handler3(ctx context.Context, task *v2.Task) v2.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + age := int(user["age"].(float64)) + status := "FAIL" + if age > 20 { + status = "PASS" } - edge := Edge{edgeType: SimpleEdge, to: toKey} - node.edges = append(node.edges, edge) -} - -func (d *DAG) AddLoop(fromKey string, toKey string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add loop edge.\n", fromKey) - return + user["status"] = status + resultPayload, _ := json.Marshal(user) + return v2.Result{ + TaskID: task.ID, + NodeKey: "C", + Payload: resultPayload, + Status: status, } - edge := Edge{edgeType: LoopEdge, to: toKey} - node.edges = append(node.edges, edge) -} - -func (d *DAG) AddCondition(fromKey string, conditions map[string]string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add condition edge.\n", fromKey) - return - } - edge := Edge{edgeType: ConditionEdge, conditions: conditions} - node.edges = append(node.edges, edge) -} - -type ProcessCallback func(ctx context.Context, key string, result Result) string - -func (d *DAG) ProcessTask(ctx context.Context, key string, task Task) { - node, exists := d.trie.Search(key) - if !exists { - fmt.Printf("Node %s not found.\n", key) - return - } - result := node.handler(ctx, task) - nextKey := d.callback(ctx, key, result) - if nextKey != "" { - d.ProcessTask(ctx, nextKey, task) - } -} - -func (d *DAG) ProcessLoop(ctx context.Context, key string, task Task) { - _, exists := d.trie.Search(key) - if !exists { - fmt.Printf("Node %s not found.\n", key) - return - } - var items []json.RawMessage - err := json.Unmarshal(task.Payload, &items) - if err != nil { - fmt.Printf("Error unmarshaling payload as slice: %v\n", err) - return - } - for _, item := range items { - newTask := Task{ - ID: task.ID, - Payload: item, - } - - d.ProcessTask(ctx, key, newTask) - } -} - -func (d *DAG) callback(ctx context.Context, currentKey string, result Result) string { - fmt.Printf("Callback received result from %s: %s\n", currentKey, string(result.Payload)) - node, exists := d.trie.Search(currentKey) - if !exists { - return "" - } - for _, edge := range node.edges { - switch edge.edgeType { - case SimpleEdge: - return edge.to - case LoopEdge: - - d.ProcessLoop(ctx, edge.to, Task{Payload: result.Payload}) - return "" - case ConditionEdge: - if nextKey, conditionMet := edge.conditions[result.Status]; conditionMet { - return nextKey - } - } - } - return "" -} - -func Node1(ctx context.Context, task Task) Result { - return Result{Payload: task.Payload, MessageID: task.ID} -} - -func Node2(ctx context.Context, task Task) Result { - return Result{Payload: task.Payload, MessageID: task.ID} -} - -func Node3(ctx context.Context, task Task) Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) - bt, _ := json.Marshal(data) - return Result{Payload: bt, MessageID: task.ID} -} - -func Node4(ctx context.Context, task Task) Result { - var data []map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - payload := map[string]any{"storage": data} - bt, _ := json.Marshal(payload) - return Result{Payload: bt, MessageID: task.ID} -} - -func CheckCondition(ctx context.Context, task Task) Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - var status string - if data["user_id"].(float64) == 2 { - status = "pass" - } else { - status = "fail" - } - return Result{Status: status, Payload: task.Payload, MessageID: task.ID} -} - -func Pass(ctx context.Context, task Task) Result { - fmt.Println("Pass") - return Result{Payload: task.Payload} -} - -func Fail(ctx context.Context, task Task) Result { - fmt.Println("Fail") - return Result{Payload: []byte(`{"test2": "asdsa"}`)} } func main() { - dag := NewDAG() - dag.AddNode("queue1", Node1, true) - dag.AddNode("queue2", Node2) - dag.AddNode("queue3", Node3) - dag.AddNode("queue4", Node4) - dag.AddNode("queue5", CheckCondition) - dag.AddNode("queue6", Pass) - dag.AddNode("queue7", Fail) - dag.AddEdge("queue1", "queue2") - dag.AddEdge("queue2", "queue4") - dag.AddEdge("queue3", "queue5") - dag.AddLoop("queue2", "queue3") - dag.AddCondition("queue5", map[string]string{"pass": "queue6", "fail": "queue7"}) - ctx := context.Background() - task := Task{ - ID: "task1", - Payload: []byte(`[{"user_id": 1}, {"user_id": 2}]`), - } - dag.ProcessTask(ctx, "queue1", task) + dag := v2.NewDAG() + dag.AddNode("A", handler1) + dag.AddNode("B", handler2) + dag.AddNode("C", handler3) + dag.AddEdge("A", "B", v2.LoopEdge) + dag.AddEdge("B", "C", v2.SimpleEdge) + initialPayload, _ := json.Marshal([]map[string]any{ + {"user_id": 1, "age": 12}, + {"user_id": 2, "age": 34}, + }) + rs := dag.ProcessTask(context.Background(), "A", initialPayload) + fmt.Println(string(rs.Payload)) } diff --git a/v2/dag.go b/v2/dag.go new file mode 100644 index 0000000..64666c3 --- /dev/null +++ b/v2/dag.go @@ -0,0 +1,102 @@ +package v2 + +import ( + "context" + "encoding/json" + "sync" + + "github.com/oarkflow/xid" +) + +type Handler func(ctx context.Context, task *Task) Result + +type Result struct { + TaskID string `json:"task_id"` + NodeKey string `json:"node_key"` + Payload json.RawMessage `json:"payload"` + Status string `json:"status"` + Error error `json:"error"` +} + +type Task struct { + ID string `json:"id"` + NodeKey string `json:"node_key"` + Payload json.RawMessage `json:"payload"` + Results map[string]Result `json:"results"` +} + +type Node struct { + Key string + Edges []Edge + handler Handler +} + +type EdgeType int + +const ( + SimpleEdge EdgeType = iota + LoopEdge + ConditionEdge +) + +type Edge struct { + From *Node + To *Node + Type EdgeType + Condition func(result Result) bool +} + +type DAG struct { + Nodes map[string]*Node + taskContext map[string]*TaskManager + mu sync.RWMutex +} + +func NewDAG() *DAG { + return &DAG{ + Nodes: make(map[string]*Node), + taskContext: make(map[string]*TaskManager), + } +} + +func (tm *DAG) AddNode(key string, handler Handler) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.Nodes[key] = &Node{ + Key: key, + handler: handler, + } +} + +func (tm *DAG) AddEdge(from, to string, edgeType EdgeType) { + tm.mu.Lock() + defer tm.mu.Unlock() + fromNode, ok := tm.Nodes[from] + if !ok { + return + } + toNode, ok := tm.Nodes[to] + if !ok { + return + } + fromNode.Edges = append(fromNode.Edges, Edge{ + From: fromNode, + To: toNode, + Type: edgeType, + }) +} + +func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) Result { + tm.mu.Lock() + defer tm.mu.Unlock() + taskID := xid.New().String() + task := &Task{ + ID: taskID, + NodeKey: node, + Payload: payload, + Results: make(map[string]Result), + } + manager := NewTaskManager(tm) + tm.taskContext[taskID] = manager + return manager.processTask(ctx, node, task) +} diff --git a/dag/task_manager.go b/v2/task_manager.go similarity index 96% rename from dag/task_manager.go rename to v2/task_manager.go index ae92306..316130c 100644 --- a/dag/task_manager.go +++ b/v2/task_manager.go @@ -18,9 +18,10 @@ type TaskManager struct { func NewTaskManager(d *DAG) *TaskManager { return &TaskManager{ - dag: d, - results: make([]Result, 0), - done: make(chan struct{}), + dag: d, + nodeResults: make(map[string]Result), + results: make([]Result, 0), + done: make(chan struct{}), } }