Files
mq/dag/task_manager.go
2024-10-15 20:34:56 +05:45

230 lines
5.7 KiB
Go

package dag
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
)
type TaskManager struct {
createdAt time.Time
processedAt time.Time
dag *DAG
nodeResults map[string]mq.Result
wg *WaitGroup
taskID string
results []mq.Result
mutex sync.Mutex
}
func NewTaskManager(d *DAG, taskID string) *TaskManager {
return &TaskManager{
dag: d,
nodeResults: make(map[string]mq.Result),
results: make([]mq.Result, 0),
taskID: taskID,
wg: NewWaitGroup(),
}
}
func (tm *TaskManager) updateTS(result *mq.Result) {
result.CreatedAt = tm.createdAt
result.ProcessedAt = time.Now()
}
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
node, ok := tm.dag.nodes[nodeID]
if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
}
if tm.createdAt.IsZero() {
tm.createdAt = time.Now()
}
tm.wg.Add(1)
go func() {
go tm.processNode(ctx, node, payload)
}()
tm.wg.Wait()
return tm.dispatchFinalResult(ctx)
}
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
var rs mq.Result
if len(tm.results) == 1 {
rs = tm.handleResult(ctx, tm.results[0])
} else {
rs = tm.handleResult(ctx, tm.results)
}
tm.updateTS(&rs)
tm.dag.callbackToConsumer(ctx, rs)
if tm.dag.server.NotifyHandler() != nil {
tm.dag.server.NotifyHandler()(ctx, rs)
}
tm.dag.taskCleanupCh <- tm.taskID
return rs
}
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
edges := make([]Edge, len(node.Edges))
copy(edges, node.Edges)
if result.Status != "" {
if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok {
if targetNodeKey, ok := conditions[When(result.Status)]; ok {
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
}
} else if targetNodeKey, ok = conditions["default"]; ok {
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
}
}
}
}
return edges
}
func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result {
defer tm.wg.Done()
node, ok := tm.dag.nodes[result.Topic]
if !ok {
return result
}
edges := tm.getConditionalEdges(node, result)
if len(edges) == 0 {
tm.appendFinalResult(result)
return result
}
for _, edge := range edges {
switch edge.Type {
case Iterator:
var items []json.RawMessage
err := json.Unmarshal(result.Payload, &items)
if err != nil {
tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err})
return result
}
for _, target := range edge.To {
for _, item := range items {
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key})
tm.wg.Add(1)
go func(ctx context.Context, target *Node, item json.RawMessage) {
tm.processNode(ctx, target, item)
}(ctx, target, item)
}
}
case Simple:
for _, target := range edge.To {
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: target.Key})
tm.wg.Add(1)
go func(ctx context.Context, target *Node, result mq.Result) {
tm.processNode(ctx, target, result.Payload)
}(ctx, target, result)
}
}
}
return result
}
func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result {
var rs mq.Result
switch res := results.(type) {
case []mq.Result:
aggregatedOutput := make([]json.RawMessage, 0)
var status, topic string
for i, result := range res {
if i == 0 {
status = result.Status
topic = result.Topic
}
if result.Error != nil {
return mq.HandleError(ctx, result.Error)
}
var item json.RawMessage
err := json.Unmarshal(result.Payload, &item)
if err != nil {
return mq.HandleError(ctx, err)
}
aggregatedOutput = append(aggregatedOutput, item)
}
finalOutput, err := json.Marshal(aggregatedOutput)
if err != nil {
return mq.HandleError(ctx, err)
}
return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx}
case mq.Result:
if res.Ctx == nil {
res.Ctx = ctx
}
return res
}
if rs.Ctx == nil {
rs.Ctx = ctx
}
return rs
}
func (tm *TaskManager) appendFinalResult(result mq.Result) {
tm.mutex.Lock()
tm.results = append(tm.results, result)
tm.nodeResults[result.Topic] = result
tm.mutex.Unlock()
}
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
dag, isDAG := isDAGNode(node)
if isDAG {
if tm.dag.server.SyncMode() && !dag.server.SyncMode() {
dag.server.Options().SetSyncMode(true)
}
}
var result mq.Result
if tm.dag.server.SyncMode() {
defer func() {
tm.mutex.Lock()
tm.nodeResults[node.Key] = result
tm.mutex.Unlock()
tm.handleCallback(ctx, result)
}()
}
select {
case <-ctx.Done():
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
tm.appendFinalResult(result)
return
default:
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
if tm.dag.server.SyncMode() {
result = node.ProcessTask(ctx, NewTask(tm.taskID, payload, node.Key))
if isDAG {
result.Topic = dag.consumerTopic
result.TaskID = tm.taskID
}
if result.Error != nil {
tm.appendFinalResult(result)
return
}
return
}
err := tm.dag.server.Publish(ctx, NewTask(tm.taskID, payload, node.Key), node.Key)
if err != nil {
tm.appendFinalResult(mq.Result{Error: err})
return
}
}
}
func isDAGNode(node *Node) (*DAG, bool) {
switch node := node.processor.(type) {
case *DAG:
return node, true
default:
return nil, false
}
}