mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-29 13:22:10 +08:00
1537 lines
48 KiB
Go
1537 lines
48 KiB
Go
package dag
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"math/rand" // ...new import for jitter...
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/oarkflow/json"
|
|
|
|
"github.com/oarkflow/mq"
|
|
"github.com/oarkflow/mq/consts"
|
|
dagstorage "github.com/oarkflow/mq/dag/storage" // Import dag storage package with alias
|
|
"github.com/oarkflow/mq/logger"
|
|
mqstorage "github.com/oarkflow/mq/storage"
|
|
"github.com/oarkflow/mq/storage/memory"
|
|
)
|
|
|
|
// TaskError is used by node processors to indicate whether an error is recoverable.
|
|
type TaskError struct {
|
|
Err error
|
|
Recoverable bool
|
|
}
|
|
|
|
func (te TaskError) Error() string {
|
|
return te.Err.Error()
|
|
}
|
|
|
|
// TaskState holds state and intermediate results for a given task (identified by a node ID).
|
|
type TaskState struct {
|
|
UpdatedAt time.Time
|
|
targetResults mqstorage.IMap[string, mq.Result]
|
|
NodeID string
|
|
Status mq.Status
|
|
Result mq.Result
|
|
}
|
|
|
|
func newTaskState(nodeID string) *TaskState {
|
|
return &TaskState{
|
|
NodeID: nodeID,
|
|
Status: mq.Pending,
|
|
UpdatedAt: time.Now(),
|
|
targetResults: memory.New[string, mq.Result](),
|
|
}
|
|
}
|
|
|
|
type nodeResult struct {
|
|
ctx context.Context
|
|
nodeID string
|
|
status mq.Status
|
|
result mq.Result
|
|
}
|
|
|
|
type task struct {
|
|
ctx context.Context
|
|
taskID string
|
|
nodeID string
|
|
payload json.RawMessage
|
|
}
|
|
|
|
func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task {
|
|
return &task{
|
|
ctx: ctx,
|
|
taskID: taskID,
|
|
nodeID: nodeID,
|
|
payload: payload,
|
|
}
|
|
}
|
|
|
|
type TaskManagerConfig struct {
|
|
MaxRetries int
|
|
BaseBackoff time.Duration
|
|
RecoveryHandler func(ctx context.Context, result mq.Result) error
|
|
}
|
|
|
|
type TaskManager struct {
|
|
createdAt time.Time
|
|
taskStates mqstorage.IMap[string, *TaskState]
|
|
parentNodes mqstorage.IMap[string, string]
|
|
childNodes mqstorage.IMap[string, int]
|
|
deferredTasks mqstorage.IMap[string, *task]
|
|
iteratorNodes mqstorage.IMap[string, []Edge]
|
|
currentNodePayload mqstorage.IMap[string, json.RawMessage]
|
|
currentNodeResult mqstorage.IMap[string, mq.Result]
|
|
taskQueue chan *task
|
|
result *mq.Result
|
|
resultQueue chan nodeResult
|
|
resultCh chan mq.Result
|
|
stopCh chan struct{}
|
|
taskID string
|
|
dag *DAG
|
|
maxRetries int
|
|
baseBackoff time.Duration
|
|
recoveryHandler func(ctx context.Context, result mq.Result) error
|
|
pauseMu sync.Mutex
|
|
pauseCh chan struct{}
|
|
wg sync.WaitGroup
|
|
storage dagstorage.TaskStorage // Added TaskStorage for persistence
|
|
}
|
|
|
|
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes mqstorage.IMap[string, []Edge], taskStorage dagstorage.TaskStorage) *TaskManager {
|
|
config := TaskManagerConfig{
|
|
MaxRetries: 3,
|
|
BaseBackoff: time.Second,
|
|
}
|
|
tm := &TaskManager{
|
|
createdAt: time.Now(),
|
|
taskStates: memory.New[string, *TaskState](),
|
|
parentNodes: memory.New[string, string](),
|
|
childNodes: memory.New[string, int](),
|
|
deferredTasks: memory.New[string, *task](),
|
|
currentNodePayload: memory.New[string, json.RawMessage](),
|
|
currentNodeResult: memory.New[string, mq.Result](),
|
|
taskQueue: make(chan *task, DefaultChannelSize),
|
|
resultQueue: make(chan nodeResult, DefaultChannelSize),
|
|
resultCh: resultCh,
|
|
stopCh: make(chan struct{}),
|
|
taskID: taskID,
|
|
dag: dag,
|
|
maxRetries: config.MaxRetries,
|
|
baseBackoff: config.BaseBackoff,
|
|
recoveryHandler: config.RecoveryHandler,
|
|
iteratorNodes: iteratorNodes,
|
|
storage: taskStorage,
|
|
}
|
|
|
|
tm.wg.Add(3)
|
|
go tm.run()
|
|
go tm.waitForResult()
|
|
go tm.retryDeferredTasks()
|
|
|
|
return tm
|
|
}
|
|
|
|
func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) {
|
|
tm.enqueueTask(ctx, startNode, tm.taskID, payload)
|
|
}
|
|
|
|
func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("enqueueTask called",
|
|
logger.Field{Key: "startNode", Value: startNode},
|
|
logger.Field{Key: "taskID", Value: taskID},
|
|
logger.Field{Key: "payloadSize", Value: len(payload)})
|
|
}
|
|
|
|
if index, ok := ctx.Value(ContextIndex).(string); ok {
|
|
base := strings.Split(startNode, Delimiter)[0]
|
|
startNode = fmt.Sprintf("%s%s%s", base, Delimiter, index)
|
|
}
|
|
if _, exists := tm.taskStates.Get(startNode); !exists {
|
|
tm.taskStates.Set(startNode, newTaskState(startNode))
|
|
}
|
|
t := newTask(ctx, taskID, startNode, payload)
|
|
// Persist task to storage
|
|
if tm.storage != nil {
|
|
persistentTask := &dagstorage.PersistentTask{
|
|
ID: taskID,
|
|
DAGID: tm.dag.key,
|
|
NodeID: startNode,
|
|
CurrentNodeID: startNode,
|
|
ProcessingState: "enqueued",
|
|
Status: dagstorage.TaskStatusPending,
|
|
Payload: payload,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
MaxRetries: tm.maxRetries,
|
|
}
|
|
if err := tm.storage.SaveTask(ctx, persistentTask); err != nil {
|
|
tm.dag.Logger().Error("Failed to persist task", logger.Field{Key: "taskID", Value: taskID}, logger.Field{Key: "error", Value: err.Error()})
|
|
} else {
|
|
// Log task creation activity
|
|
tm.logActivity(ctx, taskID, startNode, "task_created", "Task enqueued for processing", nil)
|
|
}
|
|
}
|
|
select {
|
|
case tm.taskQueue <- t:
|
|
// Successfully enqueued
|
|
default:
|
|
// Queue is full, add to deferred tasks with limit
|
|
if tm.deferredTasks.Size() < 1000 { // Limit deferred tasks to prevent memory issues
|
|
tm.deferredTasks.Set(taskID, t)
|
|
tm.dag.Logger().Warn("Task queue full, deferring task",
|
|
logger.Field{Key: "taskID", Value: taskID},
|
|
logger.Field{Key: "nodeID", Value: startNode})
|
|
} else {
|
|
tm.dag.Logger().Error("Deferred tasks queue also full, dropping task",
|
|
logger.Field{Key: "taskID", Value: taskID},
|
|
logger.Field{Key: "nodeID", Value: startNode})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) run() {
|
|
defer tm.wg.Done()
|
|
for {
|
|
select {
|
|
case <-tm.stopCh:
|
|
log.Println("Stopping TaskManager run loop")
|
|
return
|
|
default:
|
|
tm.pauseMu.Lock()
|
|
pch := tm.pauseCh
|
|
tm.pauseMu.Unlock()
|
|
if pch != nil {
|
|
select {
|
|
case <-tm.stopCh:
|
|
log.Println("Stopping TaskManager run loop during pause")
|
|
return
|
|
case <-pch:
|
|
// Resume from pause
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-tm.stopCh:
|
|
log.Println("Stopping TaskManager run loop")
|
|
return
|
|
case tsk := <-tm.taskQueue:
|
|
tm.processNode(tsk)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// waitForResult listens for node results on resultQueue and processes them.
|
|
func (tm *TaskManager) waitForResult() {
|
|
defer tm.wg.Done()
|
|
for {
|
|
select {
|
|
case <-tm.stopCh:
|
|
log.Println("Stopping TaskManager result listener")
|
|
return
|
|
case nr := <-tm.resultQueue:
|
|
select {
|
|
case <-tm.stopCh:
|
|
log.Println("Stopping TaskManager result listener during processing")
|
|
return
|
|
default:
|
|
tm.onNodeCompleted(nr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// areDependenciesMet checks if all previous nodes have completed successfully
|
|
func (tm *TaskManager) areDependenciesMet(nodeID string) bool {
|
|
pureNodeID := strings.Split(nodeID, Delimiter)[0]
|
|
|
|
// Get previous nodes
|
|
prevNodes, err := tm.dag.GetPreviousNodes(pureNodeID)
|
|
if err != nil {
|
|
tm.dag.Logger().Error("Error getting previous nodes", logger.Field{Key: "nodeID", Value: nodeID}, logger.Field{Key: "error", Value: err.Error()})
|
|
return false
|
|
}
|
|
|
|
// For iterator nodes, we need to be more selective about dependencies
|
|
// Iterator nodes should only depend on nodes that provide data to them,
|
|
// not on nodes that they create (which would be circular dependencies)
|
|
node, exists := tm.dag.nodes.Get(pureNodeID)
|
|
if exists {
|
|
// Check if this node has any iterator edges (meaning it's an iterator node)
|
|
hasIteratorEdges := false
|
|
for _, edge := range node.Edges {
|
|
if edge.Type == Iterator {
|
|
hasIteratorEdges = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if hasIteratorEdges {
|
|
// For iterator nodes, only check dependencies from Simple edges
|
|
// Iterator edges represent outputs, not inputs
|
|
filteredPrevNodes := make([]*Node, 0)
|
|
for _, prevNode := range prevNodes {
|
|
// Check if there's a Simple edge from prevNode to this node
|
|
hasSimpleEdge := false
|
|
for _, edge := range prevNode.Edges {
|
|
if edge.To.ID == pureNodeID && edge.Type == Simple {
|
|
hasSimpleEdge = true
|
|
break
|
|
}
|
|
}
|
|
if hasSimpleEdge {
|
|
filteredPrevNodes = append(filteredPrevNodes, prevNode)
|
|
}
|
|
}
|
|
prevNodes = filteredPrevNodes
|
|
}
|
|
}
|
|
|
|
// Check if all relevant previous nodes have completed successfully
|
|
for _, prevNode := range prevNodes {
|
|
// Check both the pure node ID and the indexed node ID for state
|
|
state, exists := tm.taskStates.Get(prevNode.ID)
|
|
if !exists {
|
|
// Also check if there's a state with an index suffix
|
|
tm.taskStates.ForEach(func(key string, s *TaskState) bool {
|
|
if strings.Split(key, Delimiter)[0] == prevNode.ID {
|
|
state = s
|
|
exists = true
|
|
return false // Stop iteration
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
if !exists || state.Status != mq.Completed {
|
|
tm.dag.Logger().Debug("Dependency not met",
|
|
logger.Field{Key: "nodeID", Value: nodeID},
|
|
logger.Field{Key: "dependency", Value: prevNode.ID},
|
|
logger.Field{Key: "stateExists", Value: exists},
|
|
logger.Field{Key: "stateStatus", Value: string(state.Status)},
|
|
logger.Field{Key: "taskID", Value: tm.taskID})
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (tm *TaskManager) processNode(exec *task) {
|
|
startTime := time.Now()
|
|
pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
|
|
node, exists := tm.dag.nodes.Get(pureNodeID)
|
|
if !exists {
|
|
tm.dag.Logger().Error("Node not found", logger.Field{Key: "nodeID", Value: pureNodeID})
|
|
return
|
|
}
|
|
|
|
// Check if all dependencies are met before processing
|
|
if !tm.areDependenciesMet(pureNodeID) {
|
|
tm.dag.Logger().Warn("Dependencies not met for node, deferring", logger.Field{Key: "nodeID", Value: pureNodeID})
|
|
// Defer the task
|
|
tm.deferredTasks.Set(exec.taskID, exec)
|
|
return
|
|
}
|
|
|
|
// Wrap context with timeout if node.Timeout is configured.
|
|
if node.Timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
exec.ctx, cancel = context.WithTimeout(exec.ctx, node.Timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
// Check for context cancellation before processing
|
|
select {
|
|
case <-exec.ctx.Done():
|
|
tm.dag.Logger().Warn("Context cancelled before node processing",
|
|
logger.Field{Key: "nodeID", Value: exec.nodeID},
|
|
logger.Field{Key: "taskID", Value: exec.taskID})
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Invoke PreProcessHook if available.
|
|
if tm.dag.PreProcessHook != nil {
|
|
exec.ctx = tm.dag.PreProcessHook(exec.ctx, node, exec.taskID, exec.payload)
|
|
}
|
|
|
|
// Debug logging before processing
|
|
if node.IsDebugEnabled(tm.dag.IsDebugEnabled()) {
|
|
tm.debugNodeStart(exec, node)
|
|
}
|
|
|
|
state, _ := tm.taskStates.Get(exec.nodeID)
|
|
if state == nil {
|
|
tm.dag.Logger().Warn("State not found; creating new state", logger.Field{Key: "nodeID", Value: exec.nodeID})
|
|
state = newTaskState(exec.nodeID)
|
|
tm.taskStates.Set(exec.nodeID, state)
|
|
}
|
|
state.Status = mq.Processing
|
|
state.UpdatedAt = time.Now()
|
|
tm.currentNodePayload.Clear()
|
|
// Update task status in storage
|
|
if tm.storage != nil {
|
|
// Update task position and status
|
|
if err := tm.updateTaskPosition(exec.ctx, exec.taskID, pureNodeID, "processing"); err != nil {
|
|
tm.dag.Logger().Error("Failed to update task position", logger.Field{Key: "taskID", Value: exec.taskID}, logger.Field{Key: "error", Value: err.Error()})
|
|
}
|
|
if err := tm.storage.UpdateTaskStatus(exec.ctx, exec.taskID, dagstorage.TaskStatusRunning, ""); err != nil {
|
|
tm.dag.Logger().Error("Failed to update task status", logger.Field{Key: "taskID", Value: exec.taskID}, logger.Field{Key: "error", Value: err.Error()})
|
|
}
|
|
// Log node processing start
|
|
tm.logActivity(exec.ctx, exec.taskID, pureNodeID, "node_processing_started", "Node processing started", nil)
|
|
}
|
|
tm.currentNodeResult.Clear()
|
|
tm.currentNodePayload.Set(exec.nodeID, exec.payload)
|
|
|
|
var result mq.Result
|
|
attempts := 0
|
|
for {
|
|
// log.Printf("Tracing: Start processing node %s (attempt %d) on flow %s", exec.nodeID, attempts+1, tm.dag.key)
|
|
// Get middlewares for this node
|
|
middlewares := tm.dag.getNodeMiddlewares(pureNodeID)
|
|
|
|
// Execute middlewares and processor
|
|
result = tm.dag.executeMiddlewares(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID, mq.WithDAG(tm.dag)), middlewares, func(ctx context.Context, task *mq.Task) mq.Result {
|
|
return node.processor.ProcessTask(ctx, task)
|
|
})
|
|
if result.Error != nil {
|
|
if te, ok := result.Error.(TaskError); ok && te.Recoverable {
|
|
if attempts < tm.maxRetries {
|
|
attempts++
|
|
backoff := tm.baseBackoff * time.Duration(1<<attempts)
|
|
// add jitter to avoid thundering herd
|
|
jitter := time.Duration(rand.Int63n(int64(tm.baseBackoff)))
|
|
backoff += jitter
|
|
tm.dag.Logger().Warn("Recoverable error on node, retrying",
|
|
logger.Field{Key: "nodeID", Value: exec.nodeID},
|
|
logger.Field{Key: "attempt", Value: attempts},
|
|
logger.Field{Key: "backoff", Value: backoff.String()},
|
|
logger.Field{Key: "error", Value: result.Error.Error()})
|
|
select {
|
|
case <-time.After(backoff):
|
|
case <-exec.ctx.Done():
|
|
tm.dag.Logger().Warn("Context cancelled for node", logger.Field{Key: "nodeID", Value: exec.nodeID})
|
|
return
|
|
}
|
|
continue
|
|
} else if tm.recoveryHandler != nil {
|
|
if err := tm.recoveryHandler(exec.ctx, result); err == nil {
|
|
result.Error = nil
|
|
result.Status = mq.Completed
|
|
} else {
|
|
result.Error = fmt.Errorf("recovery failed for node %s: %w", exec.nodeID, err)
|
|
}
|
|
}
|
|
} else {
|
|
// Wrap non-recoverable errors with context
|
|
result.Error = fmt.Errorf("node %s failed: %w", exec.nodeID, result.Error)
|
|
}
|
|
}
|
|
break
|
|
}
|
|
|
|
// Reset Last flag for sub-DAG results to prevent premature final result processing
|
|
if _, isSubDAG := node.processor.(*DAG); isSubDAG {
|
|
result.Last = false
|
|
}
|
|
// log.Printf("Tracing: End processing node %s on flow %s", exec.nodeID, tm.dag.key)
|
|
nodeLatency := time.Since(startTime)
|
|
|
|
// Invoke PostProcessHook if available.
|
|
if tm.dag.PostProcessHook != nil {
|
|
tm.dag.PostProcessHook(exec.ctx, node, exec.taskID, result)
|
|
}
|
|
|
|
// Debug logging after processing
|
|
if node.IsDebugEnabled(tm.dag.IsDebugEnabled()) {
|
|
tm.debugNodeComplete(exec, node, result, nodeLatency, attempts)
|
|
}
|
|
|
|
if result.Error != nil {
|
|
result.Status = mq.Failed
|
|
state.Status = mq.Failed
|
|
state.Result.Status = mq.Failed
|
|
state.Result.Latency = nodeLatency.String()
|
|
tm.result = &result
|
|
tm.resultCh <- result
|
|
tm.processFinalResult(state)
|
|
return
|
|
}
|
|
result.Status = mq.Completed
|
|
state.Result = result
|
|
state.Result.Status = mq.Completed
|
|
state.Result.Latency = nodeLatency.String()
|
|
state.Status = mq.Completed // <-- Add this line to set state status
|
|
result.Topic = node.ID
|
|
tm.updateTimestamps(&result)
|
|
|
|
isLast, err := tm.dag.IsLastNode(pureNodeID)
|
|
if err != nil {
|
|
tm.dag.Logger().Error("Error checking if node is last", logger.Field{Key: "nodeID", Value: pureNodeID}, logger.Field{Key: "error", Value: err.Error()})
|
|
} else if isLast {
|
|
// Check if this node has a parent (part of iterator pattern)
|
|
// If it has a parent, it should not be treated as a final node
|
|
if _, hasParent := tm.parentNodes.Get(exec.nodeID); !hasParent {
|
|
result.Last = true
|
|
}
|
|
}
|
|
tm.currentNodeResult.Set(exec.nodeID, result)
|
|
tm.logNodeExecution(exec, pureNodeID, result, nodeLatency)
|
|
|
|
if result.Error != nil {
|
|
tm.result = &result
|
|
tm.resultCh <- result
|
|
tm.processFinalResult(state)
|
|
return
|
|
}
|
|
if result.Last || node.NodeType == Page {
|
|
if node.NodeType == Page {
|
|
exec.ctx = context.WithValue(exec.ctx, consts.ContentType, consts.TypeHtml)
|
|
result.Ctx = context.WithValue(result.Ctx, consts.ContentType, consts.TypeHtml)
|
|
}
|
|
tm.result = &result
|
|
tm.resultCh <- result
|
|
if result.Last {
|
|
tm.processFinalResult(state)
|
|
}
|
|
return
|
|
}
|
|
|
|
tm.handleNext(exec.ctx, node, state, result)
|
|
}
|
|
|
|
// logNodeExecution logs node execution details
|
|
func (tm *TaskManager) logNodeExecution(exec *task, pureNodeID string, result mq.Result, latency time.Duration) {
|
|
success := result.Error == nil
|
|
|
|
// Log to DAG activity logger if available
|
|
if tm.dag.activityLogger != nil {
|
|
ctx := context.WithValue(exec.ctx, "task_id", exec.taskID)
|
|
ctx = context.WithValue(ctx, "node_id", pureNodeID)
|
|
ctx = context.WithValue(ctx, "duration", latency)
|
|
if result.Error != nil {
|
|
ctx = context.WithValue(ctx, "error", result.Error)
|
|
}
|
|
|
|
tm.dag.activityLogger.LogNodeExecution(ctx, exec.taskID, pureNodeID, result, latency)
|
|
}
|
|
|
|
// Update monitoring metrics
|
|
if tm.dag.monitor != nil {
|
|
tm.dag.monitor.metrics.RecordNodeExecution(pureNodeID, latency, success)
|
|
}
|
|
|
|
// Log to standard logger
|
|
fields := []logger.Field{
|
|
{Key: "nodeID", Value: pureNodeID},
|
|
{Key: "taskID", Value: exec.taskID},
|
|
{Key: "duration", Value: latency.String()},
|
|
{Key: "success", Value: success},
|
|
}
|
|
|
|
if result.Error != nil {
|
|
fields = append(fields, logger.Field{Key: "error", Value: result.Error.Error()})
|
|
tm.dag.Logger().Error("Node execution failed", fields...)
|
|
} else {
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Node execution completed", fields...)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) updateTimestamps(rs *mq.Result) {
|
|
rs.CreatedAt = tm.createdAt
|
|
rs.ProcessedAt = time.Now()
|
|
rs.Latency = time.Since(rs.CreatedAt).String()
|
|
}
|
|
|
|
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("handlePrevious called",
|
|
logger.Field{Key: "parentNodeID", Value: state.NodeID},
|
|
logger.Field{Key: "childNode", Value: childNode})
|
|
}
|
|
|
|
state.targetResults.Set(childNode, result)
|
|
state.targetResults.Del(state.NodeID)
|
|
targetsCount, _ := tm.childNodes.Get(state.NodeID)
|
|
size := state.targetResults.Size()
|
|
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Aggregation check",
|
|
logger.Field{Key: "parentNodeID", Value: state.NodeID},
|
|
logger.Field{Key: "targetsCount", Value: targetsCount},
|
|
logger.Field{Key: "currentSize", Value: size})
|
|
}
|
|
|
|
if size == targetsCount {
|
|
if size > 1 {
|
|
aggregated := make([]json.RawMessage, size)
|
|
i := 0
|
|
state.targetResults.ForEach(func(_ string, res mq.Result) bool {
|
|
aggregated[i] = res.Payload
|
|
i++
|
|
return true
|
|
})
|
|
aggregatedPayload, err := json.Marshal(aggregated)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID}
|
|
} else if size == 1 {
|
|
state.Result = state.targetResults.Values()[0]
|
|
}
|
|
state.Status = result.Status
|
|
state.Result.Status = result.Status
|
|
}
|
|
|
|
if state.Result.Payload == nil {
|
|
state.Result.Payload = result.Payload
|
|
}
|
|
state.UpdatedAt = time.Now()
|
|
if result.Ctx == nil {
|
|
result.Ctx = ctx
|
|
}
|
|
if result.Error != nil {
|
|
state.Status = mq.Failed
|
|
}
|
|
if parentKey, ok := tm.parentNodes.Get(state.NodeID); ok {
|
|
parts := strings.Split(state.NodeID, Delimiter)
|
|
// For iterator nodes, only continue to next edge after ALL children have completed and been aggregated
|
|
if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed && size == targetsCount {
|
|
state.Status = mq.Processing
|
|
tm.iteratorNodes.Del(parts[0])
|
|
state.targetResults.Clear()
|
|
if len(parts) == 2 {
|
|
ctx = context.WithValue(ctx, ContextIndex, parts[1])
|
|
}
|
|
toProcess := nodeResult{
|
|
ctx: ctx,
|
|
nodeID: state.NodeID,
|
|
status: state.Status,
|
|
result: state.Result,
|
|
}
|
|
tm.handleEdges(toProcess, edges)
|
|
state.Status = mq.Completed
|
|
} else if size == targetsCount {
|
|
if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil {
|
|
state.Result.Topic = state.NodeID
|
|
tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
|
|
}
|
|
}
|
|
} else {
|
|
tm.updateTimestamps(&state.Result)
|
|
tm.result = &state.Result
|
|
state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0]
|
|
tm.resultCh <- state.Result
|
|
tm.processFinalResult(state)
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) {
|
|
state.UpdatedAt = time.Now()
|
|
if result.Ctx == nil {
|
|
result.Ctx = ctx
|
|
}
|
|
if result.Error != nil {
|
|
state.Status = mq.Failed
|
|
} else {
|
|
edges := tm.getConditionalEdges(node, result)
|
|
if len(edges) == 0 {
|
|
state.Status = mq.Completed
|
|
}
|
|
}
|
|
if result.Status == "" {
|
|
result.Status = state.Status
|
|
}
|
|
// Update task status in storage based on final result
|
|
if tm.storage != nil {
|
|
var status dagstorage.TaskStatus
|
|
var errorMsg string
|
|
var action string
|
|
var message string
|
|
|
|
if result.Error != nil {
|
|
status = dagstorage.TaskStatusFailed
|
|
errorMsg = result.Error.Error()
|
|
action = "node_failed"
|
|
message = fmt.Sprintf("Node %s failed: %s", state.NodeID, errorMsg)
|
|
} else if state.Status == mq.Completed {
|
|
status = dagstorage.TaskStatusCompleted
|
|
action = "node_completed"
|
|
message = fmt.Sprintf("Node %s completed successfully", state.NodeID)
|
|
} else {
|
|
status = dagstorage.TaskStatusRunning
|
|
action = "node_processing"
|
|
message = fmt.Sprintf("Node %s processing", state.NodeID)
|
|
}
|
|
|
|
if err := tm.storage.UpdateTaskStatus(ctx, tm.taskID, status, errorMsg); err != nil {
|
|
tm.dag.Logger().Error("Failed to update task status", logger.Field{Key: "taskID", Value: tm.taskID}, logger.Field{Key: "error", Value: err.Error()})
|
|
}
|
|
|
|
// Log node completion/failure
|
|
tm.logActivity(ctx, tm.taskID, state.NodeID, action, message, result.Payload)
|
|
}
|
|
|
|
tm.enqueueResult(nodeResult{
|
|
ctx: ctx,
|
|
nodeID: state.NodeID,
|
|
status: state.Status,
|
|
result: result,
|
|
})
|
|
}
|
|
|
|
func (tm *TaskManager) enqueueResult(nr nodeResult) {
|
|
select {
|
|
case tm.resultQueue <- nr:
|
|
// Successfully enqueued
|
|
default:
|
|
tm.dag.Logger().Error("Result queue is full, dropping result",
|
|
logger.Field{Key: "nodeID", Value: nr.nodeID},
|
|
logger.Field{Key: "taskID", Value: nr.result.TaskID})
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) onNodeCompleted(nr nodeResult) {
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("onNodeCompleted called",
|
|
logger.Field{Key: "nodeID", Value: nr.nodeID},
|
|
logger.Field{Key: "status", Value: string(nr.status)},
|
|
logger.Field{Key: "hasError", Value: nr.result.Error != nil})
|
|
}
|
|
|
|
nodeID := strings.Split(nr.nodeID, Delimiter)[0]
|
|
node, ok := tm.dag.nodes.Get(nodeID)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
// Handle ResetTo functionality
|
|
if nr.result.ResetTo != "" {
|
|
tm.handleResetTo(nr)
|
|
return
|
|
}
|
|
|
|
if nr.result.Error != nil || nr.status == mq.Failed {
|
|
if state, exists := tm.taskStates.Get(nr.nodeID); exists {
|
|
tm.processFinalResult(state)
|
|
return
|
|
}
|
|
}
|
|
edges := tm.getConditionalEdges(node, nr.result)
|
|
if len(edges) > 0 {
|
|
tm.handleEdges(nr, edges)
|
|
return
|
|
}
|
|
// Check if this is a child node from an iterator (has a parent)
|
|
if parentKey, exists := tm.parentNodes.Get(nr.nodeID); exists {
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Found parent for node",
|
|
logger.Field{Key: "nodeID", Value: nr.nodeID},
|
|
logger.Field{Key: "parentKey", Value: parentKey})
|
|
}
|
|
if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil {
|
|
tm.handlePrevious(nr.ctx, parentState, nr.result, nr.nodeID, true)
|
|
return // Don't send to resultCh if has parent
|
|
}
|
|
}
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("No parent found for node, sending to resultCh",
|
|
logger.Field{Key: "nodeID", Value: nr.nodeID},
|
|
logger.Field{Key: "result_topic", Value: nr.result.Topic})
|
|
}
|
|
tm.updateTimestamps(&nr.result)
|
|
tm.resultCh <- nr.result
|
|
if state, ok := tm.taskStates.Get(nr.nodeID); ok {
|
|
tm.processFinalResult(state)
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
|
|
edges := make([]Edge, len(node.Edges))
|
|
copy(edges, node.Edges)
|
|
|
|
// Handle conditional edges based on ConditionStatus
|
|
if result.ConditionStatus != "" {
|
|
if conditions, ok := tm.dag.conditions[node.ID]; ok {
|
|
if targetKey, exists := conditions[result.ConditionStatus]; exists {
|
|
if targetNode, found := tm.dag.nodes.Get(targetKey); found {
|
|
conditionalEdge := Edge{
|
|
From: node,
|
|
FromSource: node.ID,
|
|
To: targetNode,
|
|
Label: fmt.Sprintf("condition:%s", result.ConditionStatus),
|
|
Type: Simple,
|
|
}
|
|
edges = append(edges, conditionalEdge)
|
|
}
|
|
} else if targetKey, exists := conditions["default"]; exists {
|
|
if targetNode, found := tm.dag.nodes.Get(targetKey); found {
|
|
conditionalEdge := Edge{
|
|
From: node,
|
|
FromSource: node.ID,
|
|
To: targetNode,
|
|
Label: "condition:default",
|
|
Type: Simple,
|
|
}
|
|
edges = append(edges, conditionalEdge)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return edges
|
|
}
|
|
|
|
func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
|
|
if len(edges) == 0 {
|
|
return
|
|
}
|
|
|
|
if len(edges) == 1 {
|
|
tm.processSingleEdge(currentResult, edges[0])
|
|
return
|
|
}
|
|
|
|
// For multiple edges, process sequentially to avoid race conditions
|
|
for _, edge := range edges {
|
|
tm.processSingleEdge(currentResult, edge)
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) processSingleEdge(currentResult nodeResult, edge Edge) {
|
|
index, ok := currentResult.ctx.Value(ContextIndex).(string)
|
|
if !ok {
|
|
index = "0"
|
|
}
|
|
parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index)
|
|
switch edge.Type {
|
|
case Simple:
|
|
if _, exists := tm.iteratorNodes.Get(edge.From.ID); exists {
|
|
return
|
|
}
|
|
fallthrough
|
|
case Iterator:
|
|
if edge.Type == Iterator {
|
|
if _, exists := tm.iteratorNodes.Get(edge.From.ID); !exists {
|
|
return
|
|
}
|
|
// Use the actual completing node as parent, not the edge From ID
|
|
parentNode = currentResult.nodeID
|
|
var items []json.RawMessage
|
|
if err := json.Unmarshal(currentResult.result.Payload, &items); err != nil {
|
|
log.Printf("Error unmarshalling payload for node %s: %v", edge.To.ID, err)
|
|
tm.enqueueResult(nodeResult{
|
|
ctx: currentResult.ctx,
|
|
nodeID: edge.To.ID,
|
|
status: mq.Failed,
|
|
result: mq.Result{Error: err},
|
|
})
|
|
return
|
|
}
|
|
tm.childNodes.Set(parentNode, len(items))
|
|
for i, item := range items {
|
|
childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i)
|
|
ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i))
|
|
tm.parentNodes.Set(childNode, parentNode)
|
|
tm.enqueueTask(ctx, edge.To.ID, tm.taskID, item)
|
|
}
|
|
} else {
|
|
tm.childNodes.Set(parentNode, 1)
|
|
idx, _ := currentResult.ctx.Value(ContextIndex).(string)
|
|
childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
|
|
ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
|
|
|
|
// If the current result came from an iterator child that has a parent,
|
|
// we need to preserve that parent relationship for the new target node
|
|
if originalParent, hasParent := tm.parentNodes.Get(currentResult.nodeID); hasParent {
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Transferring parent relationship for conditional edge",
|
|
logger.Field{Key: "originalChild", Value: currentResult.nodeID},
|
|
logger.Field{Key: "newChild", Value: childNode},
|
|
logger.Field{Key: "parent", Value: originalParent})
|
|
}
|
|
// Remove the original child from parent tracking since it's being replaced by conditional target
|
|
tm.parentNodes.Del(currentResult.nodeID)
|
|
// This edge target should now report back to the original parent instead
|
|
tm.parentNodes.Set(childNode, originalParent)
|
|
} else {
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("No parent found for conditional edge source",
|
|
logger.Field{Key: "nodeID", Value: currentResult.nodeID})
|
|
}
|
|
tm.parentNodes.Set(childNode, parentNode)
|
|
}
|
|
|
|
tm.enqueueTask(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) retryDeferredTasks() {
|
|
defer tm.wg.Done()
|
|
ticker := time.NewTicker(tm.baseBackoff)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-tm.stopCh:
|
|
log.Println("Stopping deferred task retrier")
|
|
return
|
|
case <-ticker.C:
|
|
// Process deferred tasks with a limit to prevent overwhelming the queue
|
|
processed := 0
|
|
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
|
|
if processed >= 10 { // Process max 10 deferred tasks per tick
|
|
return false
|
|
}
|
|
|
|
select {
|
|
case tm.taskQueue <- tsk:
|
|
tm.deferredTasks.Del(taskID)
|
|
processed++
|
|
tm.dag.Logger().Debug("Retried deferred task",
|
|
logger.Field{Key: "taskID", Value: taskID})
|
|
default:
|
|
// Queue still full, keep the task deferred
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) processFinalResult(state *TaskState) {
|
|
state.Status = mq.Completed
|
|
// state.targetResults.Clear()
|
|
// update metrics using the task start time for duration calculation
|
|
tm.dag.updateTaskMetrics(tm.taskID, state.Result, time.Since(tm.createdAt))
|
|
if tm.dag.finalResult != nil {
|
|
tm.dag.finalResult(tm.taskID, state.Result)
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) Pause() {
|
|
tm.pauseMu.Lock()
|
|
defer tm.pauseMu.Unlock()
|
|
if tm.pauseCh == nil {
|
|
tm.pauseCh = make(chan struct{})
|
|
log.Println("TaskManager paused")
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) Resume() {
|
|
tm.pauseMu.Lock()
|
|
defer tm.pauseMu.Unlock()
|
|
if tm.pauseCh != nil {
|
|
close(tm.pauseCh)
|
|
tm.pauseCh = nil
|
|
log.Println("TaskManager resumed")
|
|
}
|
|
}
|
|
|
|
// Stop gracefully stops the task manager
|
|
func (tm *TaskManager) Stop() {
|
|
close(tm.stopCh)
|
|
tm.wg.Wait()
|
|
|
|
// Cancel any pending operations
|
|
tm.pauseMu.Lock()
|
|
if tm.pauseCh != nil {
|
|
close(tm.pauseCh)
|
|
tm.pauseCh = nil
|
|
}
|
|
tm.pauseMu.Unlock()
|
|
|
|
// Clean up resources
|
|
tm.taskStates.Clear()
|
|
tm.parentNodes.Clear()
|
|
tm.childNodes.Clear()
|
|
tm.deferredTasks.Clear()
|
|
tm.currentNodePayload.Clear()
|
|
tm.currentNodeResult.Clear()
|
|
|
|
tm.dag.Logger().Info("TaskManager stopped gracefully",
|
|
logger.Field{Key: "taskID", Value: tm.taskID})
|
|
}
|
|
|
|
// debugNodeStart logs debug information when a node starts processing
|
|
func (tm *TaskManager) debugNodeStart(exec *task, node *Node) {
|
|
var payload map[string]any
|
|
if err := json.Unmarshal(exec.payload, &payload); err != nil {
|
|
payload = map[string]any{"raw_payload": string(exec.payload)}
|
|
}
|
|
|
|
tm.dag.Logger().Info("🐛 [DEBUG] Node processing started",
|
|
logger.Field{Key: "dag_name", Value: tm.dag.name},
|
|
logger.Field{Key: "task_id", Value: exec.taskID},
|
|
logger.Field{Key: "node_id", Value: node.ID},
|
|
logger.Field{Key: "node_type", Value: node.NodeType.String()},
|
|
logger.Field{Key: "node_label", Value: node.Label},
|
|
logger.Field{Key: "timestamp", Value: time.Now().Format(time.RFC3339)},
|
|
logger.Field{Key: "has_timeout", Value: node.Timeout > 0},
|
|
logger.Field{Key: "timeout_duration", Value: node.Timeout.String()},
|
|
logger.Field{Key: "payload_size", Value: len(exec.payload)},
|
|
logger.Field{Key: "payload_preview", Value: tm.getPayloadPreview(payload)},
|
|
logger.Field{Key: "debug_mode", Value: "individual_node:" + fmt.Sprintf("%t", node.Debug) + ", dag_global:" + fmt.Sprintf("%t", tm.dag.IsDebugEnabled())},
|
|
)
|
|
|
|
// Log processor type if it implements the Debugger interface
|
|
if debugger, ok := node.processor.(Debugger); ok {
|
|
debugger.Debug(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID, mq.WithDAG(tm.dag)))
|
|
}
|
|
}
|
|
|
|
// debugNodeComplete logs debug information when a node completes processing
|
|
func (tm *TaskManager) debugNodeComplete(exec *task, node *Node, result mq.Result, latency time.Duration, attempts int) {
|
|
var resultPayload map[string]any
|
|
if len(result.Payload) > 0 {
|
|
if err := json.Unmarshal(result.Payload, &resultPayload); err != nil {
|
|
resultPayload = map[string]any{"raw_payload": string(result.Payload)}
|
|
}
|
|
}
|
|
|
|
tm.dag.Logger().Info("🐛 [DEBUG] Node processing completed",
|
|
logger.Field{Key: "dag_name", Value: tm.dag.name},
|
|
logger.Field{Key: "task_id", Value: exec.taskID},
|
|
logger.Field{Key: "node_id", Value: node.ID},
|
|
logger.Field{Key: "node_type", Value: node.NodeType.String()},
|
|
logger.Field{Key: "node_label", Value: node.Label},
|
|
logger.Field{Key: "timestamp", Value: time.Now().Format(time.RFC3339)},
|
|
logger.Field{Key: "status", Value: string(result.Status)},
|
|
logger.Field{Key: "latency", Value: latency.String()},
|
|
logger.Field{Key: "attempts", Value: attempts + 1},
|
|
logger.Field{Key: "has_error", Value: result.Error != nil},
|
|
logger.Field{Key: "error_message", Value: tm.getErrorMessage(result.Error)},
|
|
logger.Field{Key: "result_size", Value: len(result.Payload)},
|
|
logger.Field{Key: "result_preview", Value: tm.getPayloadPreview(resultPayload)},
|
|
logger.Field{Key: "is_last_node", Value: result.Last},
|
|
)
|
|
}
|
|
|
|
// getPayloadPreview returns a truncated version of the payload for debug logging
|
|
func (tm *TaskManager) getPayloadPreview(payload map[string]any) string {
|
|
if payload == nil {
|
|
return "null"
|
|
}
|
|
|
|
preview := make(map[string]any)
|
|
count := 0
|
|
maxFields := 5 // Limit to first 5 fields to avoid log spam
|
|
|
|
for key, value := range payload {
|
|
if count >= maxFields {
|
|
preview["..."] = fmt.Sprintf("and %d more fields", len(payload)-maxFields)
|
|
break
|
|
}
|
|
|
|
// Truncate string values if they're too long
|
|
if strVal, ok := value.(string); ok && len(strVal) > 100 {
|
|
preview[key] = strVal[:97] + "..."
|
|
} else {
|
|
preview[key] = value
|
|
}
|
|
count++
|
|
}
|
|
|
|
previewBytes, _ := json.Marshal(preview)
|
|
return string(previewBytes)
|
|
}
|
|
|
|
// getErrorMessage safely extracts error message
|
|
func (tm *TaskManager) getErrorMessage(err error) string {
|
|
if err == nil {
|
|
return ""
|
|
}
|
|
return err.Error()
|
|
}
|
|
|
|
// logActivity logs an activity for a task
|
|
func (tm *TaskManager) logActivity(ctx context.Context, taskID, nodeID, action, message string, data json.RawMessage) {
|
|
if tm.storage == nil {
|
|
return
|
|
}
|
|
|
|
logEntry := &dagstorage.TaskActivityLog{
|
|
TaskID: taskID,
|
|
DAGID: tm.dag.key,
|
|
NodeID: nodeID,
|
|
Action: action,
|
|
Message: message,
|
|
Data: data,
|
|
Level: "info",
|
|
CreatedAt: time.Now(),
|
|
}
|
|
|
|
if err := tm.storage.LogActivity(ctx, logEntry); err != nil {
|
|
tm.dag.Logger().Error("Failed to log activity",
|
|
logger.Field{Key: "taskID", Value: taskID},
|
|
logger.Field{Key: "action", Value: action},
|
|
logger.Field{Key: "error", Value: err.Error()})
|
|
}
|
|
}
|
|
|
|
// updateTaskPosition updates the current position of a task in the DAG
|
|
func (tm *TaskManager) updateTaskPosition(ctx context.Context, taskID, currentNodeID, processingState string) error {
|
|
if tm.storage == nil {
|
|
return nil
|
|
}
|
|
|
|
// Get the current task
|
|
task, err := tm.storage.GetTask(ctx, taskID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get task for position update: %w", err)
|
|
}
|
|
|
|
// Update position fields
|
|
task.CurrentNodeID = currentNodeID
|
|
task.ProcessingState = processingState
|
|
task.UpdatedAt = time.Now()
|
|
|
|
// Save the updated task
|
|
return tm.storage.SaveTask(ctx, task)
|
|
}
|
|
|
|
// handleResetTo handles the ResetTo functionality for resetting a task to a specific node
|
|
func (tm *TaskManager) handleResetTo(nr nodeResult) {
|
|
resetTo := nr.result.ResetTo
|
|
nodeID := strings.Split(nr.nodeID, Delimiter)[0]
|
|
|
|
var targetNodeID string
|
|
var err error
|
|
|
|
if resetTo == "back" {
|
|
// Use GetPreviousPageNode to find the previous page node
|
|
var prevNode *Node
|
|
prevNode, err = tm.dag.GetPreviousPageNode(nodeID)
|
|
if err != nil {
|
|
tm.dag.Logger().Error("Failed to get previous page node",
|
|
logger.Field{Key: "currentNodeID", Value: nodeID},
|
|
logger.Field{Key: "error", Value: err.Error()})
|
|
// Send error result
|
|
tm.resultCh <- mq.Result{
|
|
Error: fmt.Errorf("failed to reset to previous page node: %w", err),
|
|
Ctx: nr.ctx,
|
|
TaskID: nr.result.TaskID,
|
|
Topic: nr.result.Topic,
|
|
Status: mq.Failed,
|
|
Payload: nr.result.Payload,
|
|
}
|
|
return
|
|
}
|
|
if prevNode == nil {
|
|
tm.dag.Logger().Error("No previous page node found",
|
|
logger.Field{Key: "currentNodeID", Value: nodeID})
|
|
// Send error result
|
|
tm.resultCh <- mq.Result{
|
|
Error: fmt.Errorf("no previous page node found"),
|
|
Ctx: nr.ctx,
|
|
TaskID: nr.result.TaskID,
|
|
Topic: nr.result.Topic,
|
|
Status: mq.Failed,
|
|
Payload: nr.result.Payload,
|
|
}
|
|
return
|
|
}
|
|
targetNodeID = prevNode.ID
|
|
} else {
|
|
// Use the specified node ID
|
|
targetNodeID = resetTo
|
|
// Validate that the target node exists
|
|
if _, exists := tm.dag.nodes.Get(targetNodeID); !exists {
|
|
tm.dag.Logger().Error("Reset target node does not exist",
|
|
logger.Field{Key: "targetNodeID", Value: targetNodeID})
|
|
// Send error result
|
|
tm.resultCh <- mq.Result{
|
|
Error: fmt.Errorf("reset target node %s does not exist", targetNodeID),
|
|
Ctx: nr.ctx,
|
|
TaskID: nr.result.TaskID,
|
|
Topic: nr.result.Topic,
|
|
Status: mq.Failed,
|
|
Payload: nr.result.Payload,
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Resetting task to node",
|
|
logger.Field{Key: "taskID", Value: nr.result.TaskID},
|
|
logger.Field{Key: "fromNode", Value: nodeID},
|
|
logger.Field{Key: "toNode", Value: targetNodeID},
|
|
logger.Field{Key: "resetTo", Value: resetTo})
|
|
}
|
|
|
|
// Clear task states of all nodes between current node and target node
|
|
// This ensures that when we reset, the workflow can proceed correctly
|
|
tm.clearTaskStatesInPath(nodeID, targetNodeID)
|
|
|
|
// Special handling for subDAG reset: if we're resetting to the same node (subDAG resetting to itself),
|
|
// we need to clear all downstream nodes that depend on this subDAG
|
|
if nodeID == targetNodeID {
|
|
tm.clearDownstreamNodes(targetNodeID)
|
|
}
|
|
|
|
// Also clear any deferred tasks for the target node itself
|
|
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
|
|
if strings.Split(tsk.nodeID, Delimiter)[0] == targetNodeID {
|
|
tm.deferredTasks.Del(taskID)
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Cleared deferred task for target node",
|
|
logger.Field{Key: "nodeID", Value: targetNodeID},
|
|
logger.Field{Key: "taskID", Value: taskID})
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Handle dependencies of the target node - if they exist and are not completed,
|
|
// we need to mark them as completed to allow the workflow to proceed
|
|
tm.handleTargetNodeDependencies(targetNodeID, nr)
|
|
|
|
// Get previously received data for the target node
|
|
var previousPayload json.RawMessage
|
|
if prevResult, hasResult := tm.currentNodeResult.Get(targetNodeID); hasResult {
|
|
previousPayload = prevResult.Payload
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Using previous payload for reset",
|
|
logger.Field{Key: "targetNodeID", Value: targetNodeID},
|
|
logger.Field{Key: "payloadSize", Value: len(previousPayload)})
|
|
}
|
|
} else {
|
|
// If no previous data, use the current result's payload
|
|
previousPayload = nr.result.Payload
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("No previous payload found, using current payload",
|
|
logger.Field{Key: "targetNodeID", Value: targetNodeID})
|
|
}
|
|
}
|
|
|
|
// Reset task state for the target node
|
|
if state, exists := tm.taskStates.Get(targetNodeID); exists {
|
|
state.Status = mq.Completed // Mark as completed to satisfy dependencies
|
|
state.UpdatedAt = time.Now()
|
|
state.Result = mq.Result{
|
|
Status: mq.Completed,
|
|
Ctx: nr.ctx,
|
|
}
|
|
} else {
|
|
// Create new state if it doesn't exist and mark as completed
|
|
newState := newTaskState(targetNodeID)
|
|
newState.Status = mq.Completed
|
|
newState.Result = mq.Result{
|
|
Status: mq.Completed,
|
|
Ctx: nr.ctx,
|
|
}
|
|
tm.taskStates.Set(targetNodeID, newState)
|
|
}
|
|
|
|
// Update current node result with the reset result (clear ResetTo to avoid loops)
|
|
resetResult := mq.Result{
|
|
TaskID: nr.result.TaskID,
|
|
Topic: targetNodeID,
|
|
Status: mq.Completed, // Mark as completed
|
|
Payload: previousPayload,
|
|
Ctx: nr.ctx,
|
|
// ResetTo is intentionally not set to avoid infinite loops
|
|
}
|
|
tm.currentNodeResult.Set(targetNodeID, resetResult)
|
|
|
|
// Re-enqueue the task for the target node
|
|
tm.enqueueTask(nr.ctx, targetNodeID, nr.result.TaskID, previousPayload)
|
|
|
|
// Log the reset activity
|
|
tm.logActivity(nr.ctx, nr.result.TaskID, targetNodeID, "task_reset",
|
|
fmt.Sprintf("Task reset from %s to %s", nodeID, targetNodeID), nil)
|
|
}
|
|
|
|
// clearTaskStatesInPath clears all task states in the path from current node to target node
|
|
// This is necessary when resetting to ensure the workflow can proceed without dependency issues
|
|
func (tm *TaskManager) clearTaskStatesInPath(currentNodeID, targetNodeID string) {
|
|
// Get all nodes in the path from current to target
|
|
pathNodes := tm.getNodesInPath(currentNodeID, targetNodeID)
|
|
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Clearing task states in path",
|
|
logger.Field{Key: "fromNode", Value: currentNodeID},
|
|
logger.Field{Key: "toNode", Value: targetNodeID},
|
|
logger.Field{Key: "pathNodeCount", Value: len(pathNodes)})
|
|
}
|
|
|
|
// Also clear the current node itself (ValidateInput in the example)
|
|
if state, exists := tm.taskStates.Get(currentNodeID); exists {
|
|
state.Status = mq.Pending
|
|
state.UpdatedAt = time.Now()
|
|
state.Result = mq.Result{} // Clear previous result
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Cleared task state for current node",
|
|
logger.Field{Key: "nodeID", Value: currentNodeID})
|
|
}
|
|
}
|
|
// Also clear any cached results for the current node
|
|
tm.currentNodeResult.Del(currentNodeID)
|
|
// Clear any deferred tasks for the current node
|
|
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
|
|
if strings.Split(tsk.nodeID, Delimiter)[0] == currentNodeID {
|
|
tm.deferredTasks.Del(taskID)
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Cleared deferred task for current node",
|
|
logger.Field{Key: "nodeID", Value: currentNodeID},
|
|
logger.Field{Key: "taskID", Value: taskID})
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Clear task states for all nodes in the path
|
|
for _, pathNodeID := range pathNodes {
|
|
if state, exists := tm.taskStates.Get(pathNodeID); exists {
|
|
state.Status = mq.Pending
|
|
state.UpdatedAt = time.Now()
|
|
state.Result = mq.Result{} // Clear previous result
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Cleared task state for path node",
|
|
logger.Field{Key: "nodeID", Value: pathNodeID})
|
|
}
|
|
}
|
|
// Also clear any cached results for this node
|
|
tm.currentNodeResult.Del(pathNodeID)
|
|
// Clear any deferred tasks for this node
|
|
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
|
|
if strings.Split(tsk.nodeID, Delimiter)[0] == pathNodeID {
|
|
tm.deferredTasks.Del(taskID)
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Cleared deferred task for path node",
|
|
logger.Field{Key: "nodeID", Value: pathNodeID},
|
|
logger.Field{Key: "taskID", Value: taskID})
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
// getNodesInPath returns all nodes in the path from start node to end node
|
|
func (tm *TaskManager) getNodesInPath(startNodeID, endNodeID string) []string {
|
|
visited := make(map[string]bool)
|
|
var result []string
|
|
|
|
// Use BFS to find the path from start to end
|
|
queue := []string{startNodeID}
|
|
visited[startNodeID] = true
|
|
parent := make(map[string]string)
|
|
|
|
found := false
|
|
for len(queue) > 0 && !found {
|
|
currentNodeID := queue[0]
|
|
queue = queue[1:]
|
|
|
|
// Get all nodes that this node points to
|
|
if node, exists := tm.dag.nodes.Get(currentNodeID); exists {
|
|
for _, edge := range node.Edges {
|
|
if edge.Type == Simple || edge.Type == Iterator {
|
|
targetNodeID := edge.To.ID
|
|
if !visited[targetNodeID] {
|
|
visited[targetNodeID] = true
|
|
parent[targetNodeID] = currentNodeID
|
|
queue = append(queue, targetNodeID)
|
|
|
|
if targetNodeID == endNodeID {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// If we found the end node, reconstruct the path
|
|
if found {
|
|
current := endNodeID
|
|
for current != startNodeID {
|
|
result = append([]string{current}, result...)
|
|
if parentNode, exists := parent[current]; exists {
|
|
current = parentNode
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
result = append([]string{startNodeID}, result...)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// getAllDownstreamNodes returns all nodes that come after the given node in the workflow
|
|
func (tm *TaskManager) getAllDownstreamNodes(nodeID string) []string {
|
|
visited := make(map[string]bool)
|
|
var result []string
|
|
|
|
// Use BFS to find all downstream nodes
|
|
queue := []string{nodeID}
|
|
visited[nodeID] = true
|
|
|
|
for len(queue) > 0 {
|
|
currentNodeID := queue[0]
|
|
queue = queue[1:]
|
|
|
|
// Get all nodes that this node points to
|
|
if node, exists := tm.dag.nodes.Get(currentNodeID); exists {
|
|
for _, edge := range node.Edges {
|
|
if edge.Type == Simple || edge.Type == Iterator {
|
|
targetNodeID := edge.To.ID
|
|
if !visited[targetNodeID] {
|
|
visited[targetNodeID] = true
|
|
result = append(result, targetNodeID)
|
|
queue = append(queue, targetNodeID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// clearDownstreamNodes clears all task states for nodes that come after the given node
|
|
// This is used when resetting a subDAG to ensure downstream dependencies are cleared
|
|
func (tm *TaskManager) clearDownstreamNodes(nodeID string) {
|
|
downstreamNodes := tm.getAllDownstreamNodes(nodeID)
|
|
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Clearing downstream nodes for subDAG reset",
|
|
logger.Field{Key: "subDAGNodeID", Value: nodeID},
|
|
logger.Field{Key: "downstreamCount", Value: len(downstreamNodes)})
|
|
}
|
|
|
|
// Clear task states for all downstream nodes
|
|
for _, downstreamNodeID := range downstreamNodes {
|
|
if state, exists := tm.taskStates.Get(downstreamNodeID); exists {
|
|
state.Status = mq.Pending
|
|
state.UpdatedAt = time.Now()
|
|
state.Result = mq.Result{} // Clear previous result
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Cleared task state for downstream node",
|
|
logger.Field{Key: "nodeID", Value: downstreamNodeID})
|
|
}
|
|
}
|
|
// Also clear any cached results for this node
|
|
tm.currentNodeResult.Del(downstreamNodeID)
|
|
// Clear any deferred tasks for this node
|
|
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
|
|
if strings.Split(tsk.nodeID, Delimiter)[0] == downstreamNodeID {
|
|
tm.deferredTasks.Del(taskID)
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Cleared deferred task for downstream node",
|
|
logger.Field{Key: "nodeID", Value: downstreamNodeID},
|
|
logger.Field{Key: "taskID", Value: taskID})
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
// handleTargetNodeDependencies handles the dependencies of the target node during reset
|
|
// If the target node has unmet dependencies, we mark them as completed to allow the workflow to proceed
|
|
func (tm *TaskManager) handleTargetNodeDependencies(targetNodeID string, nr nodeResult) {
|
|
// Get the dependencies of the target node
|
|
prevNodes, err := tm.dag.GetPreviousNodes(targetNodeID)
|
|
if err != nil {
|
|
tm.dag.Logger().Error("Error getting previous nodes for target",
|
|
logger.Field{Key: "targetNodeID", Value: targetNodeID},
|
|
logger.Field{Key: "error", Value: err.Error()})
|
|
return
|
|
}
|
|
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Info("Checking dependencies for target node",
|
|
logger.Field{Key: "targetNodeID", Value: targetNodeID},
|
|
logger.Field{Key: "dependencyCount", Value: len(prevNodes)})
|
|
}
|
|
|
|
// Check each dependency and ensure it's marked as completed for reset
|
|
for _, prevNode := range prevNodes {
|
|
// Check both the pure node ID and the indexed node ID for state
|
|
state, exists := tm.taskStates.Get(prevNode.ID)
|
|
if !exists {
|
|
// Also check if there's a state with an index suffix
|
|
tm.taskStates.ForEach(func(key string, s *TaskState) bool {
|
|
if strings.Split(key, Delimiter)[0] == prevNode.ID {
|
|
state = s
|
|
exists = true
|
|
return false // Stop iteration
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
if !exists {
|
|
// Create new state and mark as completed for reset
|
|
newState := newTaskState(prevNode.ID)
|
|
newState.Status = mq.Completed
|
|
newState.UpdatedAt = time.Now()
|
|
newState.Result = mq.Result{
|
|
Status: mq.Completed,
|
|
Ctx: nr.ctx,
|
|
}
|
|
tm.taskStates.Set(prevNode.ID, newState)
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Created completed state for dependency node during reset",
|
|
logger.Field{Key: "dependencyNodeID", Value: prevNode.ID})
|
|
}
|
|
} else if state.Status != mq.Completed {
|
|
// Mark existing state as completed for reset
|
|
state.Status = mq.Completed
|
|
state.UpdatedAt = time.Now()
|
|
if state.Result.Status == "" {
|
|
state.Result = mq.Result{
|
|
Status: mq.Completed,
|
|
Ctx: nr.ctx,
|
|
}
|
|
}
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Marked dependency node as completed during reset",
|
|
logger.Field{Key: "dependencyNodeID", Value: prevNode.ID},
|
|
logger.Field{Key: "previousStatus", Value: string(state.Status)})
|
|
}
|
|
} else {
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Dependency already satisfied",
|
|
logger.Field{Key: "dependencyNodeID", Value: prevNode.ID},
|
|
logger.Field{Key: "status", Value: string(state.Status)})
|
|
}
|
|
}
|
|
|
|
// Ensure cached result exists for this dependency
|
|
if _, hasResult := tm.currentNodeResult.Get(prevNode.ID); !hasResult {
|
|
tm.currentNodeResult.Set(prevNode.ID, mq.Result{
|
|
Status: mq.Completed,
|
|
Ctx: nr.ctx,
|
|
})
|
|
}
|
|
|
|
// Clear any deferred tasks for this dependency since it's now satisfied
|
|
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
|
|
if strings.Split(tsk.nodeID, Delimiter)[0] == prevNode.ID {
|
|
tm.deferredTasks.Del(taskID)
|
|
if tm.dag.debug {
|
|
tm.dag.Logger().Debug("Cleared deferred task for satisfied dependency",
|
|
logger.Field{Key: "dependencyNodeID", Value: prevNode.ID},
|
|
logger.Field{Key: "taskID", Value: taskID})
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|