This commit is contained in:
sujit
2025-09-19 22:30:21 +05:45
parent e4344bc96e
commit b82cd20eef
9 changed files with 565 additions and 22 deletions

View File

@@ -311,7 +311,8 @@ func (tm *TaskManager) areDependenciesMet(nodeID string) bool {
logger.Field{Key: "nodeID", Value: nodeID}, logger.Field{Key: "nodeID", Value: nodeID},
logger.Field{Key: "dependency", Value: prevNode.ID}, logger.Field{Key: "dependency", Value: prevNode.ID},
logger.Field{Key: "stateExists", Value: exists}, logger.Field{Key: "stateExists", Value: exists},
logger.Field{Key: "stateStatus", Value: string(state.Status)}) logger.Field{Key: "stateStatus", Value: string(state.Status)},
logger.Field{Key: "taskID", Value: tm.taskID})
return false return false
} }
} }
@@ -706,6 +707,13 @@ func (tm *TaskManager) onNodeCompleted(nr nodeResult) {
if !ok { if !ok {
return return
} }
// Handle ResetTo functionality
if nr.result.ResetTo != "" {
tm.handleResetTo(nr)
return
}
if nr.result.Error != nil || nr.status == mq.Failed { if nr.result.Error != nil || nr.status == mq.Failed {
if state, exists := tm.taskStates.Get(nr.nodeID); exists { if state, exists := tm.taskStates.Get(nr.nodeID); exists {
tm.processFinalResult(state) tm.processFinalResult(state)
@@ -1086,3 +1094,393 @@ func (tm *TaskManager) updateTaskPosition(ctx context.Context, taskID, currentNo
// Save the updated task // Save the updated task
return tm.storage.SaveTask(ctx, task) return tm.storage.SaveTask(ctx, task)
} }
// handleResetTo handles the ResetTo functionality for resetting a task to a specific node
func (tm *TaskManager) handleResetTo(nr nodeResult) {
resetTo := nr.result.ResetTo
nodeID := strings.Split(nr.nodeID, Delimiter)[0]
var targetNodeID string
var err error
if resetTo == "back" {
// Use GetPreviousPageNode to find the previous page node
var prevNode *Node
prevNode, err = tm.dag.GetPreviousPageNode(nodeID)
if err != nil {
tm.dag.Logger().Error("Failed to get previous page node",
logger.Field{Key: "currentNodeID", Value: nodeID},
logger.Field{Key: "error", Value: err.Error()})
// Send error result
tm.resultCh <- mq.Result{
Error: fmt.Errorf("failed to reset to previous page node: %w", err),
Ctx: nr.ctx,
TaskID: nr.result.TaskID,
Topic: nr.result.Topic,
Status: mq.Failed,
Payload: nr.result.Payload,
}
return
}
if prevNode == nil {
tm.dag.Logger().Error("No previous page node found",
logger.Field{Key: "currentNodeID", Value: nodeID})
// Send error result
tm.resultCh <- mq.Result{
Error: fmt.Errorf("no previous page node found"),
Ctx: nr.ctx,
TaskID: nr.result.TaskID,
Topic: nr.result.Topic,
Status: mq.Failed,
Payload: nr.result.Payload,
}
return
}
targetNodeID = prevNode.ID
} else {
// Use the specified node ID
targetNodeID = resetTo
// Validate that the target node exists
if _, exists := tm.dag.nodes.Get(targetNodeID); !exists {
tm.dag.Logger().Error("Reset target node does not exist",
logger.Field{Key: "targetNodeID", Value: targetNodeID})
// Send error result
tm.resultCh <- mq.Result{
Error: fmt.Errorf("reset target node %s does not exist", targetNodeID),
Ctx: nr.ctx,
TaskID: nr.result.TaskID,
Topic: nr.result.Topic,
Status: mq.Failed,
Payload: nr.result.Payload,
}
return
}
}
if tm.dag.debug {
tm.dag.Logger().Info("Resetting task to node",
logger.Field{Key: "taskID", Value: nr.result.TaskID},
logger.Field{Key: "fromNode", Value: nodeID},
logger.Field{Key: "toNode", Value: targetNodeID},
logger.Field{Key: "resetTo", Value: resetTo})
}
// Clear task states of all nodes between current node and target node
// This ensures that when we reset, the workflow can proceed correctly
tm.clearTaskStatesInPath(nodeID, targetNodeID)
// Also clear any deferred tasks for the target node itself
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
if strings.Split(tsk.nodeID, Delimiter)[0] == targetNodeID {
tm.deferredTasks.Del(taskID)
if tm.dag.debug {
tm.dag.Logger().Debug("Cleared deferred task for target node",
logger.Field{Key: "nodeID", Value: targetNodeID},
logger.Field{Key: "taskID", Value: taskID})
}
}
return true
})
// Handle dependencies of the target node - if they exist and are not completed,
// we need to mark them as completed to allow the workflow to proceed
tm.handleTargetNodeDependencies(targetNodeID, nr)
// Get previously received data for the target node
var previousPayload json.RawMessage
if prevResult, hasResult := tm.currentNodeResult.Get(targetNodeID); hasResult {
previousPayload = prevResult.Payload
if tm.dag.debug {
tm.dag.Logger().Info("Using previous payload for reset",
logger.Field{Key: "targetNodeID", Value: targetNodeID},
logger.Field{Key: "payloadSize", Value: len(previousPayload)})
}
} else {
// If no previous data, use the current result's payload
previousPayload = nr.result.Payload
if tm.dag.debug {
tm.dag.Logger().Info("No previous payload found, using current payload",
logger.Field{Key: "targetNodeID", Value: targetNodeID})
}
}
// Reset task state for the target node
if state, exists := tm.taskStates.Get(targetNodeID); exists {
state.Status = mq.Completed // Mark as completed to satisfy dependencies
state.UpdatedAt = time.Now()
state.Result = mq.Result{
Status: mq.Completed,
Ctx: nr.ctx,
}
} else {
// Create new state if it doesn't exist and mark as completed
newState := newTaskState(targetNodeID)
newState.Status = mq.Completed
newState.Result = mq.Result{
Status: mq.Completed,
Ctx: nr.ctx,
}
tm.taskStates.Set(targetNodeID, newState)
}
// Update current node result with the reset result (clear ResetTo to avoid loops)
resetResult := mq.Result{
TaskID: nr.result.TaskID,
Topic: targetNodeID,
Status: mq.Completed, // Mark as completed
Payload: previousPayload,
Ctx: nr.ctx,
// ResetTo is intentionally not set to avoid infinite loops
}
tm.currentNodeResult.Set(targetNodeID, resetResult)
// Re-enqueue the task for the target node
tm.enqueueTask(nr.ctx, targetNodeID, nr.result.TaskID, previousPayload)
// Log the reset activity
tm.logActivity(nr.ctx, nr.result.TaskID, targetNodeID, "task_reset",
fmt.Sprintf("Task reset from %s to %s", nodeID, targetNodeID), nil)
}
// clearTaskStatesInPath clears all task states in the path from current node to target node
// This is necessary when resetting to ensure the workflow can proceed without dependency issues
func (tm *TaskManager) clearTaskStatesInPath(currentNodeID, targetNodeID string) {
// Get all nodes in the path from current to target
pathNodes := tm.getNodesInPath(currentNodeID, targetNodeID)
if tm.dag.debug {
tm.dag.Logger().Info("Clearing task states in path",
logger.Field{Key: "fromNode", Value: currentNodeID},
logger.Field{Key: "toNode", Value: targetNodeID},
logger.Field{Key: "pathNodeCount", Value: len(pathNodes)})
}
// Also clear the current node itself (ValidateInput in the example)
if state, exists := tm.taskStates.Get(currentNodeID); exists {
state.Status = mq.Pending
state.UpdatedAt = time.Now()
state.Result = mq.Result{} // Clear previous result
if tm.dag.debug {
tm.dag.Logger().Debug("Cleared task state for current node",
logger.Field{Key: "nodeID", Value: currentNodeID})
}
}
// Also clear any cached results for the current node
tm.currentNodeResult.Del(currentNodeID)
// Clear any deferred tasks for the current node
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
if strings.Split(tsk.nodeID, Delimiter)[0] == currentNodeID {
tm.deferredTasks.Del(taskID)
if tm.dag.debug {
tm.dag.Logger().Debug("Cleared deferred task for current node",
logger.Field{Key: "nodeID", Value: currentNodeID},
logger.Field{Key: "taskID", Value: taskID})
}
}
return true
})
// Clear task states for all nodes in the path
for _, pathNodeID := range pathNodes {
if state, exists := tm.taskStates.Get(pathNodeID); exists {
state.Status = mq.Pending
state.UpdatedAt = time.Now()
state.Result = mq.Result{} // Clear previous result
if tm.dag.debug {
tm.dag.Logger().Debug("Cleared task state for path node",
logger.Field{Key: "nodeID", Value: pathNodeID})
}
}
// Also clear any cached results for this node
tm.currentNodeResult.Del(pathNodeID)
// Clear any deferred tasks for this node
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
if strings.Split(tsk.nodeID, Delimiter)[0] == pathNodeID {
tm.deferredTasks.Del(taskID)
if tm.dag.debug {
tm.dag.Logger().Debug("Cleared deferred task for path node",
logger.Field{Key: "nodeID", Value: pathNodeID},
logger.Field{Key: "taskID", Value: taskID})
}
}
return true
})
}
}
// getNodesInPath returns all nodes in the path from start node to end node
func (tm *TaskManager) getNodesInPath(startNodeID, endNodeID string) []string {
visited := make(map[string]bool)
var result []string
// Use BFS to find the path from start to end
queue := []string{startNodeID}
visited[startNodeID] = true
parent := make(map[string]string)
found := false
for len(queue) > 0 && !found {
currentNodeID := queue[0]
queue = queue[1:]
// Get all nodes that this node points to
if node, exists := tm.dag.nodes.Get(currentNodeID); exists {
for _, edge := range node.Edges {
if edge.Type == Simple || edge.Type == Iterator {
targetNodeID := edge.To.ID
if !visited[targetNodeID] {
visited[targetNodeID] = true
parent[targetNodeID] = currentNodeID
queue = append(queue, targetNodeID)
if targetNodeID == endNodeID {
found = true
break
}
}
}
}
}
}
// If we found the end node, reconstruct the path
if found {
current := endNodeID
for current != startNodeID {
result = append([]string{current}, result...)
if parentNode, exists := parent[current]; exists {
current = parentNode
} else {
break
}
}
result = append([]string{startNodeID}, result...)
}
return result
}
// getAllDownstreamNodes returns all nodes that come after the given node in the workflow
func (tm *TaskManager) getAllDownstreamNodes(nodeID string) []string {
visited := make(map[string]bool)
var result []string
// Use BFS to find all downstream nodes
queue := []string{nodeID}
visited[nodeID] = true
for len(queue) > 0 {
currentNodeID := queue[0]
queue = queue[1:]
// Get all nodes that this node points to
if node, exists := tm.dag.nodes.Get(currentNodeID); exists {
for _, edge := range node.Edges {
if edge.Type == Simple || edge.Type == Iterator {
targetNodeID := edge.To.ID
if !visited[targetNodeID] {
visited[targetNodeID] = true
result = append(result, targetNodeID)
queue = append(queue, targetNodeID)
}
}
}
}
}
return result
}
// handleTargetNodeDependencies handles the dependencies of the target node during reset
// If the target node has unmet dependencies, we mark them as completed to allow the workflow to proceed
func (tm *TaskManager) handleTargetNodeDependencies(targetNodeID string, nr nodeResult) {
// Get the dependencies of the target node
prevNodes, err := tm.dag.GetPreviousNodes(targetNodeID)
if err != nil {
tm.dag.Logger().Error("Error getting previous nodes for target",
logger.Field{Key: "targetNodeID", Value: targetNodeID},
logger.Field{Key: "error", Value: err.Error()})
return
}
if tm.dag.debug {
tm.dag.Logger().Info("Checking dependencies for target node",
logger.Field{Key: "targetNodeID", Value: targetNodeID},
logger.Field{Key: "dependencyCount", Value: len(prevNodes)})
}
// Check each dependency and ensure it's marked as completed for reset
for _, prevNode := range prevNodes {
// Check both the pure node ID and the indexed node ID for state
state, exists := tm.taskStates.Get(prevNode.ID)
if !exists {
// Also check if there's a state with an index suffix
tm.taskStates.ForEach(func(key string, s *TaskState) bool {
if strings.Split(key, Delimiter)[0] == prevNode.ID {
state = s
exists = true
return false // Stop iteration
}
return true
})
}
if !exists {
// Create new state and mark as completed for reset
newState := newTaskState(prevNode.ID)
newState.Status = mq.Completed
newState.UpdatedAt = time.Now()
newState.Result = mq.Result{
Status: mq.Completed,
Ctx: nr.ctx,
}
tm.taskStates.Set(prevNode.ID, newState)
if tm.dag.debug {
tm.dag.Logger().Debug("Created completed state for dependency node during reset",
logger.Field{Key: "dependencyNodeID", Value: prevNode.ID})
}
} else if state.Status != mq.Completed {
// Mark existing state as completed for reset
state.Status = mq.Completed
state.UpdatedAt = time.Now()
if state.Result.Status == "" {
state.Result = mq.Result{
Status: mq.Completed,
Ctx: nr.ctx,
}
}
if tm.dag.debug {
tm.dag.Logger().Debug("Marked dependency node as completed during reset",
logger.Field{Key: "dependencyNodeID", Value: prevNode.ID},
logger.Field{Key: "previousStatus", Value: string(state.Status)})
}
} else {
if tm.dag.debug {
tm.dag.Logger().Debug("Dependency already satisfied",
logger.Field{Key: "dependencyNodeID", Value: prevNode.ID},
logger.Field{Key: "status", Value: string(state.Status)})
}
}
// Ensure cached result exists for this dependency
if _, hasResult := tm.currentNodeResult.Get(prevNode.ID); !hasResult {
tm.currentNodeResult.Set(prevNode.ID, mq.Result{
Status: mq.Completed,
Ctx: nr.ctx,
})
}
// Clear any deferred tasks for this dependency since it's now satisfied
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
if strings.Split(tsk.nodeID, Delimiter)[0] == prevNode.ID {
tm.deferredTasks.Del(taskID)
if tm.dag.debug {
tm.dag.Logger().Debug("Cleared deferred task for satisfied dependency",
logger.Field{Key: "dependencyNodeID", Value: prevNode.ID},
logger.Field{Key: "taskID", Value: taskID})
}
}
return true
})
}
}

