mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-13 15:53:41 +08:00
455 lines
13 KiB
Go
455 lines
13 KiB
Go
package dag
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/oarkflow/mq"
|
||
"github.com/oarkflow/mq/logger"
|
||
|
||
"github.com/oarkflow/mq/storage"
|
||
"github.com/oarkflow/mq/storage/memory"
|
||
)
|
||
|
||
type TaskState struct {
|
||
UpdatedAt time.Time
|
||
targetResults storage.IMap[string, mq.Result]
|
||
NodeID string
|
||
Status mq.Status
|
||
Result mq.Result
|
||
}
|
||
|
||
func newTaskState(nodeID string) *TaskState {
|
||
return &TaskState{
|
||
NodeID: nodeID,
|
||
Status: mq.Pending,
|
||
UpdatedAt: time.Now(),
|
||
targetResults: memory.New[string, mq.Result](),
|
||
}
|
||
}
|
||
|
||
type nodeResult struct {
|
||
ctx context.Context
|
||
nodeID string
|
||
status mq.Status
|
||
result mq.Result
|
||
}
|
||
|
||
type TaskManager struct {
|
||
createdAt time.Time
|
||
taskStates storage.IMap[string, *TaskState]
|
||
parentNodes storage.IMap[string, string]
|
||
childNodes storage.IMap[string, int]
|
||
deferredTasks storage.IMap[string, *task]
|
||
iteratorNodes storage.IMap[string, []Edge]
|
||
currentNodePayload storage.IMap[string, json.RawMessage]
|
||
currentNodeResult storage.IMap[string, mq.Result]
|
||
taskQueue chan *task
|
||
result *mq.Result
|
||
dag *DAG
|
||
resultQueue chan nodeResult
|
||
resultCh chan mq.Result
|
||
stopCh chan struct{}
|
||
taskID string
|
||
latency string
|
||
}
|
||
|
||
type task struct {
|
||
ctx context.Context
|
||
taskID string
|
||
nodeID string
|
||
payload json.RawMessage
|
||
}
|
||
|
||
func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task {
|
||
return &task{
|
||
ctx: ctx,
|
||
taskID: taskID,
|
||
nodeID: nodeID,
|
||
payload: payload,
|
||
}
|
||
}
|
||
|
||
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
|
||
tm := &TaskManager{
|
||
taskStates: memory.New[string, *TaskState](),
|
||
parentNodes: memory.New[string, string](),
|
||
childNodes: memory.New[string, int](),
|
||
deferredTasks: memory.New[string, *task](),
|
||
currentNodePayload: memory.New[string, json.RawMessage](),
|
||
currentNodeResult: memory.New[string, mq.Result](),
|
||
taskQueue: make(chan *task, DefaultChannelSize),
|
||
resultQueue: make(chan nodeResult, DefaultChannelSize),
|
||
iteratorNodes: iteratorNodes,
|
||
createdAt: time.Now(),
|
||
stopCh: make(chan struct{}),
|
||
resultCh: resultCh,
|
||
taskID: taskID,
|
||
dag: dag,
|
||
}
|
||
go tm.run()
|
||
go tm.waitForResult()
|
||
return tm
|
||
}
|
||
|
||
func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) {
|
||
tm.send(ctx, startNode, tm.taskID, payload)
|
||
}
|
||
|
||
func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
|
||
if index, ok := ctx.Value(ContextIndex).(string); ok {
|
||
startNode = strings.Split(startNode, Delimiter)[0]
|
||
startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index)
|
||
}
|
||
if _, exists := tm.taskStates.Get(startNode); !exists {
|
||
tm.taskStates.Set(startNode, newTaskState(startNode))
|
||
}
|
||
t := newTask(ctx, taskID, startNode, payload)
|
||
select {
|
||
case tm.taskQueue <- t:
|
||
default:
|
||
log.Println("task queue is full, dropping task.")
|
||
tm.deferredTasks.Set(taskID, t)
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) run() {
|
||
for {
|
||
select {
|
||
case <-tm.stopCh:
|
||
log.Println("Stopping TaskManager")
|
||
return
|
||
case task := <-tm.taskQueue:
|
||
tm.processNode(task)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) waitForResult() {
|
||
for {
|
||
select {
|
||
case <-tm.stopCh:
|
||
log.Println("Stopping Result Listener")
|
||
return
|
||
case nr := <-tm.resultQueue:
|
||
tm.onNodeCompleted(nr)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) processNode(exec *task) {
|
||
startTime := time.Now()
|
||
pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
|
||
node, exists := tm.dag.nodes.Get(pureNodeID)
|
||
if !exists {
|
||
tm.dag.Logger().Error("Node not found while processing node",
|
||
logger.Field{Key: "nodeID", Value: pureNodeID})
|
||
return
|
||
}
|
||
state, _ := tm.taskStates.Get(exec.nodeID)
|
||
if state == nil {
|
||
tm.dag.Logger().Warn("State not found; creating new state",
|
||
logger.Field{Key: "nodeID", Value: exec.nodeID})
|
||
state = newTaskState(exec.nodeID)
|
||
tm.taskStates.Set(exec.nodeID, state)
|
||
}
|
||
state.Status = mq.Processing
|
||
state.UpdatedAt = time.Now()
|
||
tm.currentNodePayload.Clear()
|
||
tm.currentNodeResult.Clear()
|
||
tm.currentNodePayload.Set(exec.nodeID, exec.payload)
|
||
|
||
// Execute the node’s task.
|
||
result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
|
||
// Calculate the per-node latency.
|
||
nodeLatency := time.Since(startTime)
|
||
|
||
// Log the result of node execution with comprehensive details.
|
||
logFields := []logger.Field{
|
||
{Key: "nodeID", Value: exec.nodeID},
|
||
{Key: "pureNodeID", Value: pureNodeID},
|
||
{Key: "taskID", Value: exec.taskID},
|
||
{Key: "latency", Value: nodeLatency.String()},
|
||
}
|
||
if result.Error != nil {
|
||
logFields = append(logFields, logger.Field{Key: "error", Value: result.Error.Error()})
|
||
logFields = append(logFields, logger.Field{Key: "status", Value: mq.Failed})
|
||
tm.dag.Logger().Error("Node execution failed", logFields...)
|
||
} else {
|
||
logFields = append(logFields, logger.Field{Key: "status", Value: mq.Completed})
|
||
tm.dag.Logger().Info("Node executed successfully", logFields...)
|
||
}
|
||
|
||
// If this is the last node, mark it accordingly.
|
||
isLast, err := tm.dag.IsLastNode(pureNodeID)
|
||
if err != nil {
|
||
tm.dag.Logger().Error("Error checking if node is last",
|
||
logger.Field{Key: "nodeID", Value: pureNodeID},
|
||
logger.Field{Key: "error", Value: err.Error()})
|
||
} else if isLast {
|
||
result.Last = true
|
||
}
|
||
|
||
tm.currentNodeResult.Set(exec.nodeID, result)
|
||
state.Result = result
|
||
result.Topic = node.ID
|
||
tm.updateTimestamps(&result)
|
||
if result.Error != nil {
|
||
result.Status = mq.Failed
|
||
state.Status = mq.Failed
|
||
state.Result.Status = mq.Failed
|
||
state.Result.Latency = result.Latency
|
||
tm.result = &result
|
||
tm.resultCh <- result
|
||
tm.processFinalResult(state)
|
||
return
|
||
}
|
||
result.Status = mq.Completed
|
||
state.Result.Status = mq.Completed
|
||
state.Result.Latency = result.Latency
|
||
if isLast {
|
||
tm.processFinalResult(state)
|
||
}
|
||
if node.NodeType == Page {
|
||
tm.result = &result
|
||
tm.resultCh <- result
|
||
return
|
||
}
|
||
if !isLast {
|
||
tm.handleNext(exec.ctx, node, state, result)
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) updateTimestamps(rs *mq.Result) {
|
||
rs.CreatedAt = tm.createdAt
|
||
rs.ProcessedAt = time.Now()
|
||
rs.Latency = time.Since(rs.CreatedAt).String()
|
||
}
|
||
|
||
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
|
||
state.targetResults.Set(childNode, result)
|
||
state.targetResults.Del(state.NodeID)
|
||
targetsCount, _ := tm.childNodes.Get(state.NodeID)
|
||
size := state.targetResults.Size()
|
||
nodeID := strings.Split(state.NodeID, Delimiter)
|
||
if size == targetsCount {
|
||
if size > 1 {
|
||
aggregatedData := make([]json.RawMessage, size)
|
||
i := 0
|
||
state.targetResults.ForEach(func(_ string, rs mq.Result) bool {
|
||
aggregatedData[i] = rs.Payload
|
||
i++
|
||
return true
|
||
})
|
||
aggregatedPayload, err := json.Marshal(aggregatedData)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID}
|
||
} else if size == 1 {
|
||
state.Result = state.targetResults.Values()[0]
|
||
}
|
||
state.Status = result.Status
|
||
state.Result.Status = result.Status
|
||
}
|
||
if state.Result.Payload == nil {
|
||
state.Result.Payload = result.Payload
|
||
}
|
||
state.UpdatedAt = time.Now()
|
||
if result.Ctx == nil {
|
||
result.Ctx = ctx
|
||
}
|
||
if result.Error != nil {
|
||
state.Status = mq.Failed
|
||
}
|
||
pn, ok := tm.parentNodes.Get(state.NodeID)
|
||
if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed {
|
||
state.Status = mq.Processing
|
||
tm.iteratorNodes.Del(nodeID[0])
|
||
state.targetResults.Clear()
|
||
if len(nodeID) == 2 {
|
||
ctx = context.WithValue(ctx, ContextIndex, nodeID[1])
|
||
}
|
||
toProcess := nodeResult{
|
||
ctx: ctx,
|
||
nodeID: state.NodeID,
|
||
status: state.Status,
|
||
result: state.Result,
|
||
}
|
||
tm.handleEdges(toProcess, edges)
|
||
} else if ok {
|
||
if targetsCount == size {
|
||
parentState, _ := tm.taskStates.Get(pn)
|
||
if parentState != nil {
|
||
state.Result.Topic = state.NodeID
|
||
tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
|
||
}
|
||
}
|
||
} else {
|
||
tm.updateTimestamps(&state.Result)
|
||
tm.result = &state.Result
|
||
state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0]
|
||
tm.resultCh <- state.Result
|
||
tm.processFinalResult(state)
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) {
|
||
state.UpdatedAt = time.Now()
|
||
if result.Ctx == nil {
|
||
result.Ctx = ctx
|
||
}
|
||
if result.Error != nil {
|
||
state.Status = mq.Failed
|
||
} else {
|
||
edges := tm.getConditionalEdges(node, result)
|
||
if len(edges) == 0 {
|
||
state.Status = mq.Completed
|
||
}
|
||
}
|
||
if result.Status == "" {
|
||
result.Status = state.Status
|
||
}
|
||
select {
|
||
case tm.resultQueue <- nodeResult{
|
||
ctx: ctx,
|
||
nodeID: state.NodeID,
|
||
result: result,
|
||
status: state.Status,
|
||
}:
|
||
default:
|
||
log.Println("Result queue is full, dropping result.")
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) onNodeCompleted(rs nodeResult) {
|
||
nodeID := strings.Split(rs.nodeID, Delimiter)[0]
|
||
node, ok := tm.dag.nodes.Get(nodeID)
|
||
if !ok {
|
||
return
|
||
}
|
||
edges := tm.getConditionalEdges(node, rs.result)
|
||
hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0
|
||
if hasErrorOrCompleted {
|
||
if index, ok := rs.ctx.Value(ContextIndex).(string); ok {
|
||
childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index)
|
||
pn, ok := tm.parentNodes.Get(childNode)
|
||
if ok {
|
||
parentState, _ := tm.taskStates.Get(pn)
|
||
if parentState != nil {
|
||
pn = strings.Split(pn, Delimiter)[0]
|
||
tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true)
|
||
}
|
||
} else {
|
||
tm.updateTimestamps(&rs.result)
|
||
tm.resultCh <- rs.result
|
||
if state, ok := tm.taskStates.Get(rs.nodeID); ok {
|
||
tm.processFinalResult(state)
|
||
}
|
||
}
|
||
}
|
||
return
|
||
}
|
||
tm.handleEdges(rs, edges)
|
||
}
|
||
|
||
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
|
||
edges := make([]Edge, len(node.Edges))
|
||
copy(edges, node.Edges)
|
||
if result.ConditionStatus != "" {
|
||
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
|
||
if targetNodeKey, ok := conditions[result.ConditionStatus]; ok {
|
||
if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
|
||
edges = append(edges, Edge{From: node, To: targetNode})
|
||
}
|
||
} else if targetNodeKey, ok = conditions["default"]; ok {
|
||
if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
|
||
edges = append(edges, Edge{From: node, To: targetNode})
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return edges
|
||
}
|
||
|
||
func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
|
||
for _, edge := range edges {
|
||
index, ok := currentResult.ctx.Value(ContextIndex).(string)
|
||
if !ok {
|
||
index = "0"
|
||
}
|
||
parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index)
|
||
if edge.Type == Simple {
|
||
if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok {
|
||
continue
|
||
}
|
||
}
|
||
if edge.Type == Iterator {
|
||
var items []json.RawMessage
|
||
err := json.Unmarshal(currentResult.result.Payload, &items)
|
||
if err != nil {
|
||
log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err)
|
||
tm.resultQueue <- nodeResult{
|
||
ctx: currentResult.ctx,
|
||
nodeID: edge.To.ID,
|
||
status: mq.Failed,
|
||
result: mq.Result{Error: err},
|
||
}
|
||
return
|
||
}
|
||
tm.childNodes.Set(parentNode, len(items))
|
||
for i, item := range items {
|
||
childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i)
|
||
ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i))
|
||
tm.parentNodes.Set(childNode, parentNode)
|
||
tm.send(ctx, edge.To.ID, tm.taskID, item)
|
||
}
|
||
} else {
|
||
tm.childNodes.Set(parentNode, 1)
|
||
idx, ok := currentResult.ctx.Value(ContextIndex).(string)
|
||
if !ok {
|
||
idx = "0"
|
||
}
|
||
childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
|
||
ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
|
||
tm.parentNodes.Set(childNode, parentNode)
|
||
tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) retryDeferredTasks() {
|
||
const maxRetries = 5
|
||
backoff := time.Second
|
||
|
||
for retries := 0; retries < maxRetries; retries++ {
|
||
select {
|
||
case <-tm.stopCh:
|
||
log.Println("Stopping Deferred task Retrier")
|
||
return
|
||
case <-time.After(backoff):
|
||
tm.deferredTasks.ForEach(func(taskID string, task *task) bool {
|
||
tm.send(task.ctx, task.nodeID, taskID, task.payload)
|
||
backoff = backoff * 2 // Exponential backoff
|
||
return true
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) processFinalResult(state *TaskState) {
|
||
state.Status = mq.Completed
|
||
state.targetResults.Clear()
|
||
if tm.dag.finalResult != nil {
|
||
tm.dag.finalResult(tm.taskID, state.Result)
|
||
}
|
||
}
|
||
|
||
func (tm *TaskManager) Stop() {
|
||
close(tm.stopCh)
|
||
}
|