diff --git a/examples/dag.go b/examples/dag.go index 8ce672f..cbd4de8 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -13,17 +13,14 @@ import ( func main() { dag := NewDAG() - dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result { log.Printf("Handling task for queue1: %s", string(task.Payload)) return mq.Result{Payload: task.Payload, MessageID: task.ID} }) - dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result { log.Printf("Handling task for queue2: %s", string(task.Payload)) return mq.Result{Payload: task.Payload, MessageID: task.ID} }) - dag.AddNode("queue3", func(ctx context.Context, task mq.Task) mq.Result { var data map[string]any err := json.Unmarshal(task.Payload, &data) @@ -35,18 +32,22 @@ func main() { log.Printf("Handling task for queue3: %s", string(bt)) return mq.Result{Payload: bt, MessageID: task.ID} }) - dag.AddNode("queue4", func(ctx context.Context, task mq.Task) mq.Result { + var data []map[string]any + err := json.Unmarshal(task.Payload, &data) + if err != nil { + return mq.Result{Error: err} + } log.Printf("Handling task for queue4: %s", string(task.Payload)) - return mq.Result{Payload: task.Payload, MessageID: task.ID} + payload := map[string]any{"storage": data} + bt, _ := json.Marshal(payload) + return mq.Result{Payload: bt, MessageID: task.ID} }) - // Define edges and loops dag.AddEdge("queue1", "queue2") - dag.AddLoop("queue2", "queue3") // Loop through queue3 for each item from queue2 - dag.AddEdge("queue2", "queue4") // After processing queue2 (including loop), continue to queue4 + dag.AddLoop("queue2", "queue3") + dag.AddEdge("queue2", "queue4") - // Send task payload go func() { time.Sleep(2 * time.Second) finalResult := dag.Send([]byte(`[{"user_id": 1}, {"user_id": 2}]`)) @@ -59,18 +60,16 @@ func main() { } } -// DAG struct to handle tasks, edges, and loops type DAG struct { server *mq.Broker nodes map[string]*mq.Consumer edges map[string][]string - loopEdges map[string]string // Tracks loop edges + loopEdges map[string]string taskChMap map[string]chan mq.Result - loopTaskMap map[string]*loopTaskContext // Manages loop tasks + loopTaskMap map[string]*loopTaskContext mu sync.Mutex } -// Structure to store the loop task context type loopTaskContext struct { subResultCh chan mq.Result totalItems int @@ -78,7 +77,6 @@ type loopTaskContext struct { results []json.RawMessage } -// NewDAG initializes the DAG structure func NewDAG(opts ...mq.Option) *DAG { d := &DAG{ nodes: make(map[string]*mq.Consumer), @@ -113,7 +111,6 @@ func (d *DAG) Start(ctx context.Context) error { return d.server.Start(ctx) } -// PublishTask sends a task to a specific queue func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) (*mq.Task, error) { task := mq.Task{ Payload: payload, @@ -124,48 +121,37 @@ func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, return d.server.Publish(ctx, task, queueName) } -// TaskCallback processes the completion of a task and continues the flow func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error { log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result)) - - // Check if the task belongs to a loop d.mu.Lock() loopCtx, isLoopTask := d.loopTaskMap[task.ID] d.mu.Unlock() - if isLoopTask { loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID} } - // Handle loopEdges first if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists { var items []json.RawMessage if err := json.Unmarshal(task.Result, &items); err != nil { return err } loopCtx := &loopTaskContext{ - subResultCh: make(chan mq.Result, len(items)), // Collect sub-task results + subResultCh: make(chan mq.Result, len(items)), totalItems: len(items), results: make([]json.RawMessage, 0, len(items)), } d.mu.Lock() d.loopTaskMap[task.ID] = loopCtx d.mu.Unlock() - - // Publish a sub-task for each item for _, item := range items { _, err := d.PublishTask(ctx, item, loopNode, task.ID) if err != nil { return err } } - - // Wait for loop completion go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue) return nil } - - // Handle normal edges edges, exists := d.edges[task.CurrentQueue] if exists { for _, edge := range edges { @@ -174,40 +160,40 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error { return err } } + } else { + d.mu.Lock() + if resultCh, ok := d.taskChMap[task.ID]; ok { + resultCh <- mq.Result{ + Command: "complete", + Payload: task.Result, + Queue: task.CurrentQueue, + MessageID: task.ID, + Status: "done", + } + delete(d.taskChMap, task.ID) + } + d.mu.Unlock() } - return nil } -// waitForLoopCompletion waits for all sub-tasks in a loop to finish, aggregates the results, and proceeds func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) { - // Get the loop context d.mu.Lock() loopCtx := d.loopTaskMap[taskID] d.mu.Unlock() - for result := range loopCtx.subResultCh { - // Collect results loopCtx.results = append(loopCtx.results, result.Payload) loopCtx.completed++ - - // If all sub-tasks are completed, aggregate results and continue if loopCtx.completed == loopCtx.totalItems { close(loopCtx.subResultCh) - - // Aggregate results aggregatedResult, err := json.Marshal(loopCtx.results) if err != nil { log.Printf("Error aggregating results: %v", err) return } - - // Continue flow after loop completion d.mu.Lock() delete(d.loopTaskMap, taskID) d.mu.Unlock() - - // Handle the next nodes in the DAG edges, exists := d.edges[currentQueue] if exists { for _, edge := range edges { @@ -218,7 +204,6 @@ func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQ } } } else { - // If no further edges, finalize d.mu.Lock() if resultCh, ok := d.taskChMap[taskID]; ok { resultCh <- mq.Result{ @@ -232,13 +217,10 @@ func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQ } d.mu.Unlock() } - - break } } } -// Send sends the initial task and waits for the final result func (d *DAG) Send(payload []byte) mq.Result { resultCh := make(chan mq.Result) task, err := d.PublishTask(context.TODO(), payload, "queue1")