Files
mq/dag/v1/task_manager.go
2024-11-23 10:51:22 +05:45

375 lines
10 KiB
Go

package v1
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"github.com/oarkflow/mq"
)
type TaskManager struct {
createdAt time.Time
processedAt time.Time
status string
dag *DAG
taskID string
wg *WaitGroup
topic string
result mq.Result
iteratorNodes storage.IMap[string, []Edge]
taskNodeStatus storage.IMap[string, *taskNodeStatus]
}
func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
return &TaskManager{
dag: d,
taskNodeStatus: memory.New[string, *taskNodeStatus](),
taskID: taskID,
iteratorNodes: iteratorNodes,
wg: NewWaitGroup(),
}
}
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
tm.updateTS(&tm.result)
tm.dag.callbackToConsumer(ctx, tm.result)
if tm.dag.server.NotifyHandler() != nil {
_ = tm.dag.server.NotifyHandler()(ctx, tm.result)
}
tm.dag.taskCleanupCh <- tm.taskID
tm.topic = tm.result.Topic
return tm.result
}
func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) {
if tm.dag.reportNodeResultCallback != nil {
tm.dag.reportNodeResultCallback(result)
}
}
func (tm *TaskManager) SetTotalItems(topic string, i int) {
if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok {
nodeStatus.totalItems = i
}
}
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
topic := getTopic(ctx, node.Key)
tm.taskNodeStatus.Set(topic, newNodeStatus(topic))
defer mq.RecoverPanic(mq.RecoverTitle)
dag, isDAG := isDAGNode(node)
if isDAG {
if tm.dag.server.SyncMode() && !dag.server.SyncMode() {
dag.server.Options().SetSyncMode(true)
}
}
tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key})
var result mq.Result
if tm.dag.server.SyncMode() {
defer func() {
if isDAG {
result.Topic = dag.consumerTopic
result.TaskID = tm.taskID
tm.reportNodeResult(result, false)
tm.handleNextTask(result.Ctx, result)
} else {
result.Topic = node.Key
tm.reportNodeResult(result, false)
tm.handleNextTask(ctx, result)
}
}()
}
select {
case <-ctx.Done():
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return
default:
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
if tm.dag.server.SyncMode() {
result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key))
if result.Error != nil {
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return
}
return
}
err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
if err != nil {
tm.reportNodeResult(mq.Result{Error: err}, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return
}
}
}
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
defer mq.RecoverPanic(mq.RecoverTitle)
node, ok := tm.dag.nodes[nodeID]
if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
}
if tm.createdAt.IsZero() {
tm.createdAt = time.Now()
}
tm.wg.Add(1)
go func() {
ctxx := context.Background()
if headers, ok := mq.GetHeaders(ctx); ok {
headers.Set(consts.QueueKey, node.Key)
headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0))
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
go tm.processNode(ctx, node, payload)
}()
tm.wg.Wait()
requestType, ok := mq.GetHeader(ctx, "request_type")
if ok && requestType == "render" {
return tm.renderResult(ctx)
}
return tm.dispatchFinalResult(ctx)
}
func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result {
tm.topic = result.Topic
defer func() {
tm.wg.Done()
mq.RecoverPanic(mq.RecoverTitle)
}()
if result.Ctx != nil {
if headers, ok := mq.GetHeaders(ctx); ok {
ctx = mq.SetHeaders(result.Ctx, headers.AsMap())
}
}
node, ok := tm.dag.nodes[result.Topic]
if !ok {
return result
}
if result.Error != nil {
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return result
}
edges := tm.getConditionalEdges(node, result)
if len(edges) == 0 {
tm.reportNodeResult(result, true)
tm.ChangeNodeStatus(ctx, node.Key, Completed, result)
return result
} else {
tm.reportNodeResult(result, false)
}
if node.Type == Page {
return result
}
for _, edge := range edges {
switch edge.Type {
case Iterator:
var items []json.RawMessage
err := json.Unmarshal(result.Payload, &items)
if err != nil {
tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
result.Error = err
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
return result
}
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To))
for _, target := range edge.To {
for i, item := range items {
tm.wg.Add(1)
go func(ctx context.Context, target *Node, item json.RawMessage, i int) {
ctxx := context.Background()
if headers, ok := mq.GetHeaders(ctx); ok {
headers.Set(consts.QueueKey, target.Key)
headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i))
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.processNode(ctxx, target, item)
}(ctx, target, item, i)
}
}
}
}
for _, edge := range edges {
switch edge.Type {
case Simple:
if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok {
continue
}
tm.processEdge(ctx, edge, result)
}
}
return result
}
func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) {
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To))
index, _ := mq.GetHeader(ctx, "index")
if index != "" && strings.Contains(index, "__") {
index = strings.Split(index, "__")[1]
} else {
index = "0"
}
for _, target := range edge.To {
tm.wg.Add(1)
go func(ctx context.Context, target *Node, result mq.Result) {
ctxx := context.Background()
if headers, ok := mq.GetHeaders(ctx); ok {
headers.Set(consts.QueueKey, target.Key)
headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index))
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.processNode(ctxx, target, result.Payload)
}(ctx, target, result)
}
}
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
edges := make([]Edge, len(node.Edges))
copy(edges, node.Edges)
if result.ConditionStatus != "" {
if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok {
if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok {
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
}
} else if targetNodeKey, ok = conditions["default"]; ok {
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
}
}
}
}
return edges
}
func (tm *TaskManager) renderResult(ctx context.Context) mq.Result {
var rs mq.Result
tm.updateTS(&rs)
tm.dag.callbackToConsumer(ctx, rs)
tm.topic = rs.Topic
return rs
}
func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) {
topic := nodeID
if !strings.Contains(nodeID, "__") {
nodeID = getTopic(ctx, nodeID)
} else {
topic = strings.Split(nodeID, "__")[0]
}
nodeStatus, ok := tm.taskNodeStatus.Get(nodeID)
if !ok || nodeStatus == nil {
return
}
nodeStatus.markAs(rs, status)
switch status {
case Completed:
canProceed := false
edges, ok := tm.iteratorNodes.Get(topic)
if ok {
if len(edges) == 0 {
canProceed = true
} else {
nodeStatus.status = Processing
nodeStatus.totalItems = 1
nodeStatus.itemResults.Clear()
for _, edge := range edges {
tm.processEdge(ctx, edge, rs)
}
tm.iteratorNodes.Del(topic)
}
}
if canProceed || !ok {
if topic == tm.dag.startNode {
tm.result = rs
} else {
tm.markParentTask(ctx, topic, nodeID, status, rs)
}
}
case Failed:
if topic == tm.dag.startNode {
tm.result = rs
} else {
tm.markParentTask(ctx, topic, nodeID, status, rs)
}
}
}
func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) {
parentNodes, err := tm.dag.GetPreviousNodes(topic)
if err != nil {
return
}
var index string
nodeParts := strings.Split(nodeID, "__")
if len(nodeParts) == 2 {
index = nodeParts[1]
}
for _, parentNode := range parentNodes {
parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index)
parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey)
if !exists {
parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0")
parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey)
}
if exists {
parentNodeStatus.itemResults.Set(nodeID, rs)
if parentNodeStatus.IsDone() {
rt := tm.prepareResult(ctx, parentNodeStatus)
tm.ChangeNodeStatus(ctx, parentKey, status, rt)
}
}
}
}
func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result {
aggregatedOutput := make([]json.RawMessage, 0)
var status mq.Status
var topic string
var err1 error
if nodeStatus.totalItems == 1 {
rs := nodeStatus.itemResults.Values()[0]
if rs.Ctx == nil {
rs.Ctx = ctx
}
return rs
}
nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool {
if topic == "" {
topic = result.Topic
status = result.Status
}
if result.Error != nil {
err1 = result.Error
return false
}
var item json.RawMessage
err := json.Unmarshal(result.Payload, &item)
if err != nil {
err1 = err
return false
}
aggregatedOutput = append(aggregatedOutput, item)
return true
})
if err1 != nil {
return mq.HandleError(ctx, err1)
}
finalOutput, err := json.Marshal(aggregatedOutput)
if err != nil {
return mq.HandleError(ctx, err)
}
return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx}
}