Files
mq/dag/task_manager.go
2024-11-23 10:51:22 +05:45

389 lines
10 KiB
Go

package dag
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
)
type TaskState struct {
NodeID string
Status mq.Status
UpdatedAt time.Time
Result mq.Result
targetResults storage.IMap[string, mq.Result]
}
func newTaskState(nodeID string) *TaskState {
return &TaskState{
NodeID: nodeID,
Status: mq.Pending,
UpdatedAt: time.Now(),
targetResults: memory.New[string, mq.Result](),
}
}
type nodeResult struct {
ctx context.Context
nodeID string
status mq.Status
result mq.Result
}
type TaskManager struct {
taskStates storage.IMap[string, *TaskState]
parentNodes storage.IMap[string, string]
childNodes storage.IMap[string, int]
deferredTasks storage.IMap[string, *task]
iteratorNodes storage.IMap[string, []Edge]
currentNodePayload storage.IMap[string, json.RawMessage]
currentNodeResult storage.IMap[string, mq.Result]
result *mq.Result
dag *DAG
taskID string
taskQueue chan *task
resultQueue chan nodeResult
resultCh chan mq.Result
stopCh chan struct{}
}
type task struct {
ctx context.Context
taskID string
nodeID string
payload json.RawMessage
}
func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task {
return &task{
ctx: ctx,
taskID: taskID,
nodeID: nodeID,
payload: payload,
}
}
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
tm := &TaskManager{
taskStates: memory.New[string, *TaskState](),
parentNodes: memory.New[string, string](),
childNodes: memory.New[string, int](),
deferredTasks: memory.New[string, *task](),
currentNodePayload: memory.New[string, json.RawMessage](),
currentNodeResult: memory.New[string, mq.Result](),
taskQueue: make(chan *task, DefaultChannelSize),
resultQueue: make(chan nodeResult, DefaultChannelSize),
iteratorNodes: iteratorNodes,
stopCh: make(chan struct{}),
resultCh: resultCh,
taskID: taskID,
dag: dag,
}
go tm.run()
go tm.waitForResult()
return tm
}
func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) {
tm.send(ctx, startNode, tm.taskID, payload)
}
func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
if index, ok := ctx.Value(ContextIndex).(string); ok {
startNode = strings.Split(startNode, Delimiter)[0]
startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index)
}
if _, exists := tm.taskStates.Get(startNode); !exists {
tm.taskStates.Set(startNode, newTaskState(startNode))
}
t := newTask(ctx, taskID, startNode, payload)
select {
case tm.taskQueue <- t:
default:
log.Println("task queue is full, dropping task.")
tm.deferredTasks.Set(taskID, t)
}
}
func (tm *TaskManager) run() {
for {
select {
case <-tm.stopCh:
log.Println("Stopping TaskManager")
return
case task := <-tm.taskQueue:
tm.processNode(task)
}
}
}
func (tm *TaskManager) waitForResult() {
for {
select {
case <-tm.stopCh:
log.Println("Stopping Result Listener")
return
case nr := <-tm.resultQueue:
tm.onNodeCompleted(nr)
}
}
}
func (tm *TaskManager) processNode(exec *task) {
pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
node, exists := tm.dag.nodes.Get(pureNodeID)
if !exists {
log.Printf("Node %s does not exist while processing node\n", pureNodeID)
return
}
state, _ := tm.taskStates.Get(exec.nodeID)
if state == nil {
log.Printf("State for node %s not found; creating new state.\n", exec.nodeID)
state = newTaskState(exec.nodeID)
tm.taskStates.Set(exec.nodeID, state)
}
state.Status = mq.Processing
state.UpdatedAt = time.Now()
tm.currentNodePayload.Clear()
tm.currentNodeResult.Clear()
tm.currentNodePayload.Set(exec.nodeID, exec.payload)
result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
tm.currentNodeResult.Set(exec.nodeID, result)
state.Result = result
result.Topic = node.ID
if result.Error != nil {
tm.result = &result
tm.resultCh <- result
tm.processFinalResult(state)
return
}
if node.NodeType == Page {
tm.result = &result
tm.resultCh <- result
return
}
tm.handleNext(exec.ctx, node, state, result)
}
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
state.targetResults.Set(childNode, result)
state.targetResults.Del(state.NodeID)
targetsCount, _ := tm.childNodes.Get(state.NodeID)
size := state.targetResults.Size()
nodeID := strings.Split(state.NodeID, Delimiter)
if size == targetsCount {
if size > 1 {
aggregatedData := make([]json.RawMessage, size)
i := 0
state.targetResults.ForEach(func(_ string, rs mq.Result) bool {
aggregatedData[i] = rs.Payload
i++
return true
})
aggregatedPayload, err := json.Marshal(aggregatedData)
if err != nil {
panic(err)
}
state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID}
} else if size == 1 {
state.Result = state.targetResults.Values()[0]
}
state.Status = result.Status
state.Result.Status = result.Status
}
if state.Result.Payload == nil {
state.Result.Payload = result.Payload
}
state.UpdatedAt = time.Now()
if result.Ctx == nil {
result.Ctx = ctx
}
if result.Error != nil {
state.Status = mq.Failed
}
pn, ok := tm.parentNodes.Get(state.NodeID)
if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed {
state.Status = mq.Processing
tm.iteratorNodes.Del(nodeID[0])
state.targetResults.Clear()
if len(nodeID) == 2 {
ctx = context.WithValue(ctx, ContextIndex, nodeID[1])
}
toProcess := nodeResult{
ctx: ctx,
nodeID: state.NodeID,
status: state.Status,
result: state.Result,
}
tm.handleEdges(toProcess, edges)
} else if ok {
if targetsCount == size {
parentState, _ := tm.taskStates.Get(pn)
if parentState != nil {
state.Result.Topic = state.NodeID
tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
}
}
} else {
tm.result = &state.Result
state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0]
tm.resultCh <- state.Result
tm.processFinalResult(state)
}
}
func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) {
state.UpdatedAt = time.Now()
if result.Ctx == nil {
result.Ctx = ctx
}
if result.Error != nil {
state.Status = mq.Failed
} else {
edges := tm.getConditionalEdges(node, result)
if len(edges) == 0 {
state.Status = mq.Completed
}
}
if result.Status == "" {
result.Status = state.Status
}
select {
case tm.resultQueue <- nodeResult{
ctx: ctx,
nodeID: state.NodeID,
result: result,
status: state.Status,
}:
default:
log.Println("Result queue is full, dropping result.")
}
}
func (tm *TaskManager) onNodeCompleted(rs nodeResult) {
nodeID := strings.Split(rs.nodeID, Delimiter)[0]
node, ok := tm.dag.nodes.Get(nodeID)
if !ok {
return
}
edges := tm.getConditionalEdges(node, rs.result)
hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0
if hasErrorOrCompleted {
if index, ok := rs.ctx.Value(ContextIndex).(string); ok {
childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index)
pn, ok := tm.parentNodes.Get(childNode)
if ok {
parentState, _ := tm.taskStates.Get(pn)
if parentState != nil {
pn = strings.Split(pn, Delimiter)[0]
tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true)
}
}
}
return
}
tm.handleEdges(rs, edges)
}
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
edges := make([]Edge, len(node.Edges))
copy(edges, node.Edges)
if result.ConditionStatus != "" {
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
if targetNodeKey, ok := conditions[result.ConditionStatus]; ok {
if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
edges = append(edges, Edge{From: node, To: targetNode})
}
} else if targetNodeKey, ok = conditions["default"]; ok {
if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
edges = append(edges, Edge{From: node, To: targetNode})
}
}
}
}
return edges
}
func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
for _, edge := range edges {
index, ok := currentResult.ctx.Value(ContextIndex).(string)
if !ok {
index = "0"
}
parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index)
if edge.Type == Simple {
if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok {
continue
}
}
if edge.Type == Iterator {
var items []json.RawMessage
err := json.Unmarshal(currentResult.result.Payload, &items)
if err != nil {
log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err)
tm.resultQueue <- nodeResult{
ctx: currentResult.ctx,
nodeID: edge.To.ID,
status: mq.Failed,
result: mq.Result{Error: err},
}
return
}
tm.childNodes.Set(parentNode, len(items))
for i, item := range items {
childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i)
ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i))
tm.parentNodes.Set(childNode, parentNode)
tm.send(ctx, edge.To.ID, tm.taskID, item)
}
} else {
tm.childNodes.Set(parentNode, 1)
idx, ok := currentResult.ctx.Value(ContextIndex).(string)
if !ok {
idx = "0"
}
childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
tm.parentNodes.Set(childNode, parentNode)
tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
}
}
}
func (tm *TaskManager) retryDeferredTasks() {
const maxRetries = 5
retries := 0
for retries < maxRetries {
select {
case <-tm.stopCh:
log.Println("Stopping Deferred task Retrier")
return
case <-time.After(RetryInterval):
tm.deferredTasks.ForEach(func(taskID string, task *task) bool {
tm.send(task.ctx, task.nodeID, taskID, task.payload)
retries++
return true
})
}
}
}
func (tm *TaskManager) processFinalResult(state *TaskState) {
state.targetResults.Clear()
if tm.dag.finalResult != nil {
tm.dag.finalResult(tm.taskID, state.Result)
}
}
func (tm *TaskManager) Stop() {
close(tm.stopCh)
}