diff --git a/examples/dag.go b/examples/dag.go index e59d722..8ce672f 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -13,14 +13,17 @@ import ( func main() { dag := NewDAG() + dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result { log.Printf("Handling task for queue1: %s", string(task.Payload)) return mq.Result{Payload: task.Payload, MessageID: task.ID} }) + dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result { log.Printf("Handling task for queue2: %s", string(task.Payload)) return mq.Result{Payload: task.Payload, MessageID: task.ID} }) + dag.AddNode("queue3", func(ctx context.Context, task mq.Task) mq.Result { var data map[string]any err := json.Unmarshal(task.Payload, &data) @@ -29,16 +32,21 @@ func main() { } data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) bt, _ := json.Marshal(data) - log.Printf("Handling task for queue3: %s", string(task.Payload)) + log.Printf("Handling task for queue3: %s", string(bt)) return mq.Result{Payload: bt, MessageID: task.ID} }) + dag.AddNode("queue4", func(ctx context.Context, task mq.Task) mq.Result { log.Printf("Handling task for queue4: %s", string(task.Payload)) return mq.Result{Payload: task.Payload, MessageID: task.ID} }) - dag.AddEdge("queue1", "queue2") - dag.AddLoop("queue2", "queue3") // Add loop to handle array + // Define edges and loops + dag.AddEdge("queue1", "queue2") + dag.AddLoop("queue2", "queue3") // Loop through queue3 for each item from queue2 + dag.AddEdge("queue2", "queue4") // After processing queue2 (including loop), continue to queue4 + + // Send task payload go func() { time.Sleep(2 * time.Second) finalResult := dag.Send([]byte(`[{"user_id": 1}, {"user_id": 2}]`)) @@ -51,14 +59,14 @@ func main() { } } -// DAG struct to handle tasks and loops +// DAG struct to handle tasks, edges, and loops type DAG struct { server *mq.Broker nodes map[string]*mq.Consumer edges map[string][]string - loopEdges map[string]string // Handles loop edges + loopEdges map[string]string // Tracks loop edges taskChMap map[string]chan mq.Result - loopTaskMap map[string]*loopTaskContext // Map to handle loop tasks + loopTaskMap map[string]*loopTaskContext // Manages loop tasks mu sync.Mutex } @@ -70,7 +78,7 @@ type loopTaskContext struct { results []json.RawMessage } -// NewDAG initializes the DAG structure with necessary fields +// NewDAG initializes the DAG structure func NewDAG(opts ...mq.Option) *DAG { d := &DAG{ nodes: make(map[string]*mq.Consumer), @@ -105,7 +113,7 @@ func (d *DAG) Start(ctx context.Context) error { return d.server.Start(ctx) } -// PublishTask sends a task to a queue +// PublishTask sends a task to a specific queue func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) (*mq.Task, error) { task := mq.Task{ Payload: payload, @@ -116,39 +124,48 @@ func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, return d.server.Publish(ctx, task, queueName) } -// TaskCallback is called when a task completes and decides the next step +// TaskCallback processes the completion of a task and continues the flow func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error { log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result)) + + // Check if the task belongs to a loop d.mu.Lock() loopCtx, isLoopTask := d.loopTaskMap[task.ID] d.mu.Unlock() + if isLoopTask { loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID} } + + // Handle loopEdges first if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists { var items []json.RawMessage if err := json.Unmarshal(task.Result, &items); err != nil { return err } loopCtx := &loopTaskContext{ - subResultCh: make(chan mq.Result, len(items)), // A channel to collect sub-task results + subResultCh: make(chan mq.Result, len(items)), // Collect sub-task results totalItems: len(items), results: make([]json.RawMessage, 0, len(items)), } d.mu.Lock() d.loopTaskMap[task.ID] = loopCtx d.mu.Unlock() + + // Publish a sub-task for each item for _, item := range items { _, err := d.PublishTask(ctx, item, loopNode, task.ID) if err != nil { return err } } + + // Wait for loop completion go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue) return nil } - // Normal edge processing + // Handle normal edges edges, exists := d.edges[task.CurrentQueue] if exists { for _, edge := range edges { @@ -158,10 +175,11 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error { } } } + return nil } -// waitForLoopCompletion waits until all sub-tasks are processed, aggregates results, and proceeds. +// waitForLoopCompletion waits for all sub-tasks in a loop to finish, aggregates the results, and proceeds func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) { // Get the loop context d.mu.Lock() @@ -169,38 +187,52 @@ func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQ d.mu.Unlock() for result := range loopCtx.subResultCh { - // Collect the result + // Collect results loopCtx.results = append(loopCtx.results, result.Payload) loopCtx.completed++ - // If all sub-tasks are completed, aggregate results and proceed + + // If all sub-tasks are completed, aggregate results and continue if loopCtx.completed == loopCtx.totalItems { close(loopCtx.subResultCh) - // Aggregate the results + // Aggregate results aggregatedResult, err := json.Marshal(loopCtx.results) if err != nil { log.Printf("Error aggregating results: %v", err) return } - d.mu.Lock() - if resultCh, ok := d.taskChMap[taskID]; ok { - result := mq.Result{ - Command: "complete", - Payload: aggregatedResult, - Queue: currentQueue, - MessageID: taskID, - Status: "done", - } - resultCh <- result - delete(d.taskChMap, taskID) - } - d.mu.Unlock() - // Remove the loop context + // Continue flow after loop completion d.mu.Lock() delete(d.loopTaskMap, taskID) d.mu.Unlock() + // Handle the next nodes in the DAG + edges, exists := d.edges[currentQueue] + if exists { + for _, edge := range edges { + _, err := d.PublishTask(ctx, aggregatedResult, edge, taskID) + if err != nil { + log.Printf("Error publishing aggregated result: %v", err) + return + } + } + } else { + // If no further edges, finalize + d.mu.Lock() + if resultCh, ok := d.taskChMap[taskID]; ok { + resultCh <- mq.Result{ + Command: "complete", + Payload: aggregatedResult, + Queue: currentQueue, + MessageID: taskID, + Status: "done", + } + delete(d.taskChMap, taskID) + } + d.mu.Unlock() + } + break } }