mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-06 00:16:49 +08:00
feat: update
This commit is contained in:
@@ -268,7 +268,14 @@ func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
|||||||
}
|
}
|
||||||
task.Topic = initialNode
|
task.Topic = initialNode
|
||||||
}
|
}
|
||||||
return manager.processTask(ctx, task.Topic, task.Payload)
|
manager.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
manager.processTask(ctx, task.Topic, task.Payload)
|
||||||
|
}()
|
||||||
|
manager.wg.Wait()
|
||||||
|
result := manager.dispatchFinalResult(ctx)
|
||||||
|
manager.finalResult <- result
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
|
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
|
||||||
|
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
@@ -13,15 +12,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TaskManager struct {
|
type TaskManager struct {
|
||||||
taskID string
|
taskID string
|
||||||
dag *DAG
|
dag *DAG
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
processedAt time.Time
|
processedAt time.Time
|
||||||
results []mq.Result
|
results []mq.Result
|
||||||
waitingCallback int64
|
nodeResults map[string]mq.Result
|
||||||
nodeResults map[string]mq.Result
|
finalResult chan mq.Result
|
||||||
finalResult chan mq.Result
|
wg *WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
||||||
@@ -31,6 +30,7 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
|||||||
results: make([]mq.Result, 0),
|
results: make([]mq.Result, 0),
|
||||||
taskID: taskID,
|
taskID: taskID,
|
||||||
finalResult: make(chan mq.Result, 1),
|
finalResult: make(chan mq.Result, 1),
|
||||||
|
wg: NewWaitGroup(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,16 +68,14 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) {
|
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
|
||||||
var rs mq.Result
|
var rs mq.Result
|
||||||
if len(tm.results) == 1 {
|
if len(tm.results) == 1 {
|
||||||
rs = tm.handleResult(ctx, tm.results[0])
|
rs = tm.handleResult(ctx, tm.results[0])
|
||||||
} else {
|
} else {
|
||||||
rs = tm.handleResult(ctx, tm.results)
|
rs = tm.handleResult(ctx, tm.results)
|
||||||
}
|
}
|
||||||
if tm.waitingCallback == 0 {
|
return rs
|
||||||
tm.finalResult <- rs
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
|
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
|
||||||
@@ -100,9 +98,7 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result {
|
func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result {
|
||||||
if result.Topic != "" {
|
defer tm.wg.Done()
|
||||||
atomic.AddInt64(&tm.waitingCallback, -1)
|
|
||||||
}
|
|
||||||
node, ok := tm.dag.nodes[result.Topic]
|
node, ok := tm.dag.nodes[result.Topic]
|
||||||
if !ok {
|
if !ok {
|
||||||
return result
|
return result
|
||||||
@@ -110,7 +106,6 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
|
|||||||
edges := tm.getConditionalEdges(node, result)
|
edges := tm.getConditionalEdges(node, result)
|
||||||
if len(edges) == 0 {
|
if len(edges) == 0 {
|
||||||
tm.appendFinalResult(result)
|
tm.appendFinalResult(result)
|
||||||
tm.dispatchFinalResult(ctx)
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
for _, edge := range edges {
|
for _, edge := range edges {
|
||||||
@@ -125,13 +120,19 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
|
|||||||
for _, target := range edge.To {
|
for _, target := range edge.To {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key})
|
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key})
|
||||||
go tm.processNode(ctx, target, item)
|
tm.wg.Add(1)
|
||||||
|
go func(ctx context.Context, target *Node, item json.RawMessage) {
|
||||||
|
tm.processNode(ctx, target, item)
|
||||||
|
}(ctx, target, item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case Simple:
|
case Simple:
|
||||||
for _, target := range edge.To {
|
for _, target := range edge.To {
|
||||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key})
|
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key})
|
||||||
go tm.processNode(ctx, target, result.Payload)
|
tm.wg.Add(1)
|
||||||
|
go func(ctx context.Context, target *Node, result mq.Result) {
|
||||||
|
go tm.processNode(ctx, target, result.Payload)
|
||||||
|
}(ctx, target, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -178,7 +179,6 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
|
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
|
||||||
atomic.AddInt64(&tm.waitingCallback, 1)
|
|
||||||
var result mq.Result
|
var result mq.Result
|
||||||
defer func() {
|
defer func() {
|
||||||
tm.mutex.Lock()
|
tm.mutex.Lock()
|
||||||
@@ -210,7 +210,6 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TaskManager) Clear() error {
|
func (tm *TaskManager) Clear() error {
|
||||||
tm.waitingCallback = 0
|
|
||||||
clear(tm.results)
|
clear(tm.results)
|
||||||
tm.nodeResults = make(map[string]mq.Result)
|
tm.nodeResults = make(map[string]mq.Result)
|
||||||
return nil
|
return nil
|
||||||
|
51
dag/waitgroup.go
Normal file
51
dag/waitgroup.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package dag
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WaitGroup struct {
|
||||||
|
sync.Mutex
|
||||||
|
counter int
|
||||||
|
cond *sync.Cond
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWaitGroup() *WaitGroup {
|
||||||
|
awg := &WaitGroup{}
|
||||||
|
awg.cond = sync.NewCond(&awg.Mutex)
|
||||||
|
return awg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add increments the counter for an async task
|
||||||
|
func (awg *WaitGroup) Add(delta int) {
|
||||||
|
awg.Lock()
|
||||||
|
awg.counter += delta
|
||||||
|
awg.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset sets the counter to zero and notifies waiting goroutines
|
||||||
|
func (awg *WaitGroup) Reset() {
|
||||||
|
awg.Lock()
|
||||||
|
awg.counter = 0
|
||||||
|
awg.cond.Broadcast() // Notify any waiting goroutines that we're done
|
||||||
|
awg.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done decrements the counter when a task is completed
|
||||||
|
func (awg *WaitGroup) Done() {
|
||||||
|
awg.Lock()
|
||||||
|
awg.counter--
|
||||||
|
if awg.counter == 0 {
|
||||||
|
awg.cond.Broadcast() // Notify all waiting goroutines
|
||||||
|
}
|
||||||
|
awg.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks until the counter is zero
|
||||||
|
func (awg *WaitGroup) Wait() {
|
||||||
|
awg.Lock()
|
||||||
|
for awg.counter > 0 {
|
||||||
|
awg.cond.Wait() // Wait for notification
|
||||||
|
}
|
||||||
|
awg.Unlock()
|
||||||
|
}
|
Reference in New Issue
Block a user