feat: update

This commit is contained in:
sujit
2024-10-13 22:19:11 +05:45
parent cf669b0a38
commit 1f0727bc2a
3 changed files with 80 additions and 23 deletions

View File

@@ -268,7 +268,14 @@ func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
} }
task.Topic = initialNode task.Topic = initialNode
} }
return manager.processTask(ctx, task.Topic, task.Payload) manager.wg.Add(1)
go func() {
manager.processTask(ctx, task.Topic, task.Payload)
}()
manager.wg.Wait()
result := manager.dispatchFinalResult(ctx)
manager.finalResult <- result
return result
} }
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
@@ -13,15 +12,15 @@ import (
) )
type TaskManager struct { type TaskManager struct {
taskID string taskID string
dag *DAG dag *DAG
mutex sync.Mutex mutex sync.Mutex
createdAt time.Time createdAt time.Time
processedAt time.Time processedAt time.Time
results []mq.Result results []mq.Result
waitingCallback int64 nodeResults map[string]mq.Result
nodeResults map[string]mq.Result finalResult chan mq.Result
finalResult chan mq.Result wg *WaitGroup
} }
func NewTaskManager(d *DAG, taskID string) *TaskManager { func NewTaskManager(d *DAG, taskID string) *TaskManager {
@@ -31,6 +30,7 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
results: make([]mq.Result, 0), results: make([]mq.Result, 0),
taskID: taskID, taskID: taskID,
finalResult: make(chan mq.Result, 1), finalResult: make(chan mq.Result, 1),
wg: NewWaitGroup(),
} }
} }
@@ -68,16 +68,14 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j
} }
} }
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) { func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
var rs mq.Result var rs mq.Result
if len(tm.results) == 1 { if len(tm.results) == 1 {
rs = tm.handleResult(ctx, tm.results[0]) rs = tm.handleResult(ctx, tm.results[0])
} else { } else {
rs = tm.handleResult(ctx, tm.results) rs = tm.handleResult(ctx, tm.results)
} }
if tm.waitingCallback == 0 { return rs
tm.finalResult <- rs
}
} }
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge { func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
@@ -100,9 +98,7 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge
} }
func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result { func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result {
if result.Topic != "" { defer tm.wg.Done()
atomic.AddInt64(&tm.waitingCallback, -1)
}
node, ok := tm.dag.nodes[result.Topic] node, ok := tm.dag.nodes[result.Topic]
if !ok { if !ok {
return result return result
@@ -110,7 +106,6 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
edges := tm.getConditionalEdges(node, result) edges := tm.getConditionalEdges(node, result)
if len(edges) == 0 { if len(edges) == 0 {
tm.appendFinalResult(result) tm.appendFinalResult(result)
tm.dispatchFinalResult(ctx)
return result return result
} }
for _, edge := range edges { for _, edge := range edges {
@@ -125,13 +120,19 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
for _, target := range edge.To { for _, target := range edge.To {
for _, item := range items { for _, item := range items {
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key}) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key})
go tm.processNode(ctx, target, item) tm.wg.Add(1)
go func(ctx context.Context, target *Node, item json.RawMessage) {
tm.processNode(ctx, target, item)
}(ctx, target, item)
} }
} }
case Simple: case Simple:
for _, target := range edge.To { for _, target := range edge.To {
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key}) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key})
go tm.processNode(ctx, target, result.Payload) tm.wg.Add(1)
go func(ctx context.Context, target *Node, result mq.Result) {
go tm.processNode(ctx, target, result.Payload)
}(ctx, target, result)
} }
} }
} }
@@ -178,7 +179,6 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) {
} }
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
atomic.AddInt64(&tm.waitingCallback, 1)
var result mq.Result var result mq.Result
defer func() { defer func() {
tm.mutex.Lock() tm.mutex.Lock()
@@ -210,7 +210,6 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
} }
func (tm *TaskManager) Clear() error { func (tm *TaskManager) Clear() error {
tm.waitingCallback = 0
clear(tm.results) clear(tm.results)
tm.nodeResults = make(map[string]mq.Result) tm.nodeResults = make(map[string]mq.Result)
return nil return nil

51
dag/waitgroup.go Normal file
View File

@@ -0,0 +1,51 @@
package dag
import (
"sync"
)
type WaitGroup struct {
sync.Mutex
counter int
cond *sync.Cond
}
func NewWaitGroup() *WaitGroup {
awg := &WaitGroup{}
awg.cond = sync.NewCond(&awg.Mutex)
return awg
}
// Add increments the counter for an async task
func (awg *WaitGroup) Add(delta int) {
awg.Lock()
awg.counter += delta
awg.Unlock()
}
// Reset sets the counter to zero and notifies waiting goroutines
func (awg *WaitGroup) Reset() {
awg.Lock()
awg.counter = 0
awg.cond.Broadcast() // Notify any waiting goroutines that we're done
awg.Unlock()
}
// Done decrements the counter when a task is completed
func (awg *WaitGroup) Done() {
awg.Lock()
awg.counter--
if awg.counter == 0 {
awg.cond.Broadcast() // Notify all waiting goroutines
}
awg.Unlock()
}
// Wait blocks until the counter is zero
func (awg *WaitGroup) Wait() {
awg.Lock()
for awg.counter > 0 {
awg.cond.Wait() // Wait for notification
}
awg.Unlock()
}