diff --git a/consumer.go b/consumer.go index 6911868..26b62ab 100644 --- a/consumer.go +++ b/consumer.go @@ -154,6 +154,7 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } + if err := c.subscribe(ctx, c.queue); err != nil { return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) } diff --git a/ctx.go b/ctx.go index 8907cb0..eaf139a 100644 --- a/ctx.go +++ b/ctx.go @@ -118,8 +118,6 @@ func GetPublisherID(ctx context.Context) (string, bool) { // Helper function to convert HeaderMap to a regular map func getMapAsRegularMap(hd *HeaderMap) map[string]string { result := make(map[string]string) - hd.mu.RLock() - defer hd.mu.RUnlock() for key, value := range hd.headers { result[key] = value } diff --git a/dag/dag.go b/dag/dag.go deleted file mode 100644 index 0915ecf..0000000 --- a/dag/dag.go +++ /dev/null @@ -1,381 +0,0 @@ -package dag - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "sync" - "time" - - "github.com/oarkflow/mq/consts" - - "github.com/oarkflow/mq" -) - -type taskContext struct { - totalItems int - completed int - results []json.RawMessage - result json.RawMessage - multipleResults bool -} - -type DAG struct { - FirstNode string - server *mq.Broker - nodes map[string]*mq.Consumer - edges map[string]string - conditions map[string]map[string]string - loopEdges map[string][]string - taskChMap map[string]chan mq.Result - taskResults map[string]map[string]*taskContext - mu sync.Mutex -} - -func New(opts ...mq.Option) *DAG { - d := &DAG{ - nodes: make(map[string]*mq.Consumer), - edges: make(map[string]string), - conditions: make(map[string]map[string]string), - loopEdges: make(map[string][]string), - taskChMap: make(map[string]chan mq.Result), - taskResults: make(map[string]map[string]*taskContext), - } - opts = append(opts, mq.WithCallback(d.TaskCallback)) - d.server = mq.NewBroker(opts...) - return d -} - -func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { - tlsConfig := d.server.TLSConfig() - con := mq.NewConsumer(name, name, handler, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath)) - if len(firstNode) > 0 { - d.FirstNode = name - } - d.nodes[name] = con -} - -func (d *DAG) AddCondition(fromNode string, conditions map[string]string) { - d.conditions[fromNode] = conditions -} - -func (d *DAG) AddEdge(fromNode string, toNodes string) { - d.edges[fromNode] = toNodes -} - -func (d *DAG) AddLoop(fromNode string, toNode ...string) { - d.loopEdges[fromNode] = toNode -} - -func (d *DAG) Prepare() { - if d.FirstNode == "" { - firstNode, ok := d.FindFirstNode() - if ok && firstNode != "" { - d.FirstNode = firstNode - } - } -} - -func (d *DAG) Start(ctx context.Context, addr string) error { - d.Prepare() - if d.server.SyncMode() { - return nil - } - go func() { - err := d.server.Start(ctx) - if err != nil { - panic(err) - } - }() - for _, con := range d.nodes { - go func(con *mq.Consumer) { - con.Consume(ctx) - }(con) - } - log.Printf("HTTP server started on %s", addr) - config := d.server.TLSConfig() - if config.UseTLS { - return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) - } - return http.ListenAndServe(addr, nil) -} - -func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) mq.Result { - queue, ok := mq.GetQueue(ctx) - if !ok { - queue = d.FirstNode - } - var id string - if len(taskID) > 0 { - id = taskID[0] - } else { - id = mq.NewID() - } - task := &mq.Task{ - ID: id, - Payload: payload, - CreatedAt: time.Now(), - } - err := d.server.Publish(ctx, task, queue) - if err != nil { - return mq.Result{Error: err} - } - return mq.Result{ - Payload: payload, - Topic: queue, - TaskID: id, - } -} - -func (d *DAG) FindFirstNode() (string, bool) { - inDegree := make(map[string]int) - for n, _ := range d.nodes { - inDegree[n] = 0 - } - for _, outNode := range d.edges { - inDegree[outNode]++ - } - for _, targets := range d.loopEdges { - for _, outNode := range targets { - inDegree[outNode]++ - } - } - for n, count := range inDegree { - if count == 0 { - return n, true - } - } - return "", false -} - -func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result { - return d.sendSync(ctx, mq.Result{Payload: payload}) -} - -func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { - if d.FirstNode == "" { - return mq.Result{Error: fmt.Errorf("initial node not defined")} - } - if d.server.SyncMode() { - return d.sendSync(ctx, mq.Result{Payload: payload}) - } - resultCh := make(chan mq.Result) - result := d.PublishTask(ctx, payload) - if result.Error != nil { - return result - } - d.mu.Lock() - d.taskChMap[result.TaskID] = resultCh - d.mu.Unlock() - finalResult := <-resultCh - return finalResult -} - -func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { - if con, ok := d.nodes[task.Topic]; ok { - return con.ProcessTask(ctx, &mq.Task{ - ID: task.TaskID, - Payload: task.Payload, - }) - } - return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Topic)} -} - -func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { - if task.TaskID == "" { - task.TaskID = mq.NewID() - } - if task.Topic == "" { - task.Topic = d.FirstNode - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: task.Topic, - }) - result := d.processNode(ctx, task) - if result.Error != nil { - return result - } - for _, target := range d.loopEdges[task.Topic] { - var items, results []json.RawMessage - if err := json.Unmarshal(result.Payload, &items); err != nil { - return mq.Result{Error: err} - } - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: item, - Topic: target, - TaskID: result.TaskID, - }) - if result.Error != nil { - return result - } - results = append(results, result.Payload) - } - bt, err := json.Marshal(results) - if err != nil { - return mq.Result{Error: err} - } - result.Payload = bt - } - if conditions, ok := d.conditions[task.Topic]; ok { - if target, exists := conditions[result.Status]; exists { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Topic: target, - TaskID: result.TaskID, - }) - if result.Error != nil { - return result - } - } - } - if target, ok := d.edges[task.Topic]; ok { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Topic: target, - TaskID: result.TaskID, - }) - if result.Error != nil { - return result - } - } - return result -} - -func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) { - var result any - var payload []byte - completed := false - multipleResults := false - if ok && triggeredNode != "" { - taskResults, ok := d.taskResults[task.TaskID] - if ok { - nodeResult, exists := taskResults[triggeredNode] - if exists { - multipleResults = nodeResult.multipleResults - nodeResult.completed++ - if nodeResult.completed == nodeResult.totalItems { - completed = true - } - if multipleResults { - nodeResult.results = append(nodeResult.results, task.Payload) - if completed { - result = nodeResult.results - } - } else { - nodeResult.result = task.Payload - if completed { - result = nodeResult.result - } - } - } - if completed { - delete(taskResults, triggeredNode) - } - } - } - if completed { - payload, _ = json.Marshal(result) - } else { - payload = task.Payload - } - return payload, completed, multipleResults -} - -func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { - if task.Error != nil { - return mq.Result{Error: task.Error} - } - triggeredNode, ok := mq.GetTriggerNode(ctx) - payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) - if loopNodes, exists := d.loopEdges[task.Topic]; exists { - var items []json.RawMessage - if err := json.Unmarshal(payload, &items); err != nil { - return mq.Result{Error: task.Error} - } - d.taskResults[task.TaskID] = map[string]*taskContext{ - task.Topic: { - totalItems: len(items), - multipleResults: true, - }, - } - - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Topic}) - for _, loopNode := range loopNodes { - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: loopNode, - }) - result := d.PublishTask(ctx, item, task.TaskID) - if result.Error != nil { - return result - } - } - } - - return task - } - if multipleResults && completed { - task.Topic = triggeredNode - } - if conditions, ok := d.conditions[task.Topic]; ok { - if target, exists := conditions[task.Status]; exists { - d.taskResults[task.TaskID] = map[string]*taskContext{ - task.Topic: { - totalItems: len(conditions), - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - consts.TriggerNode: task.Topic, - }) - result := d.PublishTask(ctx, payload, task.TaskID) - if result.Error != nil { - return result - } - } - } else { - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Topic}) - edge, exists := d.edges[task.Topic] - if exists { - d.taskResults[task.TaskID] = map[string]*taskContext{ - task.Topic: { - totalItems: 1, - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: edge, - }) - result := d.PublishTask(ctx, payload, task.TaskID) - if result.Error != nil { - return result - } - } else if completed { - d.mu.Lock() - if resultCh, ok := d.taskChMap[task.TaskID]; ok { - resultCh <- mq.Result{ - Payload: payload, - Topic: task.Topic, - TaskID: task.TaskID, - Status: "done", - } - delete(d.taskChMap, task.TaskID) - delete(d.taskResults, task.TaskID) - } - d.mu.Unlock() - } - } - - return task -} diff --git a/examples/dag.go b/examples/dag.go index 59f8431..76bf297 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -1,53 +1,79 @@ package main -/* import ( "context" "encoding/json" - "fmt" "io" "net/http" - "time" "github.com/oarkflow/mq" - "github.com/oarkflow/mq/dag" - "github.com/oarkflow/mq/examples/tasks" + "github.com/oarkflow/mq/v2" ) -var d *dag.DAG +func handler1(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Payload: task.Payload} +} + +func handler2(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + return mq.Result{Payload: task.Payload} +} + +func handler3(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + age := int(user["age"].(float64)) + status := "FAIL" + if age > 20 { + status = "PASS" + } + user["status"] = status + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload, Status: status} +} + +func handler4(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["final"] = "D" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} +} + +func handler5(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["salary"] = "E" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} +} + +func handler6(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + resultPayload, _ := json.Marshal(map[string]any{"storage": user}) + return mq.Result{Payload: resultPayload} +} + +var ( + d = v2.NewDAG(mq.WithSyncMode(false)) +) func main() { - d = dag.New(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.crt")) - d.AddNode("queue1", tasks.Node1, true) - d.AddNode("queue2", tasks.Node2) - d.AddNode("queue3", tasks.Node3) - d.AddNode("queue4", tasks.Node4) + d.AddNode("A", handler1) + d.AddNode("B", handler2) + d.AddNode("C", handler3) + d.AddNode("D", handler4) + d.AddNode("E", handler5) + d.AddNode("F", handler6) + d.AddEdge("A", "B", v2.LoopEdge) + d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) + d.AddEdge("B", "C") + d.AddEdge("D", "F") + d.AddEdge("E", "F") - d.AddNode("queue5", tasks.CheckCondition) - d.AddNode("queue6", tasks.Pass) - d.AddNode("queue7", tasks.Fail) - - d.AddCondition("queue5", map[string]string{"pass": "queue6", "fail": "queue7"}) - d.AddEdge("queue1", "queue2") - d.AddEdge("queue2", "queue4") - d.AddEdge("queue3", "queue5") - - d.AddLoop("queue2", "queue3") - d.Prepare() - go func() { - d.Start(context.Background(), ":8081") - }() - go func() { - time.Sleep(3 * time.Second) - result := d.Send(context.Background(), []byte(`[{"user_id": 1}, {"user_id": 2}]`)) - if result.Error != nil { - panic(result.Error) - } - fmt.Println("Response", string(result.Payload)) - }() - - time.Sleep(10 * time.Second) - d.Prepare() + // fmt.Println(rs.TaskID, "Task", string(rs.Payload)) http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) err := d.Start(context.TODO(), ":8083") @@ -75,19 +101,13 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ http.Error(w, "Empty request body", http.StatusBadRequest) return } - var rs mq.Result - if requestType == "request" { - rs = d.Request(context.Background(), payload) - } else { - rs = d.Send(context.Background(), payload) - } + rs := d.ProcessTask(context.Background(), "A", payload) w.Header().Set("Content-Type", "application/json") result := map[string]any{ "message_id": rs.TaskID, - "payload": string(rs.Payload), + "payload": rs.Payload, "error": rs.Error, } json.NewEncoder(w).Encode(result) } } -*/ diff --git a/examples/dag_v2.go b/examples/dag_v2.go deleted file mode 100644 index 82e12eb..0000000 --- a/examples/dag_v2.go +++ /dev/null @@ -1,125 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - - "github.com/oarkflow/mq" - "github.com/oarkflow/mq/v2" -) - -func handler1(ctx context.Context, task *mq.Task) mq.Result { - return mq.Result{Payload: task.Payload} -} - -func handler2(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - return mq.Result{Payload: task.Payload} -} - -func handler3(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - age := int(user["age"].(float64)) - status := "FAIL" - if age > 20 { - status = "PASS" - } - user["status"] = status - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload, Status: status} -} - -func handler4(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - user["final"] = "D" - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload} -} - -func handler5(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - user["salary"] = "E" - resultPayload, _ := json.Marshal(user) - return mq.Result{Payload: resultPayload} -} - -func handler6(ctx context.Context, task *mq.Task) mq.Result { - var user map[string]any - json.Unmarshal(task.Payload, &user) - resultPayload, _ := json.Marshal(map[string]any{"storage": user}) - return mq.Result{Payload: resultPayload} -} - -var ( - d = v2.NewDAG(mq.WithSyncMode(true)) -) - -func main() { - d.AddNode("A", handler1) - d.AddNode("B", handler2) - d.AddNode("C", handler3) - d.AddNode("D", handler4) - d.AddNode("E", handler5) - d.AddNode("F", handler6) - d.AddEdge("A", "B", v2.LoopEdge) - d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) - d.AddEdge("B", "C") - d.AddEdge("D", "F") - d.AddEdge("E", "F") - - initialPayload, _ := json.Marshal([]map[string]any{ - {"user_id": 1, "age": 12}, - {"user_id": 2, "age": 34}, - }) - /*for i := 0; i < 100; i++ { - - }*/ - rs := d.ProcessTask(context.Background(), "A", initialPayload) - if rs.Error != nil { - panic(rs.Error) - } - fmt.Println(rs.TaskID, "Task", string(rs.Payload)) - /*http.HandleFunc("POST /publish", requestHandler("publish")) - http.HandleFunc("POST /request", requestHandler("request")) - err := d.Start(context.TODO(), ":8083") - if err != nil { - panic(err) - }*/ -} - -func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var payload []byte - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err = io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return - } - rs := d.ProcessTask(context.Background(), "A", payload) - w.Header().Set("Content-Type", "application/json") - result := map[string]any{ - "message_id": rs.TaskID, - "payload": string(rs.Payload), - "error": rs.Error, - } - json.NewEncoder(w).Encode(result) - } -} diff --git a/v2/dag.go b/v2/dag.go index 8868591..5163fa2 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "sync" + "time" "github.com/oarkflow/xid" @@ -78,6 +79,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { }() for _, con := range tm.Nodes { go func(con *Node) { + time.Sleep(1 * time.Second) con.consumer.Consume(ctx) }(con) } diff --git a/v2/task_manager.go b/v2/task_manager.go index 66daedf..b9a5f27 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -5,20 +5,22 @@ import ( "encoding/json" "fmt" "sync" + "sync/atomic" "github.com/oarkflow/mq" "github.com/oarkflow/mq/consts" ) type TaskManager struct { - taskID string - dag *DAG - wg sync.WaitGroup - mutex sync.Mutex - results []mq.Result - nodeResults map[string]mq.Result - done chan struct{} - finalResult chan mq.Result // Channel to collect final results + taskID string + dag *DAG + wg sync.WaitGroup + mutex sync.Mutex + results []mq.Result + waitingCallback int64 + nodeResults map[string]mq.Result + done chan struct{} + finalResult chan mq.Result // Channel to collect final results } func NewTaskManager(d *DAG, taskID string) *TaskManager { @@ -26,9 +28,7 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { dag: d, nodeResults: make(map[string]mq.Result), results: make([]mq.Result, 0), - done: make(chan struct{}), taskID: taskID, - finalResult: make(chan mq.Result), // Initialize finalResult channel } } @@ -37,26 +37,97 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j if !ok { return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } - tm.wg.Add(1) - go tm.processNode(ctx, node, payload) - go func() { - tm.wg.Wait() - close(tm.done) - }() - select { - case <-ctx.Done(): - return mq.Result{Error: ctx.Err()} - case <-tm.done: - tm.mutex.Lock() - defer tm.mutex.Unlock() - if len(tm.results) == 1 { - return tm.handleResult(ctx, tm.results[0]) + if tm.dag.server.SyncMode() { + tm.done = make(chan struct{}) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + close(tm.done) + }() + select { + case <-ctx.Done(): + return mq.Result{Error: ctx.Err()} + case <-tm.done: + tm.mutex.Lock() + defer tm.mutex.Unlock() + if len(tm.results) == 1 { + return tm.handleResult(ctx, tm.results[0]) + } + return tm.handleResult(ctx, tm.results) + } + } else { + tm.finalResult = make(chan mq.Result) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + }() + select { + case result := <-tm.finalResult: // Block until a result is available + return result + case <-ctx.Done(): // Handle context cancellation + return mq.Result{Error: ctx.Err()} } - return tm.handleResult(ctx, tm.results) } } func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result { + if result.Topic != "" { + atomic.AddInt64(&tm.waitingCallback, -1) + } + node, ok := tm.dag.Nodes[result.Topic] + if !ok { + return result + } + edges := make([]Edge, len(node.Edges)) + copy(edges, node.Edges) + if result.Status != "" { + if conditions, ok := tm.dag.conditions[result.Topic]; ok { + if targetNodeKey, ok := conditions[result.Status]; ok { + if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { + edges = append(edges, Edge{From: node, To: targetNode}) + } + } + } + } + if len(edges) == 0 { + tm.appendFinalResult(result) + if !tm.dag.server.SyncMode() { + var rs mq.Result + if len(tm.results) == 1 { + rs = tm.handleResult(ctx, tm.results[0]) + } else { + rs = tm.handleResult(ctx, tm.results) + } + if tm.waitingCallback == 0 { + tm.finalResult <- rs + } + } + return result + } + for _, edge := range edges { + switch edge.Type { + case LoopEdge: + var items []json.RawMessage + err := json.Unmarshal(result.Payload, &items) + if err != nil { + tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}) + return result + } + for _, item := range items { + tm.wg.Add(1) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) + go tm.processNode(ctx, edge.To, item) + } + case SimpleEdge: + if edge.To != nil { + tm.wg.Add(1) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) + go tm.processNode(ctx, edge.To, result.Payload) + } + } + } return mq.Result{} } @@ -103,6 +174,7 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) { } func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { + atomic.AddInt64(&tm.waitingCallback, 1) defer tm.wg.Done() var result mq.Result select { @@ -115,6 +187,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json if tm.dag.server.SyncMode() { result = node.consumer.ProcessTask(ctx, NewTask(tm.taskID, payload, node.Key)) result.Topic = node.Key + result.TaskID = tm.taskID if result.Error != nil { tm.appendFinalResult(result) return @@ -130,41 +203,5 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json tm.mutex.Lock() tm.nodeResults[node.Key] = result tm.mutex.Unlock() - edges := make([]Edge, len(node.Edges)) - copy(edges, node.Edges) - if result.Status != "" { - if conditions, ok := tm.dag.conditions[result.Topic]; ok { - if targetNodeKey, ok := conditions[result.Status]; ok { - if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { - edges = append(edges, Edge{From: node, To: targetNode}) - } - } - } - } - if len(edges) == 0 { - tm.appendFinalResult(result) - return - } - for _, edge := range edges { - switch edge.Type { - case LoopEdge: - var items []json.RawMessage - err := json.Unmarshal(result.Payload, &items) - if err != nil { - tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}) - return - } - for _, item := range items { - tm.wg.Add(1) - ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, item) - } - case SimpleEdge: - if edge.To != nil { - tm.wg.Add(1) - ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) - go tm.processNode(ctx, edge.To, result.Payload) - } - } - } + tm.handleCallback(ctx, result) }