View File

@@ -24,7 +24,7 @@ func subDAG() *dag.DAG {
return f return f
} }
func main() { func mai2n() {
flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) { flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
fmt.Printf("DAG Final result for task %s: %s\n", taskID, string(result.Payload)) fmt.Printf("DAG Final result for task %s: %s\n", taskID, string(result.Payload))
}) })

View File

@@ -11,7 +11,7 @@ import (
) )
// Enhanced DAG Example demonstrates how to use the enhanced DAG system with workflow capabilities // Enhanced DAG Example demonstrates how to use the enhanced DAG system with workflow capabilities
func main() { func mai1n() {
fmt.Println("🚀 Starting Enhanced DAG with Workflow Engine Demo...") fmt.Println("🚀 Starting Enhanced DAG with Workflow Engine Demo...")
// Create enhanced DAG configuration // Create enhanced DAG configuration

View File

@@ -26,17 +26,17 @@ func main() {
// Add SMS workflow nodes // Add SMS workflow nodes
// Note: Page nodes have no timeout by default, allowing users unlimited time for form input // Note: Page nodes have no timeout by default, allowing users unlimited time for form input
flow.AddDAGNode(dag.Page, "Login", "login", loginSubDAG(), true) // flow.AddDAGNode(dag.Page, "Login", "login", loginSubDAG(), true)
flow.AddNode(dag.Page, "SMS Form", "SMSForm", &SMSFormNode{}) flow.AddNode(dag.Page, "SMS Form", "SMSForm", &SMSFormNode{}, true)
flow.AddNode(dag.Function, "Validate Input", "ValidateInput", &ValidateInputNode{}) flow.AddNode(dag.Function, "Validate Input", "ValidateInput", &ValidateInputNode{})
flow.AddNode(dag.Function, "Send SMS", "SendSMS", &SendSMSNode{}) flow.AddNode(dag.Function, "Send SMS", "SendSMS", &SendSMSNode{})
flow.AddNode(dag.Page, "SMS Result", "SMSResult", &SMSResultNode{}) flow.AddNode(dag.Page, "SMS Result", "SMSResult", &SMSResultNode{})
flow.AddNode(dag.Page, "Error Page", "ErrorPage", &ErrorPageNode{}) flow.AddNode(dag.Page, "Error Page", "ErrorPage", &ErrorPageNode{})
// Define edges for SMS workflow // Define edges for SMS workflow
flow.AddEdge(dag.Simple, "Login to Form", "login", "SMSForm") // flow.AddEdge(dag.Simple, "Login to Form", "login", "SMSForm")
flow.AddEdge(dag.Simple, "Form to Validation", "SMSForm", "ValidateInput") flow.AddEdge(dag.Simple, "Form to Validation", "SMSForm", "ValidateInput")
flow.AddCondition("ValidateInput", map[string]string{"valid": "SendSMS", "invalid": "ErrorPage"}) flow.AddCondition("ValidateInput", map[string]string{"valid": "SendSMS"}) // Removed invalid -> ErrorPage since we use ResetTo
flow.AddCondition("SendSMS", map[string]string{"sent": "SMSResult", "failed": "ErrorPage"}) flow.AddCondition("SendSMS", map[string]string{"sent": "SMSResult", "failed": "ErrorPage"})
// Start the flow // Start the flow
@@ -303,12 +303,17 @@ func (s *SMSFormNode) ProcessTask(ctx context.Context, task *mq.Task) mq.Result
var inputData map[string]any var inputData map[string]any
if task.Payload != nil && len(task.Payload) > 0 { if task.Payload != nil && len(task.Payload) > 0 {
if err := json.Unmarshal(task.Payload, &inputData); err == nil { if err := json.Unmarshal(task.Payload, &inputData); err == nil {
// If we have valid input data, pass it through for validation // Check if this is validation error data (contains validation_error)
return mq.Result{Payload: task.Payload, Ctx: ctx} if _, hasValidationError := inputData["validation_error"]; hasValidationError {
// This is validation error data, show the form with errors
} else {
// If we have valid input data, pass it through for validation
return mq.Result{Payload: task.Payload, Ctx: ctx}
}
} }
} }
// Otherwise, show the form // Show the form (either initial load or with validation errors)
htmlTemplate := ` htmlTemplate := `
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
@@ -399,12 +404,21 @@ func (s *SMSFormNode) ProcessTask(ctx context.Context, task *mq.Task) mq.Result
<div class="info"> <div class="info">
<p>Send SMS messages through our secure DAG workflow</p> <p>Send SMS messages through our secure DAG workflow</p>
</div> </div>
{{if validation_error}}
<div class="error-message" style="background: rgba(255, 100, 100, 0.2); border: 1px solid #ff6b6b; padding: 15px; border-radius: 8px; margin-bottom: 20px; color: #ffcccc;">
<strong>⚠️ Validation Error:</strong> {{validation_error}}
</div>
{{end}}
<form method="post" action="/process?task_id={{task_id}}&next=true"> <form method="post" action="/process?task_id={{task_id}}&next=true">
<div class="form-group"> <div class="form-group">
<label for="phone">📞 Phone Number:</label> <label for="phone">📞 Phone Number:</label>
<input type="tel" id="phone" name="phone" <input type="tel" id="phone" name="phone"
placeholder="+1234567890 or 1234567890" placeholder="+1234567890 or 1234567890"
required> value="{{phone}}"
required
{{if error_field_phone}}style="border: 2px solid #ff6b6b; background: rgba(255, 100, 100, 0.1);"{{end}}>
<div class="info" style="margin-top: 5px; font-size: 12px;"> <div class="info" style="margin-top: 5px; font-size: 12px;">
Supports US format: +1234567890 or 1234567890 Supports US format: +1234567890 or 1234567890
</div> </div>
@@ -416,14 +430,16 @@ func (s *SMSFormNode) ProcessTask(ctx context.Context, task *mq.Task) mq.Result
placeholder="Enter your message here..." placeholder="Enter your message here..."
maxlength="160" maxlength="160"
required required
oninput="updateCharCount()"></textarea> oninput="updateCharCount()"
<div class="char-count" id="charCount">0/160 characters</div> {{if error_field_message}}style="border: 2px solid #ff6b6b; background: rgba(255, 100, 100, 0.1);"{{end}}>{{message}}</textarea>
<div class="char-count" id="charCount">{{message_length}}/160 characters</div>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="sender_name">👤 Sender Name (Optional):</label> <label for="sender_name">👤 Sender Name (Optional):</label>
<input type="text" id="sender_name" name="sender_name" <input type="text" id="sender_name" name="sender_name"
placeholder="Your name or organization" placeholder="Your name or organization"
value="{{sender_name}}"
maxlength="50"> maxlength="50">
</div> </div>
@@ -460,9 +476,20 @@ func (s *SMSFormNode) ProcessTask(ctx context.Context, task *mq.Task) mq.Result
</body> </body>
</html>` </html>`
messageStr, _ := inputData["message"].(string)
messageLength := len(messageStr)
parser := jet.NewWithMemory(jet.WithDelims("{{", "}}")) parser := jet.NewWithMemory(jet.WithDelims("{{", "}}"))
rs, err := parser.ParseTemplate(htmlTemplate, map[string]any{ rs, err := parser.ParseTemplate(htmlTemplate, map[string]any{
"task_id": ctx.Value("task_id"), "task_id": ctx.Value("task_id"),
"validation_error": inputData["validation_error"],
"error_field": inputData["error_field"],
"error_field_phone": inputData["error_field"] == "phone",
"error_field_message": inputData["error_field"] == "message",
"phone": inputData["phone"],
"message": inputData["message"],
"message_length": messageLength,
"sender_name": inputData["sender_name"],
}) })
if err != nil { if err != nil {
return mq.Result{Error: err, Ctx: ctx} return mq.Result{Error: err, Ctx: ctx}
@@ -501,7 +528,11 @@ func (v *ValidateInputNode) ProcessTask(ctx context.Context, task *mq.Task) mq.R
inputData["validation_error"] = "Phone number is required" inputData["validation_error"] = "Phone number is required"
inputData["error_field"] = "phone" inputData["error_field"] = "phone"
bt, _ := json.Marshal(inputData) bt, _ := json.Marshal(inputData)
return mq.Result{Payload: bt, Ctx: ctx, ConditionStatus: "invalid"} return mq.Result{
Payload: bt,
Ctx: ctx,
ResetTo: "SMSForm", // Reset to form instead of going to error page
}
} }
// Clean and validate phone number format // Clean and validate phone number format
@@ -514,7 +545,11 @@ func (v *ValidateInputNode) ProcessTask(ctx context.Context, task *mq.Task) mq.R
inputData["validation_error"] = "Invalid phone number format. Please use US format: +1234567890 or 1234567890" inputData["validation_error"] = "Invalid phone number format. Please use US format: +1234567890 or 1234567890"
inputData["error_field"] = "phone" inputData["error_field"] = "phone"
bt, _ := json.Marshal(inputData) bt, _ := json.Marshal(inputData)
return mq.Result{Payload: bt, Ctx: ctx, ConditionStatus: "invalid"} return mq.Result{
Payload: bt,
Ctx: ctx,
ResetTo: "SMSForm", // Reset to form instead of going to error page
}
} }
// Validate message // Validate message
@@ -522,14 +557,22 @@ func (v *ValidateInputNode) ProcessTask(ctx context.Context, task *mq.Task) mq.R
inputData["validation_error"] = "Message is required" inputData["validation_error"] = "Message is required"
inputData["error_field"] = "message" inputData["error_field"] = "message"
bt, _ := json.Marshal(inputData) bt, _ := json.Marshal(inputData)
return mq.Result{Payload: bt, Ctx: ctx, ConditionStatus: "invalid"} return mq.Result{
Payload: bt,
Ctx: ctx,
ResetTo: "SMSForm", // Reset to form instead of going to error page
}
} }
if len(message) > 160 { if len(message) > 160 {
inputData["validation_error"] = "Message too long. Maximum 160 characters allowed" inputData["validation_error"] = "Message too long. Maximum 160 characters allowed"
inputData["error_field"] = "message" inputData["error_field"] = "message"
bt, _ := json.Marshal(inputData) bt, _ := json.Marshal(inputData)
return mq.Result{Payload: bt, Ctx: ctx, ConditionStatus: "invalid"} return mq.Result{
Payload: bt,
Ctx: ctx,
ResetTo: "SMSForm", // Reset to form instead of going to error page
}
} }
// Check for potentially harmful content // Check for potentially harmful content
@@ -540,7 +583,11 @@ func (v *ValidateInputNode) ProcessTask(ctx context.Context, task *mq.Task) mq.R
inputData["validation_error"] = "Message contains prohibited content" inputData["validation_error"] = "Message contains prohibited content"
inputData["error_field"] = "message" inputData["error_field"] = "message"
bt, _ := json.Marshal(inputData) bt, _ := json.Marshal(inputData)
return mq.Result{Payload: bt, Ctx: ctx, ConditionStatus: "invalid"} return mq.Result{
Payload: bt,
Ctx: ctx,
ResetTo: "SMSForm", // Reset to form instead of going to error page
}
} }
} }

