mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-13 10:03:42 +08:00
feat: add example
This commit is contained in:
@@ -5,20 +5,22 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type TaskManager struct {
|
||||
taskID string
|
||||
dag *DAG
|
||||
wg sync.WaitGroup
|
||||
mutex sync.Mutex
|
||||
results []mq.Result
|
||||
nodeResults map[string]mq.Result
|
||||
done chan struct{}
|
||||
finalResult chan mq.Result // Channel to collect final results
|
||||
taskID string
|
||||
dag *DAG
|
||||
wg sync.WaitGroup
|
||||
mutex sync.Mutex
|
||||
results []mq.Result
|
||||
waitingCallback int64
|
||||
nodeResults map[string]mq.Result
|
||||
done chan struct{}
|
||||
finalResult chan mq.Result // Channel to collect final results
|
||||
}
|
||||
|
||||
func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
||||
@@ -26,9 +28,7 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
||||
dag: d,
|
||||
nodeResults: make(map[string]mq.Result),
|
||||
results: make([]mq.Result, 0),
|
||||
done: make(chan struct{}),
|
||||
taskID: taskID,
|
||||
finalResult: make(chan mq.Result), // Initialize finalResult channel
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,26 +37,97 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j
|
||||
if !ok {
|
||||
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
|
||||
}
|
||||
tm.wg.Add(1)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
go func() {
|
||||
tm.wg.Wait()
|
||||
close(tm.done)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return mq.Result{Error: ctx.Err()}
|
||||
case <-tm.done:
|
||||
tm.mutex.Lock()
|
||||
defer tm.mutex.Unlock()
|
||||
if len(tm.results) == 1 {
|
||||
return tm.handleResult(ctx, tm.results[0])
|
||||
if tm.dag.server.SyncMode() {
|
||||
tm.done = make(chan struct{})
|
||||
tm.wg.Add(1)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
go func() {
|
||||
tm.wg.Wait()
|
||||
close(tm.done)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return mq.Result{Error: ctx.Err()}
|
||||
case <-tm.done:
|
||||
tm.mutex.Lock()
|
||||
defer tm.mutex.Unlock()
|
||||
if len(tm.results) == 1 {
|
||||
return tm.handleResult(ctx, tm.results[0])
|
||||
}
|
||||
return tm.handleResult(ctx, tm.results)
|
||||
}
|
||||
} else {
|
||||
tm.finalResult = make(chan mq.Result)
|
||||
tm.wg.Add(1)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
go func() {
|
||||
tm.wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case result := <-tm.finalResult: // Block until a result is available
|
||||
return result
|
||||
case <-ctx.Done(): // Handle context cancellation
|
||||
return mq.Result{Error: ctx.Err()}
|
||||
}
|
||||
return tm.handleResult(ctx, tm.results)
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result {
|
||||
if result.Topic != "" {
|
||||
atomic.AddInt64(&tm.waitingCallback, -1)
|
||||
}
|
||||
node, ok := tm.dag.Nodes[result.Topic]
|
||||
if !ok {
|
||||
return result
|
||||
}
|
||||
edges := make([]Edge, len(node.Edges))
|
||||
copy(edges, node.Edges)
|
||||
if result.Status != "" {
|
||||
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
|
||||
if targetNodeKey, ok := conditions[result.Status]; ok {
|
||||
if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok {
|
||||
edges = append(edges, Edge{From: node, To: targetNode})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(edges) == 0 {
|
||||
tm.appendFinalResult(result)
|
||||
if !tm.dag.server.SyncMode() {
|
||||
var rs mq.Result
|
||||
if len(tm.results) == 1 {
|
||||
rs = tm.handleResult(ctx, tm.results[0])
|
||||
} else {
|
||||
rs = tm.handleResult(ctx, tm.results)
|
||||
}
|
||||
if tm.waitingCallback == 0 {
|
||||
tm.finalResult <- rs
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
for _, edge := range edges {
|
||||
switch edge.Type {
|
||||
case LoopEdge:
|
||||
var items []json.RawMessage
|
||||
err := json.Unmarshal(result.Payload, &items)
|
||||
if err != nil {
|
||||
tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err})
|
||||
return result
|
||||
}
|
||||
for _, item := range items {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, item)
|
||||
}
|
||||
case SimpleEdge:
|
||||
if edge.To != nil {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, result.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
return mq.Result{}
|
||||
}
|
||||
|
||||
@@ -103,6 +174,7 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) {
|
||||
}
|
||||
|
||||
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
|
||||
atomic.AddInt64(&tm.waitingCallback, 1)
|
||||
defer tm.wg.Done()
|
||||
var result mq.Result
|
||||
select {
|
||||
@@ -115,6 +187,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
|
||||
if tm.dag.server.SyncMode() {
|
||||
result = node.consumer.ProcessTask(ctx, NewTask(tm.taskID, payload, node.Key))
|
||||
result.Topic = node.Key
|
||||
result.TaskID = tm.taskID
|
||||
if result.Error != nil {
|
||||
tm.appendFinalResult(result)
|
||||
return
|
||||
@@ -130,41 +203,5 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
|
||||
tm.mutex.Lock()
|
||||
tm.nodeResults[node.Key] = result
|
||||
tm.mutex.Unlock()
|
||||
edges := make([]Edge, len(node.Edges))
|
||||
copy(edges, node.Edges)
|
||||
if result.Status != "" {
|
||||
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
|
||||
if targetNodeKey, ok := conditions[result.Status]; ok {
|
||||
if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok {
|
||||
edges = append(edges, Edge{From: node, To: targetNode})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(edges) == 0 {
|
||||
tm.appendFinalResult(result)
|
||||
return
|
||||
}
|
||||
for _, edge := range edges {
|
||||
switch edge.Type {
|
||||
case LoopEdge:
|
||||
var items []json.RawMessage
|
||||
err := json.Unmarshal(result.Payload, &items)
|
||||
if err != nil {
|
||||
tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err})
|
||||
return
|
||||
}
|
||||
for _, item := range items {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, item)
|
||||
}
|
||||
case SimpleEdge:
|
||||
if edge.To != nil {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, result.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
tm.handleCallback(ctx, result)
|
||||
}
|
||||
|
Reference in New Issue
Block a user