From 1f0727bc2a9efe31fc584684c0bfaa699cac4ea2 Mon Sep 17 00:00:00 2001 From: sujit Date: Sun, 13 Oct 2024 22:19:11 +0545 Subject: [PATCH] feat: update --- dag/dag.go | 9 +++++++- dag/task_manager.go | 43 +++++++++++++++++++------------------- dag/waitgroup.go | 51 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 23 deletions(-) create mode 100644 dag/waitgroup.go diff --git a/dag/dag.go b/dag/dag.go index 4134d58..4289d32 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -268,7 +268,14 @@ func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { } task.Topic = initialNode } - return manager.processTask(ctx, task.Topic, task.Payload) + manager.wg.Add(1) + go func() { + manager.processTask(ctx, task.Topic, task.Payload) + }() + manager.wg.Wait() + result := manager.dispatchFinalResult(ctx) + manager.finalResult <- result + return result } func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { diff --git a/dag/task_manager.go b/dag/task_manager.go index e35564d..b163bbd 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "sync" - "sync/atomic" "time" "github.com/oarkflow/mq" @@ -13,15 +12,15 @@ import ( ) type TaskManager struct { - taskID string - dag *DAG - mutex sync.Mutex - createdAt time.Time - processedAt time.Time - results []mq.Result - waitingCallback int64 - nodeResults map[string]mq.Result - finalResult chan mq.Result + taskID string + dag *DAG + mutex sync.Mutex + createdAt time.Time + processedAt time.Time + results []mq.Result + nodeResults map[string]mq.Result + finalResult chan mq.Result + wg *WaitGroup } func NewTaskManager(d *DAG, taskID string) *TaskManager { @@ -31,6 +30,7 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { results: make([]mq.Result, 0), taskID: taskID, finalResult: make(chan mq.Result, 1), + wg: NewWaitGroup(), } } @@ -68,16 +68,14 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j } } -func (tm *TaskManager) dispatchFinalResult(ctx context.Context) { +func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result { var rs mq.Result if len(tm.results) == 1 { rs = tm.handleResult(ctx, tm.results[0]) } else { rs = tm.handleResult(ctx, tm.results) } - if tm.waitingCallback == 0 { - tm.finalResult <- rs - } + return rs } func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge { @@ -100,9 +98,7 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge } func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result { - if result.Topic != "" { - atomic.AddInt64(&tm.waitingCallback, -1) - } + defer tm.wg.Done() node, ok := tm.dag.nodes[result.Topic] if !ok { return result @@ -110,7 +106,6 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. edges := tm.getConditionalEdges(node, result) if len(edges) == 0 { tm.appendFinalResult(result) - tm.dispatchFinalResult(ctx) return result } for _, edge := range edges { @@ -125,13 +120,19 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. for _, target := range edge.To { for _, item := range items { ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key}) - go tm.processNode(ctx, target, item) + tm.wg.Add(1) + go func(ctx context.Context, target *Node, item json.RawMessage) { + tm.processNode(ctx, target, item) + }(ctx, target, item) } } case Simple: for _, target := range edge.To { ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key}) - go tm.processNode(ctx, target, result.Payload) + tm.wg.Add(1) + go func(ctx context.Context, target *Node, result mq.Result) { + go tm.processNode(ctx, target, result.Payload) + }(ctx, target, result) } } } @@ -178,7 +179,6 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) { } func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { - atomic.AddInt64(&tm.waitingCallback, 1) var result mq.Result defer func() { tm.mutex.Lock() @@ -210,7 +210,6 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json } func (tm *TaskManager) Clear() error { - tm.waitingCallback = 0 clear(tm.results) tm.nodeResults = make(map[string]mq.Result) return nil diff --git a/dag/waitgroup.go b/dag/waitgroup.go new file mode 100644 index 0000000..defa02d --- /dev/null +++ b/dag/waitgroup.go @@ -0,0 +1,51 @@ +package dag + +import ( + "sync" +) + +type WaitGroup struct { + sync.Mutex + counter int + cond *sync.Cond +} + +func NewWaitGroup() *WaitGroup { + awg := &WaitGroup{} + awg.cond = sync.NewCond(&awg.Mutex) + return awg +} + +// Add increments the counter for an async task +func (awg *WaitGroup) Add(delta int) { + awg.Lock() + awg.counter += delta + awg.Unlock() +} + +// Reset sets the counter to zero and notifies waiting goroutines +func (awg *WaitGroup) Reset() { + awg.Lock() + awg.counter = 0 + awg.cond.Broadcast() // Notify any waiting goroutines that we're done + awg.Unlock() +} + +// Done decrements the counter when a task is completed +func (awg *WaitGroup) Done() { + awg.Lock() + awg.counter-- + if awg.counter == 0 { + awg.cond.Broadcast() // Notify all waiting goroutines + } + awg.Unlock() +} + +// Wait blocks until the counter is zero +func (awg *WaitGroup) Wait() { + awg.Lock() + for awg.counter > 0 { + awg.cond.Wait() // Wait for notification + } + awg.Unlock() +}