View File

@@ -0,0 +1,97 @@
package main
import (
"context"
"fmt"
"log"
"github.com/oarkflow/json"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
)
// ResetToExample demonstrates the ResetTo functionality
type ResetToExample struct {
dag.Operation
}
func (r *ResetToExample) Process(ctx context.Context, task *mq.Task) mq.Result {
payload := string(task.Payload)
log.Printf("Processing node %s with payload: %s", task.Topic, payload)
// Simulate some processing logic
if task.Topic == "step1" {
// For step1, we'll return a result that resets to step2
return mq.Result{
Status: mq.Completed,
Payload: json.RawMessage(`{"message": "Step 1 completed, resetting to step2"}`),
Ctx: ctx,
TaskID: task.ID,
Topic: task.Topic,
ResetTo: "step2", // Reset to step2
}
} else if task.Topic == "step2" {
// For step2, we'll return a result that resets to the previous page node
return mq.Result{
Status: mq.Completed,
Payload: json.RawMessage(`{"message": "Step 2 completed, resetting to back"}`),
Ctx: ctx,
TaskID: task.ID,
Topic: task.Topic,
ResetTo: "back", // Reset to previous page node
}
} else if task.Topic == "step3" {
// Final step
return mq.Result{
Status: mq.Completed,
Payload: json.RawMessage(`{"message": "Step 3 completed - final result"}`),
Ctx: ctx,
TaskID: task.ID,
Topic: task.Topic,
}
}
return mq.Result{
Status: mq.Failed,
Error: fmt.Errorf("unknown step: %s", task.Topic),
Ctx: ctx,
TaskID: task.ID,
Topic: task.Topic,
}
}
func runResetToExample() {
// Create a DAG with ResetTo functionality
flow := dag.NewDAG("ResetTo Example", "reset-to-example", func(taskID string, result mq.Result) {
log.Printf("Final result for task %s: %s", taskID, string(result.Payload))
})
// Add nodes
flow.AddNode(dag.Function, "Step 1", "step1", &ResetToExample{}, true)
flow.AddNode(dag.Page, "Step 2", "step2", &ResetToExample{})
flow.AddNode(dag.Page, "Step 3", "step3", &ResetToExample{})
// Add edges
flow.AddEdge(dag.Simple, "Step 1 to Step 2", "step1", "step2")
flow.AddEdge(dag.Simple, "Step 2 to Step 3", "step2", "step3")
// Validate the DAG
if err := flow.Validate(); err != nil {
log.Fatalf("DAG validation failed: %v", err)
}
// Process a task
data := json.RawMessage(`{"initial": "data"}`)
log.Println("Starting DAG processing...")
result := flow.Process(context.Background(), data)
if result.Error != nil {
log.Printf("Processing failed: %v", result.Error)
} else {
log.Printf("Processing completed successfully: %s", string(result.Payload))
}
}
func main() {
runResetToExample()
}

View File

@@ -551,7 +551,7 @@ func Logger() HandlerFunc {
// Example // Example
// ---------------------------- // ----------------------------
func main() { func mai3n() {
app := New() app := New()
app.Use(Recover()) app.Use(Recover())

View File

@@ -24,7 +24,7 @@ func enhancedSubDAG() *dag.DAG {
return f return f
} }
func main() { func mai4n() {
fmt.Println("🚀 Starting Simple Enhanced DAG Demo...") fmt.Println("🚀 Starting Simple Enhanced DAG Demo...")
// Create enhanced DAG - simple configuration, just like regular DAG but with enhanced features // Create enhanced DAG - simple configuration, just like regular DAG but with enhanced features

View File

@@ -116,7 +116,7 @@ func demonstrateTaskRecovery() {
log.Println("💡 In a real scenario, the recovered task would continue processing from the 'process' node") log.Println("💡 In a real scenario, the recovered task would continue processing from the 'process' node")
} }
func main() { func mai5n() {
fmt.Println("=== DAG Task Recovery Example ===") fmt.Println("=== DAG Task Recovery Example ===")
demonstrateTaskRecovery() demonstrateTaskRecovery()
} }

1
mq.go
View File

@@ -45,6 +45,7 @@ type Result struct {
ConditionStatus string `json:"condition_status"` ConditionStatus string `json:"condition_status"`
Ctx context.Context `json:"-"` Ctx context.Context `json:"-"`
Payload json.RawMessage `json:"payload"` Payload json.RawMessage `json:"payload"`
ResetTo string `json:"reset_to,omitempty"` // Node ID to reset to, or "back" for previous page node
Last bool Last bool
} }