From d814019d737c88fce877f132e6e20b890bc6c54c Mon Sep 17 00:00:00 2001 From: Oarkflow Date: Wed, 30 Jul 2025 12:29:04 +0545 Subject: [PATCH] update --- dag/activity_logger.go | 755 ++++++++++++++++++++++++++++++++++ dag/configuration.go | 5 + dag/consts.go | 10 + dag/dag.go | 642 ++++++++++++++++++++++++----- dag/enhancements.go | 408 +++++++++++------- dag/monitoring.go | 287 +++++++------ dag/node.go | 136 ------ dag/retry.go | 23 ++ dag/task_manager.go | 40 +- examples/enhanced_dag_demo.go | 557 +++++++++++++++++++++++++ 10 files changed, 2356 insertions(+), 507 deletions(-) create mode 100644 dag/activity_logger.go delete mode 100644 dag/node.go create mode 100644 examples/enhanced_dag_demo.go diff --git a/dag/activity_logger.go b/dag/activity_logger.go new file mode 100644 index 0000000..e605de1 --- /dev/null +++ b/dag/activity_logger.go @@ -0,0 +1,755 @@ +package dag + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/logger" +) + +// ActivityLevel represents the severity level of an activity +type ActivityLevel string + +const ( + ActivityLevelDebug ActivityLevel = "debug" + ActivityLevelInfo ActivityLevel = "info" + ActivityLevelWarn ActivityLevel = "warn" + ActivityLevelError ActivityLevel = "error" + ActivityLevelFatal ActivityLevel = "fatal" +) + +// ActivityType represents the type of activity +type ActivityType string + +const ( + ActivityTypeTaskStart ActivityType = "task_start" + ActivityTypeTaskComplete ActivityType = "task_complete" + ActivityTypeTaskFail ActivityType = "task_fail" + ActivityTypeTaskCancel ActivityType = "task_cancel" + ActivityTypeNodeStart ActivityType = "node_start" + ActivityTypeNodeComplete ActivityType = "node_complete" + ActivityTypeNodeFail ActivityType = "node_fail" + ActivityTypeNodeTimeout ActivityType = "node_timeout" + ActivityTypeValidation ActivityType = "validation" + ActivityTypeConfiguration ActivityType = "configuration" + ActivityTypeAlert ActivityType = "alert" + ActivityTypeCleanup ActivityType = "cleanup" + ActivityTypeTransaction ActivityType = "transaction" + ActivityTypeRetry ActivityType = "retry" + ActivityTypeCircuitBreaker ActivityType = "circuit_breaker" + ActivityTypeWebhook ActivityType = "webhook" + ActivityTypeCustom ActivityType = "custom" +) + +// ActivityEntry represents a single activity log entry +type ActivityEntry struct { + ID string `json:"id"` + Timestamp time.Time `json:"timestamp"` + DAGName string `json:"dag_name"` + Level ActivityLevel `json:"level"` + Type ActivityType `json:"type"` + Message string `json:"message"` + TaskID string `json:"task_id,omitempty"` + NodeID string `json:"node_id,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + Success *bool `json:"success,omitempty"` + Error string `json:"error,omitempty"` + Details map[string]interface{} `json:"details,omitempty"` + ContextData map[string]interface{} `json:"context_data,omitempty"` + UserID string `json:"user_id,omitempty"` + SessionID string `json:"session_id,omitempty"` + TraceID string `json:"trace_id,omitempty"` + SpanID string `json:"span_id,omitempty"` +} + +// ActivityFilter provides filtering options for activity queries +type ActivityFilter struct { + StartTime *time.Time `json:"start_time,omitempty"` + EndTime *time.Time `json:"end_time,omitempty"` + Levels []ActivityLevel `json:"levels,omitempty"` + Types []ActivityType `json:"types,omitempty"` + TaskIDs []string `json:"task_ids,omitempty"` + NodeIDs []string `json:"node_ids,omitempty"` + UserIDs []string `json:"user_ids,omitempty"` + SuccessOnly *bool `json:"success_only,omitempty"` + FailuresOnly *bool `json:"failures_only,omitempty"` + Limit int `json:"limit,omitempty"` + Offset int `json:"offset,omitempty"` + SortBy string `json:"sort_by,omitempty"` // timestamp, level, type + SortOrder string `json:"sort_order,omitempty"` // asc, desc +} + +// ActivityStats provides statistics about activities +type ActivityStats struct { + TotalActivities int64 `json:"total_activities"` + ActivitiesByLevel map[ActivityLevel]int64 `json:"activities_by_level"` + ActivitiesByType map[ActivityType]int64 `json:"activities_by_type"` + ActivitiesByNode map[string]int64 `json:"activities_by_node"` + ActivitiesByTask map[string]int64 `json:"activities_by_task"` + SuccessRate float64 `json:"success_rate"` + FailureRate float64 `json:"failure_rate"` + AverageDuration time.Duration `json:"average_duration"` + PeakActivitiesPerMin int64 `json:"peak_activities_per_minute"` + TimeRange ActivityTimeRange `json:"time_range"` + RecentErrors []ActivityEntry `json:"recent_errors"` + TopFailingNodes []NodeFailureStats `json:"top_failing_nodes"` + HourlyDistribution map[string]int64 `json:"hourly_distribution"` +} + +// ActivityTimeRange represents a time range for activities +type ActivityTimeRange struct { + Start time.Time `json:"start"` + End time.Time `json:"end"` +} + +// NodeFailureStats represents failure statistics for a node +type NodeFailureStats struct { + NodeID string `json:"node_id"` + FailureCount int64 `json:"failure_count"` + FailureRate float64 `json:"failure_rate"` + LastFailure time.Time `json:"last_failure"` +} + +// ActivityHook allows custom processing of activity entries +type ActivityHook interface { + OnActivity(entry ActivityEntry) error +} + +// ActivityPersistence defines the interface for persisting activities +type ActivityPersistence interface { + Store(entries []ActivityEntry) error + Query(filter ActivityFilter) ([]ActivityEntry, error) + GetStats(filter ActivityFilter) (ActivityStats, error) + Close() error +} + +// ActivityLoggerConfig configures the activity logger +type ActivityLoggerConfig struct { + BufferSize int `json:"buffer_size"` + FlushInterval time.Duration `json:"flush_interval"` + MaxRetries int `json:"max_retries"` + EnableHooks bool `json:"enable_hooks"` + EnableCompression bool `json:"enable_compression"` + MaxEntryAge time.Duration `json:"max_entry_age"` + AsyncMode bool `json:"async_mode"` +} + +// DefaultActivityLoggerConfig returns default configuration +func DefaultActivityLoggerConfig() ActivityLoggerConfig { + return ActivityLoggerConfig{ + BufferSize: 1000, + FlushInterval: 5 * time.Second, + MaxRetries: 3, + EnableHooks: true, + EnableCompression: false, + MaxEntryAge: 24 * time.Hour, + AsyncMode: true, + } +} + +// ActivityLogger provides comprehensive activity logging for DAG operations +type ActivityLogger struct { + dagName string + config ActivityLoggerConfig + persistence ActivityPersistence + logger logger.Logger + buffer []ActivityEntry + bufferMu sync.Mutex + hooks []ActivityHook + hooksMu sync.RWMutex + stopCh chan struct{} + flushCh chan struct{} + running bool + runningMu sync.RWMutex + stats ActivityStats + statsMu sync.RWMutex +} + +// NewActivityLogger creates a new activity logger +func NewActivityLogger(dagName string, config ActivityLoggerConfig, persistence ActivityPersistence, logger logger.Logger) *ActivityLogger { + al := &ActivityLogger{ + dagName: dagName, + config: config, + persistence: persistence, + logger: logger, + buffer: make([]ActivityEntry, 0, config.BufferSize), + hooks: make([]ActivityHook, 0), + stopCh: make(chan struct{}), + flushCh: make(chan struct{}, 1), + stats: ActivityStats{ + ActivitiesByLevel: make(map[ActivityLevel]int64), + ActivitiesByType: make(map[ActivityType]int64), + ActivitiesByNode: make(map[string]int64), + ActivitiesByTask: make(map[string]int64), + HourlyDistribution: make(map[string]int64), + }, + } + + if config.AsyncMode { + al.start() + } + + return al +} + +// start begins the async processing routines +func (al *ActivityLogger) start() { + al.runningMu.Lock() + defer al.runningMu.Unlock() + + if al.running { + return + } + + al.running = true + go al.flushRoutine() +} + +// Stop stops the activity logger +func (al *ActivityLogger) Stop() { + al.runningMu.Lock() + defer al.runningMu.Unlock() + + if !al.running { + return + } + + al.running = false + close(al.stopCh) + + // Final flush + al.Flush() +} + +// flushRoutine handles periodic flushing of the buffer +func (al *ActivityLogger) flushRoutine() { + ticker := time.NewTicker(al.config.FlushInterval) + defer ticker.Stop() + + for { + select { + case <-al.stopCh: + return + case <-ticker.C: + al.Flush() + case <-al.flushCh: + al.Flush() + } + } +} + +// Log logs an activity entry +func (al *ActivityLogger) Log(level ActivityLevel, activityType ActivityType, message string, details map[string]interface{}) { + al.LogWithContext(context.Background(), level, activityType, message, details) +} + +// LogWithContext logs an activity entry with context information +func (al *ActivityLogger) LogWithContext(ctx context.Context, level ActivityLevel, activityType ActivityType, message string, details map[string]interface{}) { + entry := ActivityEntry{ + ID: mq.NewID(), + Timestamp: time.Now(), + DAGName: al.dagName, + Level: level, + Type: activityType, + Message: message, + Details: details, + ContextData: make(map[string]interface{}), + } + + // Extract context information + if taskID, ok := ctx.Value("task_id").(string); ok { + entry.TaskID = taskID + } + if nodeID, ok := ctx.Value("node_id").(string); ok { + entry.NodeID = nodeID + } + if userID, ok := ctx.Value("user_id").(string); ok { + entry.UserID = userID + } + if sessionID, ok := ctx.Value("session_id").(string); ok { + entry.SessionID = sessionID + } + if traceID, ok := ctx.Value("trace_id").(string); ok { + entry.TraceID = traceID + } + if spanID, ok := ctx.Value("span_id").(string); ok { + entry.SpanID = spanID + } + if duration, ok := ctx.Value("duration").(time.Duration); ok { + entry.Duration = duration + } + if err, ok := ctx.Value("error").(error); ok { + entry.Error = err.Error() + success := false + entry.Success = &success + } + + // Extract additional context data + for key, value := range map[string]interface{}{ + "method": ctx.Value("method"), + "user_agent": ctx.Value("user_agent"), + "ip_address": ctx.Value("ip_address"), + "request_id": ctx.Value("request_id"), + } { + if value != nil { + entry.ContextData[key] = value + } + } + + al.addEntry(entry) +} + +// LogTaskStart logs task start activity +func (al *ActivityLogger) LogTaskStart(ctx context.Context, taskID string, nodeID string) { + al.LogWithContext(ctx, ActivityLevelInfo, ActivityTypeTaskStart, + fmt.Sprintf("Task %s started on node %s", taskID, nodeID), + map[string]interface{}{ + "task_id": taskID, + "node_id": nodeID, + }) +} + +// LogTaskComplete logs task completion activity +func (al *ActivityLogger) LogTaskComplete(ctx context.Context, taskID string, nodeID string, duration time.Duration) { + success := true + entry := ActivityEntry{ + ID: mq.NewID(), + Timestamp: time.Now(), + DAGName: al.dagName, + Level: ActivityLevelInfo, + Type: ActivityTypeTaskComplete, + Message: fmt.Sprintf("Task %s completed successfully on node %s", taskID, nodeID), + TaskID: taskID, + NodeID: nodeID, + Duration: duration, + Success: &success, + Details: map[string]interface{}{ + "task_id": taskID, + "node_id": nodeID, + "duration": duration.String(), + }, + } + al.addEntry(entry) +} + +// LogTaskFail logs task failure activity +func (al *ActivityLogger) LogTaskFail(ctx context.Context, taskID string, nodeID string, err error, duration time.Duration) { + success := false + entry := ActivityEntry{ + ID: mq.NewID(), + Timestamp: time.Now(), + DAGName: al.dagName, + Level: ActivityLevelError, + Type: ActivityTypeTaskFail, + Message: fmt.Sprintf("Task %s failed on node %s: %s", taskID, nodeID, err.Error()), + TaskID: taskID, + NodeID: nodeID, + Duration: duration, + Success: &success, + Error: err.Error(), + Details: map[string]interface{}{ + "task_id": taskID, + "node_id": nodeID, + "duration": duration.String(), + "error": err.Error(), + }, + } + al.addEntry(entry) +} + +// LogNodeExecution logs node execution details +func (al *ActivityLogger) LogNodeExecution(ctx context.Context, taskID string, nodeID string, result mq.Result, duration time.Duration) { + if result.Error != nil { + al.LogTaskFail(ctx, taskID, nodeID, result.Error, duration) + } else { + al.LogTaskComplete(ctx, taskID, nodeID, duration) + } +} + +// addEntry adds an entry to the buffer and triggers hooks +func (al *ActivityLogger) addEntry(entry ActivityEntry) { + // Update statistics + al.updateStats(entry) + + // Trigger hooks + if al.config.EnableHooks { + al.triggerHooks(entry) + } + + // Add to buffer + al.bufferMu.Lock() + al.buffer = append(al.buffer, entry) + shouldFlush := len(al.buffer) >= al.config.BufferSize + al.bufferMu.Unlock() + + // Trigger flush if buffer is full + if shouldFlush { + select { + case al.flushCh <- struct{}{}: + default: + } + } + + // Also log to standard logger for immediate feedback + fields := []logger.Field{ + {Key: "activity_id", Value: entry.ID}, + {Key: "dag_name", Value: entry.DAGName}, + {Key: "type", Value: string(entry.Type)}, + {Key: "task_id", Value: entry.TaskID}, + {Key: "node_id", Value: entry.NodeID}, + } + + if entry.Duration > 0 { + fields = append(fields, logger.Field{Key: "duration", Value: entry.Duration.String()}) + } + + switch entry.Level { + case ActivityLevelError, ActivityLevelFatal: + al.logger.Error(entry.Message, fields...) + case ActivityLevelWarn: + al.logger.Warn(entry.Message, fields...) + case ActivityLevelDebug: + al.logger.Debug(entry.Message, fields...) + default: + al.logger.Info(entry.Message, fields...) + } +} + +// updateStats updates internal statistics +func (al *ActivityLogger) updateStats(entry ActivityEntry) { + al.statsMu.Lock() + defer al.statsMu.Unlock() + + al.stats.TotalActivities++ + al.stats.ActivitiesByLevel[entry.Level]++ + al.stats.ActivitiesByType[entry.Type]++ + + if entry.NodeID != "" { + al.stats.ActivitiesByNode[entry.NodeID]++ + } + + if entry.TaskID != "" { + al.stats.ActivitiesByTask[entry.TaskID]++ + } + + // Update hourly distribution + hour := entry.Timestamp.Format("2006-01-02T15") + al.stats.HourlyDistribution[hour]++ + + // Track recent errors + if entry.Level == ActivityLevelError || entry.Level == ActivityLevelFatal { + al.stats.RecentErrors = append(al.stats.RecentErrors, entry) + // Keep only last 10 errors + if len(al.stats.RecentErrors) > 10 { + al.stats.RecentErrors = al.stats.RecentErrors[len(al.stats.RecentErrors)-10:] + } + } +} + +// triggerHooks executes all registered hooks +func (al *ActivityLogger) triggerHooks(entry ActivityEntry) { + al.hooksMu.RLock() + hooks := make([]ActivityHook, len(al.hooks)) + copy(hooks, al.hooks) + al.hooksMu.RUnlock() + + for _, hook := range hooks { + go func(h ActivityHook, e ActivityEntry) { + if err := h.OnActivity(e); err != nil { + al.logger.Error("Activity hook error", + logger.Field{Key: "error", Value: err.Error()}, + logger.Field{Key: "activity_id", Value: e.ID}, + ) + } + }(hook, entry) + } +} + +// AddHook adds an activity hook +func (al *ActivityLogger) AddHook(hook ActivityHook) { + al.hooksMu.Lock() + defer al.hooksMu.Unlock() + al.hooks = append(al.hooks, hook) +} + +// RemoveHook removes an activity hook +func (al *ActivityLogger) RemoveHook(hook ActivityHook) { + al.hooksMu.Lock() + defer al.hooksMu.Unlock() + + for i, h := range al.hooks { + if h == hook { + al.hooks = append(al.hooks[:i], al.hooks[i+1:]...) + break + } + } +} + +// Flush flushes the buffer to persistence +func (al *ActivityLogger) Flush() error { + al.bufferMu.Lock() + if len(al.buffer) == 0 { + al.bufferMu.Unlock() + return nil + } + + entries := make([]ActivityEntry, len(al.buffer)) + copy(entries, al.buffer) + al.buffer = al.buffer[:0] // Clear buffer + al.bufferMu.Unlock() + + if al.persistence == nil { + return nil + } + + // Retry logic + var err error + for attempt := 0; attempt < al.config.MaxRetries; attempt++ { + err = al.persistence.Store(entries) + if err == nil { + al.logger.Debug("Activity entries flushed to persistence", + logger.Field{Key: "count", Value: len(entries)}, + ) + return nil + } + + al.logger.Warn("Failed to flush activity entries", + logger.Field{Key: "attempt", Value: attempt + 1}, + logger.Field{Key: "error", Value: err.Error()}, + ) + + if attempt < al.config.MaxRetries-1 { + time.Sleep(time.Duration(attempt+1) * time.Second) + } + } + + return fmt.Errorf("failed to flush activities after %d attempts: %w", al.config.MaxRetries, err) +} + +// GetActivities retrieves activities based on filter +func (al *ActivityLogger) GetActivities(filter ActivityFilter) ([]ActivityEntry, error) { + if al.persistence == nil { + return nil, fmt.Errorf("persistence not configured") + } + return al.persistence.Query(filter) +} + +// GetStats returns activity statistics +func (al *ActivityLogger) GetStats(filter ActivityFilter) (ActivityStats, error) { + if al.persistence == nil { + // Return in-memory stats if no persistence + al.statsMu.RLock() + stats := al.stats + al.statsMu.RUnlock() + return stats, nil + } + return al.persistence.GetStats(filter) +} + +// MemoryActivityPersistence provides in-memory activity persistence for testing +type MemoryActivityPersistence struct { + entries []ActivityEntry + mu sync.RWMutex +} + +// NewMemoryActivityPersistence creates a new in-memory persistence +func NewMemoryActivityPersistence() *MemoryActivityPersistence { + return &MemoryActivityPersistence{ + entries: make([]ActivityEntry, 0), + } +} + +// Store stores activity entries in memory +func (mp *MemoryActivityPersistence) Store(entries []ActivityEntry) error { + mp.mu.Lock() + defer mp.mu.Unlock() + mp.entries = append(mp.entries, entries...) + return nil +} + +// Query queries activity entries with filter +func (mp *MemoryActivityPersistence) Query(filter ActivityFilter) ([]ActivityEntry, error) { + mp.mu.RLock() + defer mp.mu.RUnlock() + + var result []ActivityEntry + for _, entry := range mp.entries { + if mp.matchesFilter(entry, filter) { + result = append(result, entry) + } + } + + // Apply limit and offset + if filter.Offset > 0 && filter.Offset < len(result) { + result = result[filter.Offset:] + } + if filter.Limit > 0 && filter.Limit < len(result) { + result = result[:filter.Limit] + } + + return result, nil +} + +// matchesFilter checks if an entry matches the filter +func (mp *MemoryActivityPersistence) matchesFilter(entry ActivityEntry, filter ActivityFilter) bool { + // Time range check + if filter.StartTime != nil && entry.Timestamp.Before(*filter.StartTime) { + return false + } + if filter.EndTime != nil && entry.Timestamp.After(*filter.EndTime) { + return false + } + + // Level filter + if len(filter.Levels) > 0 { + found := false + for _, level := range filter.Levels { + if entry.Level == level { + found = true + break + } + } + if !found { + return false + } + } + + // Type filter + if len(filter.Types) > 0 { + found := false + for _, typ := range filter.Types { + if entry.Type == typ { + found = true + break + } + } + if !found { + return false + } + } + + // Task ID filter + if len(filter.TaskIDs) > 0 { + found := false + for _, taskID := range filter.TaskIDs { + if entry.TaskID == taskID { + found = true + break + } + } + if !found { + return false + } + } + + // Node ID filter + if len(filter.NodeIDs) > 0 { + found := false + for _, nodeID := range filter.NodeIDs { + if entry.NodeID == nodeID { + found = true + break + } + } + if !found { + return false + } + } + + // Success/failure filters + if filter.SuccessOnly != nil && *filter.SuccessOnly { + if entry.Success == nil || !*entry.Success { + return false + } + } + if filter.FailuresOnly != nil && *filter.FailuresOnly { + if entry.Success == nil || *entry.Success { + return false + } + } + + return true +} + +// GetStats returns statistics for the filtered entries +func (mp *MemoryActivityPersistence) GetStats(filter ActivityFilter) (ActivityStats, error) { + entries, err := mp.Query(filter) + if err != nil { + return ActivityStats{}, err + } + + stats := ActivityStats{ + ActivitiesByLevel: make(map[ActivityLevel]int64), + ActivitiesByType: make(map[ActivityType]int64), + ActivitiesByNode: make(map[string]int64), + ActivitiesByTask: make(map[string]int64), + HourlyDistribution: make(map[string]int64), + } + + var totalDuration time.Duration + var durationCount int64 + var successCount int64 + var failureCount int64 + + for _, entry := range entries { + stats.TotalActivities++ + stats.ActivitiesByLevel[entry.Level]++ + stats.ActivitiesByType[entry.Type]++ + + if entry.NodeID != "" { + stats.ActivitiesByNode[entry.NodeID]++ + } + if entry.TaskID != "" { + stats.ActivitiesByTask[entry.TaskID]++ + } + + hour := entry.Timestamp.Format("2006-01-02T15") + stats.HourlyDistribution[hour]++ + + if entry.Duration > 0 { + totalDuration += entry.Duration + durationCount++ + } + + if entry.Success != nil { + if *entry.Success { + successCount++ + } else { + failureCount++ + } + } + + if entry.Level == ActivityLevelError || entry.Level == ActivityLevelFatal { + stats.RecentErrors = append(stats.RecentErrors, entry) + } + } + + // Calculate rates and averages + if durationCount > 0 { + stats.AverageDuration = totalDuration / time.Duration(durationCount) + } + + total := successCount + failureCount + if total > 0 { + stats.SuccessRate = float64(successCount) / float64(total) + stats.FailureRate = float64(failureCount) / float64(total) + } + + // Keep only last 10 errors + if len(stats.RecentErrors) > 10 { + stats.RecentErrors = stats.RecentErrors[len(stats.RecentErrors)-10:] + } + + return stats, nil +} + +// Close closes the persistence +func (mp *MemoryActivityPersistence) Close() error { + mp.mu.Lock() + defer mp.mu.Unlock() + mp.entries = nil + return nil +} diff --git a/dag/configuration.go b/dag/configuration.go index 0a820da..8793c34 100644 --- a/dag/configuration.go +++ b/dag/configuration.go @@ -359,6 +359,11 @@ func (cm *ConfigManager) UpdateConfig(newConfig *DAGConfig) error { return nil } +// UpdateConfiguration updates the DAG configuration (alias for UpdateConfig) +func (cm *ConfigManager) UpdateConfiguration(config *DAGConfig) error { + return cm.UpdateConfig(config) +} + // AddWatcher adds a configuration watcher func (cm *ConfigManager) AddWatcher(watcher ConfigWatcher) { cm.mu.Lock() diff --git a/dag/consts.go b/dag/consts.go index 5c5d8df..364872b 100644 --- a/dag/consts.go +++ b/dag/consts.go @@ -32,6 +32,16 @@ type EdgeType int func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator } +func (c EdgeType) String() string { + switch c { + case Simple: + return "Simple" + case Iterator: + return "Iterator" + } + return "Simple" +} + const ( Simple EdgeType = iota Iterator diff --git a/dag/dag.go b/dag/dag.go index 7db6471..00beb5f 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -94,6 +94,7 @@ type DAG struct { cleanupManager *CleanupManager webhookManager *WebhookManager performanceOptimizer *PerformanceOptimizer + activityLogger *ActivityLogger // Circuit breakers per node circuitBreakers map[string]*CircuitBreaker @@ -874,17 +875,14 @@ func (tm *DAG) RemoveNode(nodeID string) error { Type: Simple, // Use Simple edge type for adjusted flows. } // Append new edge if one doesn't already exist. - existsNewEdge := false for _, e := range inEdge.From.Edges { if e.To.ID == newEdge.To.ID { - existsNewEdge = true - break + goto SKIP_ADD } } - if !existsNewEdge { - inEdge.From.Edges = append(inEdge.From.Edges, newEdge) - } + inEdge.From.Edges = append(inEdge.From.Edges, newEdge) } + SKIP_ADD: } } // Remove all edges that are connected to the removed node. @@ -951,9 +949,338 @@ func (tm *DAG) getOrCreateCircuitBreaker(nodeID string) *CircuitBreaker { return cb } -// Enhanced DAG methods for new features +// Complete missing methods for DAG -// ValidateDAG validates the DAG structure +func (tm *DAG) GetLastNodes() ([]*Node, error) { + var lastNodes []*Node + tm.nodes.ForEach(func(key string, node *Node) bool { + if len(node.Edges) == 0 { + if conds, exists := tm.conditions[node.ID]; !exists || len(conds) == 0 { + lastNodes = append(lastNodes, node) + } + } + return true + }) + return lastNodes, nil +} + +// parseInitialNode extracts the initial node from context +func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { + if initialNode, ok := ctx.Value("initial_node").(string); ok && initialNode != "" { + return initialNode, nil + } + + // If no initial node specified, use start node + if tm.startNode != "" { + return tm.startNode, nil + } + + // Find first node if no start node is set + firstNode := tm.findStartNode() + if firstNode != nil { + return firstNode.ID, nil + } + + return "", fmt.Errorf("no initial node found") +} + +// findStartNode finds the first node in the DAG +func (tm *DAG) findStartNode() *Node { + incomingEdges := make(map[string]bool) + connectedNodes := make(map[string]bool) + for _, node := range tm.nodes.AsMap() { + for _, edge := range node.Edges { + if edge.Type.IsValid() { + connectedNodes[node.ID] = true + connectedNodes[edge.To.ID] = true + incomingEdges[edge.To.ID] = true + } + } + if cond, ok := tm.conditions[node.ID]; ok { + for _, target := range cond { + connectedNodes[target] = true + incomingEdges[target] = true + } + } + } + for nodeID, node := range tm.nodes.AsMap() { + if !incomingEdges[nodeID] && connectedNodes[nodeID] { + return node + } + } + return nil +} + +// IsLastNode checks if a node is the last node in the DAG +func (tm *DAG) IsLastNode(nodeID string) (bool, error) { + node, exists := tm.nodes.Get(nodeID) + if !exists { + return false, fmt.Errorf("node %s not found", nodeID) + } + + // Check if node has any outgoing edges + if len(node.Edges) > 0 { + return false, nil + } + + // Check if node has any conditional edges + if conditions, exists := tm.conditions[nodeID]; exists && len(conditions) > 0 { + return false, nil + } + + return true, nil +} + +// GetNextNodes returns the next nodes for a given node +func (tm *DAG) GetNextNodes(nodeID string) ([]*Node, error) { + nodeID = strings.Split(nodeID, Delimiter)[0] + if tm.nextNodesCache != nil { + if cached, exists := tm.nextNodesCache[nodeID]; exists { + return cached, nil + } + } + + node, exists := tm.nodes.Get(nodeID) + if !exists { + return nil, fmt.Errorf("node %s not found", nodeID) + } + + var nextNodes []*Node + + // Add direct edge targets + for _, edge := range node.Edges { + nextNodes = append(nextNodes, edge.To) + } + + // Add conditional targets + if conditions, exists := tm.conditions[nodeID]; exists { + for _, targetID := range conditions { + if targetNode, ok := tm.nodes.Get(targetID); ok { + nextNodes = append(nextNodes, targetNode) + } + } + } + + // Cache the result + if tm.nextNodesCache != nil { + tm.nextNodesCache[nodeID] = nextNodes + } + + return nextNodes, nil +} + +// GetPreviousNodes returns the previous nodes for a given node +func (tm *DAG) GetPreviousNodes(nodeID string) ([]*Node, error) { + nodeID = strings.Split(nodeID, Delimiter)[0] + if tm.prevNodesCache != nil { + if cached, exists := tm.prevNodesCache[nodeID]; exists { + return cached, nil + } + } + + var prevNodes []*Node + + // Find nodes that point to this node + tm.nodes.ForEach(func(id string, node *Node) bool { + // Check direct edges + for _, edge := range node.Edges { + if edge.To.ID == nodeID { + prevNodes = append(prevNodes, node) + break + } + } + + // Check conditional edges + if conditions, exists := tm.conditions[id]; exists { + for _, targetID := range conditions { + if targetID == nodeID { + prevNodes = append(prevNodes, node) + break + } + } + } + + return true + }) + + // Cache the result + if tm.prevNodesCache != nil { + tm.prevNodesCache[nodeID] = prevNodes + } + + return prevNodes, nil +} + +// GetNodeByID returns a node by its ID +func (tm *DAG) GetNodeByID(nodeID string) (*Node, error) { + node, exists := tm.nodes.Get(nodeID) + if !exists { + return nil, fmt.Errorf("node %s not found", nodeID) + } + return node, nil +} + +// GetAllNodes returns all nodes in the DAG +func (tm *DAG) GetAllNodes() map[string]*Node { + result := make(map[string]*Node) + tm.nodes.ForEach(func(id string, node *Node) bool { + result[id] = node + return true + }) + return result +} + +// GetNodeCount returns the total number of nodes +func (tm *DAG) GetNodeCount() int { + return tm.nodes.Size() +} + +// GetEdgeCount returns the total number of edges +func (tm *DAG) GetEdgeCount() int { + count := 0 + tm.nodes.ForEach(func(id string, node *Node) bool { + count += len(node.Edges) + return true + }) + + // Add conditional edges + for _, conditions := range tm.conditions { + count += len(conditions) + } + + return count +} + +// Clone creates a deep copy of the DAG +func (tm *DAG) Clone() *DAG { + newDAG := NewDAG(tm.name+"_clone", tm.key, tm.finalResult) + + // Copy nodes + tm.nodes.ForEach(func(id string, node *Node) bool { + newDAG.AddNode(node.NodeType, node.Label, node.ID, node.processor) + return true + }) + + // Copy edges + tm.nodes.ForEach(func(id string, node *Node) bool { + for _, edge := range node.Edges { + newDAG.AddEdge(edge.Type, edge.Label, edge.From.ID, edge.To.ID) + } + return true + }) + + // Copy conditions + for fromNode, conditions := range tm.conditions { + newDAG.AddCondition(fromNode, conditions) + } + + // Copy start node + newDAG.SetStartNode(tm.startNode) + + return newDAG +} + +// Export exports the DAG structure to a serializable format +func (tm *DAG) Export() map[string]interface{} { + export := map[string]interface{}{ + "name": tm.name, + "key": tm.key, + "start_node": tm.startNode, + "nodes": make([]map[string]interface{}, 0), + "edges": make([]map[string]interface{}, 0), + "conditions": tm.conditions, + } + + // Export nodes + tm.nodes.ForEach(func(id string, node *Node) bool { + nodeData := map[string]interface{}{ + "id": node.ID, + "label": node.Label, + "type": node.NodeType.String(), + "is_ready": node.isReady, + } + export["nodes"] = append(export["nodes"].([]map[string]interface{}), nodeData) + return true + }) + + // Export edges + tm.nodes.ForEach(func(id string, node *Node) bool { + for _, edge := range node.Edges { + edgeData := map[string]interface{}{ + "from": edge.From.ID, + "to": edge.To.ID, + "label": edge.Label, + "type": edge.Type.String(), + } + export["edges"] = append(export["edges"].([]map[string]interface{}), edgeData) + } + return true + }) + + return export +} + +// Enhanced DAG Methods for Production-Ready Features + +// InitializeActivityLogger initializes the activity logger for the DAG +func (tm *DAG) InitializeActivityLogger(config ActivityLoggerConfig, persistence ActivityPersistence) { + tm.activityLogger = NewActivityLogger(tm.name, config, persistence, tm.Logger()) + + // Add activity logging hooks to existing components + if tm.monitor != nil { + tm.monitor.AddAlertHandler(&ActivityAlertHandler{activityLogger: tm.activityLogger}) + } + + tm.Logger().Info("Activity logger initialized for DAG", + logger.Field{Key: "dag_name", Value: tm.name}) +} + +// GetActivityLogger returns the activity logger instance +func (tm *DAG) GetActivityLogger() *ActivityLogger { + return tm.activityLogger +} + +// LogActivity logs an activity entry +func (tm *DAG) LogActivity(ctx context.Context, level ActivityLevel, activityType ActivityType, message string, details map[string]interface{}) { + if tm.activityLogger != nil { + tm.activityLogger.LogWithContext(ctx, level, activityType, message, details) + } +} + +// GetActivityStats returns activity statistics +func (tm *DAG) GetActivityStats(filter ActivityFilter) (ActivityStats, error) { + if tm.activityLogger != nil { + return tm.activityLogger.GetStats(filter) + } + return ActivityStats{}, fmt.Errorf("activity logger not initialized") +} + +// GetActivities retrieves activities based on filter +func (tm *DAG) GetActivities(filter ActivityFilter) ([]ActivityEntry, error) { + if tm.activityLogger != nil { + return tm.activityLogger.GetActivities(filter) + } + return nil, fmt.Errorf("activity logger not initialized") +} + +// AddActivityHook adds an activity hook +func (tm *DAG) AddActivityHook(hook ActivityHook) { + if tm.activityLogger != nil { + tm.activityLogger.AddHook(hook) + } +} + +// FlushActivityLogs flushes activity logs to persistence +func (tm *DAG) FlushActivityLogs() error { + if tm.activityLogger != nil { + return tm.activityLogger.Flush() + } + return fmt.Errorf("activity logger not initialized") +} + +// Enhanced Monitoring and Management Methods + +// ValidateDAG validates the DAG structure using the enhanced validator func (tm *DAG) ValidateDAG() error { if tm.validator == nil { return fmt.Errorf("validator not initialized") @@ -961,42 +1288,42 @@ func (tm *DAG) ValidateDAG() error { return tm.validator.ValidateStructure() } -// StartMonitoring starts DAG monitoring +// GetTopologicalOrder returns nodes in topological order +func (tm *DAG) GetTopologicalOrder() ([]string, error) { + if tm.validator == nil { + return nil, fmt.Errorf("validator not initialized") + } + return tm.validator.GetTopologicalOrder() +} + +// GetCriticalPath returns the critical path of the DAG +func (tm *DAG) GetCriticalPath() ([]string, error) { + if tm.validator == nil { + return nil, fmt.Errorf("validator not initialized") + } + return tm.validator.GetCriticalPath() +} + +// GetDAGStatistics returns comprehensive DAG statistics +func (tm *DAG) GetDAGStatistics() map[string]interface{} { + if tm.validator == nil { + return map[string]interface{}{"error": "validator not initialized"} + } + return tm.validator.GetNodeStatistics() +} + +// StartMonitoring starts the monitoring system func (tm *DAG) StartMonitoring(ctx context.Context) { if tm.monitor != nil { tm.monitor.Start(ctx) } - if tm.cleanupManager != nil { - tm.cleanupManager.Start(ctx) - } } -// StopMonitoring stops DAG monitoring +// StopMonitoring stops the monitoring system func (tm *DAG) StopMonitoring() { if tm.monitor != nil { tm.monitor.Stop() } - if tm.cleanupManager != nil { - tm.cleanupManager.Stop() - } - if tm.cache != nil { - tm.cache.Stop() - } - if tm.batchProcessor != nil { - tm.batchProcessor.Stop() - } -} - -// SetRateLimit sets rate limit for a node -func (tm *DAG) SetRateLimit(nodeID string, requestsPerSecond float64, burst int) { - if tm.rateLimiter != nil { - tm.rateLimiter.SetNodeLimit(nodeID, requestsPerSecond, burst) - } -} - -// SetWebhookManager sets the webhook manager -func (tm *DAG) SetWebhookManager(webhookManager *WebhookManager) { - tm.webhookManager = webhookManager } // GetMonitoringMetrics returns current monitoring metrics @@ -1009,21 +1336,100 @@ func (tm *DAG) GetMonitoringMetrics() *MonitoringMetrics { // GetNodeStats returns statistics for a specific node func (tm *DAG) GetNodeStats(nodeID string) *NodeStats { - if tm.monitor != nil { + if tm.monitor != nil && tm.monitor.metrics != nil { return tm.monitor.metrics.GetNodeStats(nodeID) } return nil } -// OptimizePerformance runs performance optimization -func (tm *DAG) OptimizePerformance() error { - if tm.performanceOptimizer != nil { - return tm.performanceOptimizer.OptimizePerformance() +// SetAlertThresholds configures alert thresholds +func (tm *DAG) SetAlertThresholds(thresholds *AlertThresholds) { + if tm.monitor != nil { + tm.monitor.SetAlertThresholds(thresholds) } - return fmt.Errorf("performance optimizer not initialized") } -// BeginTransaction starts a new transaction for task execution +// AddAlertHandler adds an alert handler +func (tm *DAG) AddAlertHandler(handler AlertHandler) { + if tm.monitor != nil { + tm.monitor.AddAlertHandler(handler) + } +} + +// Configuration Management Methods + +// GetConfiguration returns current DAG configuration +func (tm *DAG) GetConfiguration() *DAGConfig { + if tm.configManager != nil { + return tm.configManager.GetConfig() + } + return DefaultDAGConfig() +} + +// UpdateConfiguration updates the DAG configuration +func (tm *DAG) UpdateConfiguration(config *DAGConfig) error { + if tm.configManager != nil { + return tm.configManager.UpdateConfiguration(config) + } + return fmt.Errorf("config manager not initialized") +} + +// AddConfigWatcher adds a configuration change watcher +func (tm *DAG) AddConfigWatcher(watcher ConfigWatcher) { + if tm.configManager != nil { + tm.configManager.AddWatcher(watcher) + } +} + +// Rate Limiting Methods + +// SetRateLimit sets rate limit for a specific node +func (tm *DAG) SetRateLimit(nodeID string, requestsPerSecond float64, burst int) { + if tm.rateLimiter != nil { + tm.rateLimiter.SetNodeLimit(nodeID, requestsPerSecond, burst) + } +} + +// CheckRateLimit checks if request is allowed for a node +func (tm *DAG) CheckRateLimit(nodeID string) bool { + if tm.rateLimiter != nil { + return tm.rateLimiter.Allow(nodeID) + } + return true +} + +// Retry and Circuit Breaker Methods + +// SetRetryConfig sets the retry configuration +func (tm *DAG) SetRetryConfig(config *RetryConfig) { + if tm.retryManager != nil { + tm.retryManager.SetGlobalConfig(config) + } +} + +// AddNodeWithRetry adds a node with specific retry configuration +func (tm *DAG) AddNodeWithRetry(nodeType NodeType, name, nodeID string, handler mq.Processor, retryConfig *RetryConfig, startNode ...bool) *DAG { + tm.AddNode(nodeType, name, nodeID, handler, startNode...) + if tm.retryManager != nil { + tm.retryManager.SetNodeConfig(nodeID, retryConfig) + } + return tm +} + +// GetCircuitBreakerStatus returns circuit breaker status for a node +func (tm *DAG) GetCircuitBreakerStatus(nodeID string) CircuitBreakerState { + tm.circuitBreakersMu.RLock() + defer tm.circuitBreakersMu.RUnlock() + + if cb, exists := tm.circuitBreakers[nodeID]; exists { + return cb.GetState() + } + return CircuitClosed +} + +// Transaction Management Methods + +// BeginTransaction starts a new transaction func (tm *DAG) BeginTransaction(taskID string) *Transaction { if tm.transactionManager != nil { return tm.transactionManager.BeginTransaction(taskID) @@ -1047,77 +1453,125 @@ func (tm *DAG) RollbackTransaction(txID string) error { return fmt.Errorf("transaction manager not initialized") } -// GetTopologicalOrder returns nodes in topological order -func (tm *DAG) GetTopologicalOrder() ([]string, error) { - if tm.validator != nil { - return tm.validator.GetTopologicalOrder() +// GetTransaction retrieves transaction details +func (tm *DAG) GetTransaction(txID string) (*Transaction, error) { + if tm.transactionManager != nil { + return tm.transactionManager.GetTransaction(txID) } - return nil, fmt.Errorf("validator not initialized") + return nil, fmt.Errorf("transaction manager not initialized") } -// GetCriticalPath finds the longest path in the DAG -func (tm *DAG) GetCriticalPath() ([]string, error) { - if tm.validator != nil { - return tm.validator.GetCriticalPath() - } - return nil, fmt.Errorf("validator not initialized") -} +// Batch Processing Methods -// GetDAGStatistics returns comprehensive DAG statistics -func (tm *DAG) GetDAGStatistics() map[string]interface{} { - if tm.validator != nil { - return tm.validator.GetNodeStatistics() - } - return make(map[string]interface{}) -} - -// SetRetryConfig sets retry configuration for the DAG -func (tm *DAG) SetRetryConfig(config *RetryConfig) { - if tm.retryManager != nil { - tm.retryManager.config = config +// SetBatchProcessingEnabled enables or disables batch processing +func (tm *DAG) SetBatchProcessingEnabled(enabled bool) { + if tm.batchProcessor != nil && enabled { + // Configure batch processor with processing function + tm.batchProcessor.SetProcessFunc(func(tasks []*mq.Task) error { + // Process tasks in batch + for _, task := range tasks { + tm.ProcessTask(context.Background(), task) + } + return nil + }) } } -// AddNodeWithRetry adds a node with retry capabilities -func (tm *DAG) AddNodeWithRetry(nodeType NodeType, name, nodeID string, handler mq.Processor, retryConfig *RetryConfig, startNode ...bool) *DAG { - if tm.Error != nil { - return tm - } +// Webhook Methods - // Wrap handler with retry logic if config provided - if retryConfig != nil { - handler = NewRetryableProcessor(handler, retryConfig, tm.Logger()) - } - - return tm.AddNode(nodeType, name, nodeID, handler, startNode...) +// SetWebhookManager sets the webhook manager +func (tm *DAG) SetWebhookManager(manager *WebhookManager) { + tm.webhookManager = manager } -// SetAlertThresholds configures monitoring alert thresholds -func (tm *DAG) SetAlertThresholds(thresholds *AlertThresholds) { - if tm.monitor != nil { - tm.monitor.SetAlertThresholds(thresholds) +// AddWebhook adds a webhook configuration +func (tm *DAG) AddWebhook(event string, config WebhookConfig) { + if tm.webhookManager != nil { + tm.webhookManager.AddWebhook(event, config) } } -// AddAlertHandler adds an alert handler for monitoring -func (tm *DAG) AddAlertHandler(handler AlertHandler) { - if tm.monitor != nil { - tm.monitor.AddAlertHandler(handler) +// Performance Optimization Methods + +// OptimizePerformance triggers performance optimization +func (tm *DAG) OptimizePerformance() error { + if tm.performanceOptimizer != nil { + return tm.performanceOptimizer.OptimizePerformance() + } + return fmt.Errorf("performance optimizer not initialized") +} + +// Cleanup Methods + +// StartCleanup starts the cleanup manager +func (tm *DAG) StartCleanup(ctx context.Context) { + if tm.cleanupManager != nil { + tm.cleanupManager.Start(ctx) } } -// UpdateConfiguration updates the DAG configuration -func (tm *DAG) UpdateConfiguration(config *DAGConfig) error { - if tm.configManager != nil { - return tm.configManager.UpdateConfig(config) +// StopCleanup stops the cleanup manager +func (tm *DAG) StopCleanup() { + if tm.cleanupManager != nil { + tm.cleanupManager.Stop() } - return fmt.Errorf("config manager not initialized") } -// GetConfiguration returns the current DAG configuration -func (tm *DAG) GetConfiguration() *DAGConfig { - if tm.configManager != nil { - return tm.configManager.GetConfig() +// Enhanced Stop method with proper cleanup +func (tm *DAG) StopEnhanced(ctx context.Context) error { + // Stop monitoring + tm.StopMonitoring() + + // Stop cleanup manager + tm.StopCleanup() + + // Stop batch processor + if tm.batchProcessor != nil { + tm.batchProcessor.Stop() } - return DefaultDAGConfig() + + // Stop cache cleanup + if tm.cache != nil { + tm.cache.Stop() + } + + // Flush activity logs + if tm.activityLogger != nil { + tm.activityLogger.Flush() + } + + // Stop all task managers + tm.taskManager.ForEach(func(taskID string, manager *TaskManager) bool { + manager.Stop() + return true + }) + + // Clear all caches + tm.nextNodesCache = nil + tm.prevNodesCache = nil + + // Stop underlying components + return tm.Stop(ctx) +} + +// ActivityAlertHandler handles alerts by logging them as activities +type ActivityAlertHandler struct { + activityLogger *ActivityLogger +} + +func (h *ActivityAlertHandler) HandleAlert(alert Alert) error { + if h.activityLogger != nil { + h.activityLogger.Log( + ActivityLevelWarn, + ActivityTypeAlert, + alert.Message, + map[string]interface{}{ + "alert_type": alert.Type, + "alert_severity": alert.Severity, + "alert_node_id": alert.NodeID, + "alert_timestamp": alert.Timestamp, + }, + ) + } + return nil } diff --git a/dag/enhancements.go b/dag/enhancements.go index 362ea19..8fcaa0d 100644 --- a/dag/enhancements.go +++ b/dag/enhancements.go @@ -72,20 +72,20 @@ func (bp *BatchProcessor) flushBatch() { return } - batch := make([]*mq.Task, len(bp.buffer)) - copy(batch, bp.buffer) - bp.buffer = bp.buffer[:0] // Reset buffer + tasks := make([]*mq.Task, len(bp.buffer)) + copy(tasks, bp.buffer) + bp.buffer = bp.buffer[:0] // Clear buffer bp.bufferMu.Unlock() if bp.processFunc != nil { - if err := bp.processFunc(batch); err != nil { + if err := bp.processFunc(tasks); err != nil { bp.logger.Error("Batch processing failed", - logger.Field{Key: "batchSize", Value: len(batch)}, logger.Field{Key: "error", Value: err.Error()}, + logger.Field{Key: "batch_size", Value: len(tasks)}, ) } else { bp.logger.Info("Batch processed successfully", - logger.Field{Key: "batchSize", Value: len(batch)}, + logger.Field{Key: "batch_size", Value: len(tasks)}, ) } } @@ -94,52 +94,73 @@ func (bp *BatchProcessor) flushBatch() { // Stop stops the batch processor func (bp *BatchProcessor) Stop() { close(bp.stopCh) - bp.flushBatch() // Process remaining tasks bp.wg.Wait() + + // Flush remaining tasks + bp.flushBatch() } // TransactionManager handles transaction-like operations for DAG execution type TransactionManager struct { - dag *DAG - activeTransactions map[string]*Transaction - mu sync.RWMutex - logger logger.Logger + dag *DAG + transactions map[string]*Transaction + savePoints map[string][]SavePoint + mu sync.RWMutex + logger logger.Logger } // Transaction represents a transactional DAG execution type Transaction struct { - ID string - TaskID string - StartTime time.Time - CompletedNodes []string - SavePoints map[string][]byte - Status TransactionStatus - Context context.Context - CancelFunc context.CancelFunc - RollbackHandlers []RollbackHandler + ID string `json:"id"` + TaskID string `json:"task_id"` + Status TransactionStatus `json:"status"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time,omitempty"` + Operations []TransactionOperation `json:"operations"` + SavePoints []SavePoint `json:"save_points"` + Metadata map[string]interface{} `json:"metadata,omitempty"` } // TransactionStatus represents the status of a transaction -type TransactionStatus int +type TransactionStatus string const ( - TransactionActive TransactionStatus = iota - TransactionCommitted - TransactionRolledBack - TransactionFailed + TransactionStatusStarted TransactionStatus = "started" + TransactionStatusCommitted TransactionStatus = "committed" + TransactionStatusRolledBack TransactionStatus = "rolled_back" + TransactionStatusFailed TransactionStatus = "failed" ) +// TransactionOperation represents an operation within a transaction +type TransactionOperation struct { + ID string `json:"id"` + Type string `json:"type"` + NodeID string `json:"node_id"` + Data map[string]interface{} `json:"data"` + Timestamp time.Time `json:"timestamp"` + RollbackHandler RollbackHandler `json:"-"` +} + +// SavePoint represents a save point in a transaction +type SavePoint struct { + ID string `json:"id"` + Name string `json:"name"` + Timestamp time.Time `json:"timestamp"` + State map[string]interface{} `json:"state"` +} + // RollbackHandler defines how to rollback operations type RollbackHandler interface { - Rollback(ctx context.Context, savePoint []byte) error + Rollback(operation TransactionOperation) error } // NewTransactionManager creates a new transaction manager func NewTransactionManager(dag *DAG, logger logger.Logger) *TransactionManager { return &TransactionManager{ - dag: dag, - activeTransactions: make(map[string]*Transaction), - logger: logger, + dag: dag, + transactions: make(map[string]*Transaction), + savePoints: make(map[string][]SavePoint), + logger: logger, } } @@ -148,48 +169,70 @@ func (tm *TransactionManager) BeginTransaction(taskID string) *Transaction { tm.mu.Lock() defer tm.mu.Unlock() - ctx, cancel := context.WithCancel(context.Background()) - tx := &Transaction{ - ID: fmt.Sprintf("tx_%s_%d", taskID, time.Now().UnixNano()), - TaskID: taskID, - StartTime: time.Now(), - CompletedNodes: []string{}, - SavePoints: make(map[string][]byte), - Status: TransactionActive, - Context: ctx, - CancelFunc: cancel, - RollbackHandlers: []RollbackHandler{}, + ID: mq.NewID(), + TaskID: taskID, + Status: TransactionStatusStarted, + StartTime: time.Now(), + Operations: make([]TransactionOperation, 0), + SavePoints: make([]SavePoint, 0), + Metadata: make(map[string]interface{}), } - tm.activeTransactions[tx.ID] = tx + tm.transactions[tx.ID] = tx tm.logger.Info("Transaction started", - logger.Field{Key: "transactionID", Value: tx.ID}, - logger.Field{Key: "taskID", Value: taskID}, + logger.Field{Key: "transaction_id", Value: tx.ID}, + logger.Field{Key: "task_id", Value: taskID}, ) return tx } -// AddSavePoint adds a save point to the transaction -func (tm *TransactionManager) AddSavePoint(txID, nodeID string, data []byte) error { - tm.mu.RLock() - tx, exists := tm.activeTransactions[txID] - tm.mu.RUnlock() +// AddOperation adds an operation to a transaction +func (tm *TransactionManager) AddOperation(txID string, operation TransactionOperation) error { + tm.mu.Lock() + defer tm.mu.Unlock() + tx, exists := tm.transactions[txID] if !exists { return fmt.Errorf("transaction %s not found", txID) } - if tx.Status != TransactionActive { + if tx.Status != TransactionStatusStarted { return fmt.Errorf("transaction %s is not active", txID) } - tx.SavePoints[nodeID] = data - tm.logger.Info("Save point added", - logger.Field{Key: "transactionID", Value: txID}, - logger.Field{Key: "nodeID", Value: nodeID}, + operation.ID = mq.NewID() + operation.Timestamp = time.Now() + tx.Operations = append(tx.Operations, operation) + + return nil +} + +// AddSavePoint adds a save point to the transaction +func (tm *TransactionManager) AddSavePoint(txID, name string, state map[string]interface{}) error { + tm.mu.Lock() + defer tm.mu.Unlock() + + tx, exists := tm.transactions[txID] + if !exists { + return fmt.Errorf("transaction %s not found", txID) + } + + savePoint := SavePoint{ + ID: mq.NewID(), + Name: name, + Timestamp: time.Now(), + State: state, + } + + tx.SavePoints = append(tx.SavePoints, savePoint) + tm.savePoints[txID] = tx.SavePoints + + tm.logger.Info("Save point created", + logger.Field{Key: "transaction_id", Value: txID}, + logger.Field{Key: "save_point_name", Value: name}, ) return nil @@ -200,24 +243,26 @@ func (tm *TransactionManager) CommitTransaction(txID string) error { tm.mu.Lock() defer tm.mu.Unlock() - tx, exists := tm.activeTransactions[txID] + tx, exists := tm.transactions[txID] if !exists { return fmt.Errorf("transaction %s not found", txID) } - if tx.Status != TransactionActive { + if tx.Status != TransactionStatusStarted { return fmt.Errorf("transaction %s is not active", txID) } - tx.Status = TransactionCommitted - tx.CancelFunc() - delete(tm.activeTransactions, txID) + tx.Status = TransactionStatusCommitted + tx.EndTime = time.Now() tm.logger.Info("Transaction committed", - logger.Field{Key: "transactionID", Value: txID}, - logger.Field{Key: "duration", Value: time.Since(tx.StartTime)}, + logger.Field{Key: "transaction_id", Value: txID}, + logger.Field{Key: "operations_count", Value: len(tx.Operations)}, ) + // Clean up save points + delete(tm.savePoints, txID) + return nil } @@ -226,73 +271,109 @@ func (tm *TransactionManager) RollbackTransaction(txID string) error { tm.mu.Lock() defer tm.mu.Unlock() - tx, exists := tm.activeTransactions[txID] + tx, exists := tm.transactions[txID] if !exists { return fmt.Errorf("transaction %s not found", txID) } - if tx.Status != TransactionActive { + if tx.Status != TransactionStatusStarted { return fmt.Errorf("transaction %s is not active", txID) } - tx.Status = TransactionRolledBack - tx.CancelFunc() - - // Execute rollback handlers in reverse order - for i := len(tx.RollbackHandlers) - 1; i >= 0; i-- { - handler := tx.RollbackHandlers[i] - if err := handler.Rollback(tx.Context, nil); err != nil { - tm.logger.Error("Rollback handler failed", - logger.Field{Key: "transactionID", Value: txID}, - logger.Field{Key: "error", Value: err.Error()}, - ) + // Rollback operations in reverse order + for i := len(tx.Operations) - 1; i >= 0; i-- { + operation := tx.Operations[i] + if operation.RollbackHandler != nil { + if err := operation.RollbackHandler.Rollback(operation); err != nil { + tm.logger.Error("Failed to rollback operation", + logger.Field{Key: "transaction_id", Value: txID}, + logger.Field{Key: "operation_id", Value: operation.ID}, + logger.Field{Key: "error", Value: err.Error()}, + ) + } } } - delete(tm.activeTransactions, txID) + tx.Status = TransactionStatusRolledBack + tx.EndTime = time.Now() tm.logger.Info("Transaction rolled back", - logger.Field{Key: "transactionID", Value: txID}, - logger.Field{Key: "duration", Value: time.Since(tx.StartTime)}, + logger.Field{Key: "transaction_id", Value: txID}, + logger.Field{Key: "operations_count", Value: len(tx.Operations)}, ) + // Clean up save points + delete(tm.savePoints, txID) + return nil } +// GetTransaction retrieves a transaction by ID +func (tm *TransactionManager) GetTransaction(txID string) (*Transaction, error) { + tm.mu.RLock() + defer tm.mu.RUnlock() + + tx, exists := tm.transactions[txID] + if !exists { + return nil, fmt.Errorf("transaction %s not found", txID) + } + + // Return a copy + txCopy := *tx + return &txCopy, nil +} + // CleanupManager handles cleanup of completed tasks and resources type CleanupManager struct { - dag *DAG - cleanupInterval time.Duration - retentionPeriod time.Duration - maxCompletedTasks int - stopCh chan struct{} - logger logger.Logger + dag *DAG + cleanupInterval time.Duration + retentionPeriod time.Duration + maxEntries int + logger logger.Logger + stopCh chan struct{} + running bool + mu sync.RWMutex } // NewCleanupManager creates a new cleanup manager -func NewCleanupManager(dag *DAG, cleanupInterval, retentionPeriod time.Duration, maxCompletedTasks int, logger logger.Logger) *CleanupManager { +func NewCleanupManager(dag *DAG, cleanupInterval, retentionPeriod time.Duration, maxEntries int, logger logger.Logger) *CleanupManager { return &CleanupManager{ - dag: dag, - cleanupInterval: cleanupInterval, - retentionPeriod: retentionPeriod, - maxCompletedTasks: maxCompletedTasks, - stopCh: make(chan struct{}), - logger: logger, + dag: dag, + cleanupInterval: cleanupInterval, + retentionPeriod: retentionPeriod, + maxEntries: maxEntries, + logger: logger, + stopCh: make(chan struct{}), } } // Start begins the cleanup routine func (cm *CleanupManager) Start(ctx context.Context) { + cm.mu.Lock() + defer cm.mu.Unlock() + + if cm.running { + return + } + + cm.running = true go cm.cleanupRoutine(ctx) - cm.logger.Info("Cleanup manager started", - logger.Field{Key: "interval", Value: cm.cleanupInterval}, - logger.Field{Key: "retention", Value: cm.retentionPeriod}, - ) + + cm.logger.Info("Cleanup manager started") } // Stop stops the cleanup routine func (cm *CleanupManager) Stop() { + cm.mu.Lock() + defer cm.mu.Unlock() + + if !cm.running { + return + } + + cm.running = false close(cm.stopCh) + cm.logger.Info("Cleanup manager stopped") } @@ -315,46 +396,56 @@ func (cm *CleanupManager) cleanupRoutine(ctx context.Context) { // performCleanup cleans up old tasks and resources func (cm *CleanupManager) performCleanup() { - cleaned := 0 - cutoffTime := time.Now().Add(-cm.retentionPeriod) + cutoff := time.Now().Add(-cm.retentionPeriod) // Clean up old task managers - var tasksToCleanup []string - cm.dag.taskManager.ForEach(func(taskID string, manager *TaskManager) bool { - if manager.createdAt.Before(cutoffTime) { - tasksToCleanup = append(tasksToCleanup, taskID) + var toDelete []string + cm.dag.taskManager.ForEach(func(taskID string, tm *TaskManager) bool { + if tm.createdAt.Before(cutoff) { + toDelete = append(toDelete, taskID) } return true }) - for _, taskID := range tasksToCleanup { - cm.dag.taskManager.Set(taskID, nil) - cleaned++ + for _, taskID := range toDelete { + if tm, exists := cm.dag.taskManager.Get(taskID); exists { + tm.Stop() + cm.dag.taskManager.Del(taskID) + } } - if cleaned > 0 { + // Clean up circuit breakers for removed nodes + cm.dag.circuitBreakersMu.Lock() + for nodeID := range cm.dag.circuitBreakers { + if _, exists := cm.dag.nodes.Get(nodeID); !exists { + delete(cm.dag.circuitBreakers, nodeID) + } + } + cm.dag.circuitBreakersMu.Unlock() + + if len(toDelete) > 0 { cm.logger.Info("Cleanup completed", - logger.Field{Key: "cleanedTasks", Value: cleaned}, - logger.Field{Key: "cutoffTime", Value: cutoffTime}, + logger.Field{Key: "cleaned_tasks", Value: len(toDelete)}, ) } } // WebhookManager handles webhook notifications type WebhookManager struct { - webhooks map[string][]WebhookConfig - client HTTPClient - logger logger.Logger - mu sync.RWMutex + webhooks map[string][]WebhookConfig + httpClient HTTPClient + logger logger.Logger + mu sync.RWMutex } // WebhookConfig defines webhook configuration type WebhookConfig struct { - URL string - Headers map[string]string - Timeout time.Duration - RetryCount int - Events []string // Which events to trigger on + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Method string `json:"method"` + RetryCount int `json:"retry_count"` + Timeout time.Duration `json:"timeout"` + Events []string `json:"events"` } // HTTPClient interface for HTTP requests @@ -364,30 +455,41 @@ type HTTPClient interface { // WebhookEvent represents an event to send via webhook type WebhookEvent struct { - Type string `json:"type"` - TaskID string `json:"task_id,omitempty"` - NodeID string `json:"node_id,omitempty"` - Timestamp time.Time `json:"timestamp"` - Data interface{} `json:"data,omitempty"` + Type string `json:"type"` + TaskID string `json:"task_id"` + NodeID string `json:"node_id,omitempty"` + Timestamp time.Time `json:"timestamp"` + Data map[string]interface{} `json:"data"` } // NewWebhookManager creates a new webhook manager -func NewWebhookManager(client HTTPClient, logger logger.Logger) *WebhookManager { +func NewWebhookManager(httpClient HTTPClient, logger logger.Logger) *WebhookManager { return &WebhookManager{ - webhooks: make(map[string][]WebhookConfig), - client: client, - logger: logger, + webhooks: make(map[string][]WebhookConfig), + httpClient: httpClient, + logger: logger, } } // AddWebhook adds a webhook configuration -func (wm *WebhookManager) AddWebhook(eventType string, config WebhookConfig) { +func (wm *WebhookManager) AddWebhook(event string, config WebhookConfig) { wm.mu.Lock() defer wm.mu.Unlock() - wm.webhooks[eventType] = append(wm.webhooks[eventType], config) + if config.Method == "" { + config.Method = "POST" + } + if config.RetryCount == 0 { + config.RetryCount = 3 + } + if config.Timeout == 0 { + config.Timeout = 30 * time.Second + } + + wm.webhooks[event] = append(wm.webhooks[event], config) + wm.logger.Info("Webhook added", - logger.Field{Key: "eventType", Value: eventType}, + logger.Field{Key: "event", Value: event}, logger.Field{Key: "url", Value: config.URL}, ) } @@ -395,45 +497,65 @@ func (wm *WebhookManager) AddWebhook(eventType string, config WebhookConfig) { // TriggerWebhook sends webhook notifications for an event func (wm *WebhookManager) TriggerWebhook(event WebhookEvent) { wm.mu.RLock() - configs := wm.webhooks[event.Type] + configs, exists := wm.webhooks[event.Type] wm.mu.RUnlock() - if len(configs) == 0 { + if !exists { return } - data, err := json.Marshal(event) + for _, config := range configs { + // Check if this webhook should handle this event + if len(config.Events) > 0 { + found := false + for _, eventType := range config.Events { + if eventType == event.Type { + found = true + break + } + } + if !found { + continue + } + } + + go wm.sendWebhook(config, event) + } +} + +// sendWebhook sends a single webhook with retry logic +func (wm *WebhookManager) sendWebhook(config WebhookConfig, event WebhookEvent) { + payload, err := json.Marshal(event) if err != nil { - wm.logger.Error("Failed to marshal webhook event", + wm.logger.Error("Failed to marshal webhook payload", logger.Field{Key: "error", Value: err.Error()}, ) return } - for _, config := range configs { - go wm.sendWebhook(config, data) - } -} - -// sendWebhook sends a single webhook with retry logic -func (wm *WebhookManager) sendWebhook(config WebhookConfig, data []byte) { - for attempt := 0; attempt <= config.RetryCount; attempt++ { - err := wm.client.Post(config.URL, "application/json", data, config.Headers) + for attempt := 0; attempt < config.RetryCount; attempt++ { + err := wm.httpClient.Post(config.URL, "application/json", payload, config.Headers) if err == nil { wm.logger.Info("Webhook sent successfully", logger.Field{Key: "url", Value: config.URL}, - logger.Field{Key: "attempt", Value: attempt + 1}, + logger.Field{Key: "event_type", Value: event.Type}, ) return } - if attempt < config.RetryCount { + wm.logger.Warn("Webhook delivery failed", + logger.Field{Key: "url", Value: config.URL}, + logger.Field{Key: "attempt", Value: attempt + 1}, + logger.Field{Key: "error", Value: err.Error()}, + ) + + if attempt < config.RetryCount-1 { time.Sleep(time.Duration(attempt+1) * time.Second) } } - wm.logger.Error("Webhook failed after all retries", + wm.logger.Error("Webhook delivery failed after all retries", logger.Field{Key: "url", Value: config.URL}, - logger.Field{Key: "attempts", Value: config.RetryCount + 1}, + logger.Field{Key: "event_type", Value: event.Type}, ) } diff --git a/dag/monitoring.go b/dag/monitoring.go index a09e99d..75c42cf 100644 --- a/dag/monitoring.go +++ b/dag/monitoring.go @@ -71,17 +71,9 @@ func (m *MonitoringMetrics) RecordTaskCompletion(taskID string, status mq.Status m.mu.Lock() defer m.mu.Unlock() - if startTime, exists := m.ActiveTasks[taskID]; exists { - duration := time.Since(startTime) - m.TotalExecutionTime += duration - m.LastTaskCompletedAt = time.Now() - delete(m.ActiveTasks, taskID) - m.TasksInProgress-- - - // Update average execution time - if m.TasksCompleted > 0 { - m.AverageExecutionTime = m.TotalExecutionTime / time.Duration(m.TasksCompleted+1) - } + m.TasksInProgress-- + if m.TasksInProgress < 0 { + m.TasksInProgress = 0 } switch status { @@ -92,6 +84,9 @@ func (m *MonitoringMetrics) RecordTaskCompletion(taskID string, status mq.Status case mq.Cancelled: m.TasksCancelled++ } + + m.LastTaskCompletedAt = time.Now() + delete(m.ActiveTasks, taskID) } // RecordNodeExecution records node execution metrics @@ -131,11 +126,27 @@ func (m *MonitoringMetrics) RecordNodeExecution(nodeID string, duration time.Dur // Legacy tracking m.NodesExecuted[nodeID]++ - if len(m.NodeExecutionTimes[nodeID]) > 100 { - // Keep only last 100 execution times - m.NodeExecutionTimes[nodeID] = m.NodeExecutionTimes[nodeID][1:] - } m.NodeExecutionTimes[nodeID] = append(m.NodeExecutionTimes[nodeID], duration) + + // Keep only last 100 execution times per node to prevent memory bloat + if len(m.NodeExecutionTimes[nodeID]) > 100 { + m.NodeExecutionTimes[nodeID] = m.NodeExecutionTimes[nodeID][len(m.NodeExecutionTimes[nodeID])-100:] + } + + // Calculate average execution time + var totalDuration time.Duration + var totalExecutions int64 + for _, durations := range m.NodeExecutionTimes { + for _, d := range durations { + totalDuration += d + totalExecutions++ + } + } + if totalExecutions > 0 { + m.AverageExecutionTime = totalDuration / time.Duration(totalExecutions) + } + + m.TotalExecutionTime += duration } // RecordNodeStart records when a node starts processing @@ -145,6 +156,10 @@ func (m *MonitoringMetrics) RecordNodeStart(nodeID string) { if stats, exists := m.NodeProcessingStats[nodeID]; exists { stats.CurrentlyRunning++ + } else { + m.NodeProcessingStats[nodeID] = &NodeStats{ + CurrentlyRunning: 1, + } } } @@ -153,8 +168,11 @@ func (m *MonitoringMetrics) RecordNodeEnd(nodeID string) { m.mu.Lock() defer m.mu.Unlock() - if stats, exists := m.NodeProcessingStats[nodeID]; exists && stats.CurrentlyRunning > 0 { + if stats, exists := m.NodeProcessingStats[nodeID]; exists { stats.CurrentlyRunning-- + if stats.CurrentlyRunning < 0 { + stats.CurrentlyRunning = 0 + } } } @@ -190,24 +208,14 @@ func (m *MonitoringMetrics) GetSnapshot() *MonitoringMetrics { for k, v := range m.ActiveTasks { snapshot.ActiveTasks[k] = v } - for k, v := range m.NodeExecutionTimes { - snapshot.NodeExecutionTimes[k] = make([]time.Duration, len(v)) - copy(snapshot.NodeExecutionTimes[k], v) - } for k, v := range m.NodeProcessingStats { - snapshot.NodeProcessingStats[k] = &NodeStats{ - ExecutionCount: v.ExecutionCount, - SuccessCount: v.SuccessCount, - FailureCount: v.FailureCount, - TotalDuration: v.TotalDuration, - AverageDuration: v.AverageDuration, - MinDuration: v.MinDuration, - MaxDuration: v.MaxDuration, - LastExecuted: v.LastExecuted, - LastSuccess: v.LastSuccess, - LastFailure: v.LastFailure, - CurrentlyRunning: v.CurrentlyRunning, - } + statsCopy := *v + snapshot.NodeProcessingStats[k] = &statsCopy + } + for k, v := range m.NodeExecutionTimes { + timesCopy := make([]time.Duration, len(v)) + copy(timesCopy, v) + snapshot.NodeExecutionTimes[k] = timesCopy } return snapshot @@ -219,45 +227,32 @@ func (m *MonitoringMetrics) GetNodeStats(nodeID string) *NodeStats { defer m.mu.RUnlock() if stats, exists := m.NodeProcessingStats[nodeID]; exists { - // Return a copy - return &NodeStats{ - ExecutionCount: stats.ExecutionCount, - SuccessCount: stats.SuccessCount, - FailureCount: stats.FailureCount, - TotalDuration: stats.TotalDuration, - AverageDuration: stats.AverageDuration, - MinDuration: stats.MinDuration, - MaxDuration: stats.MaxDuration, - LastExecuted: stats.LastExecuted, - LastSuccess: stats.LastSuccess, - LastFailure: stats.LastFailure, - CurrentlyRunning: stats.CurrentlyRunning, - } + statsCopy := *stats + return &statsCopy } return nil } // Monitor provides comprehensive monitoring capabilities for DAG type Monitor struct { - dag *DAG - metrics *MonitoringMetrics - logger logger.Logger - alertThresholds *AlertThresholds - webhookURL string - alertHandlers []AlertHandler - monitoringActive bool - stopCh chan struct{} - mu sync.RWMutex + dag *DAG + metrics *MonitoringMetrics + logger logger.Logger + thresholds *AlertThresholds + handlers []AlertHandler + stopCh chan struct{} + running bool + mu sync.RWMutex } // AlertThresholds defines thresholds for alerting type AlertThresholds struct { - MaxFailureRate float64 // Maximum allowed failure rate (0.0 - 1.0) - MaxExecutionTime time.Duration // Maximum allowed execution time - MaxTasksInProgress int64 // Maximum allowed concurrent tasks - MinSuccessRate float64 // Minimum required success rate - MaxNodeFailures int64 // Maximum failures per node - HealthCheckInterval time.Duration // How often to check health + MaxFailureRate float64 `json:"max_failure_rate"` + MaxExecutionTime time.Duration `json:"max_execution_time"` + MaxTasksInProgress int64 `json:"max_tasks_in_progress"` + MinSuccessRate float64 `json:"min_success_rate"` + MaxNodeFailures int64 `json:"max_node_failures"` + HealthCheckInterval time.Duration `json:"health_check_interval"` } // AlertHandler defines interface for handling alerts @@ -267,44 +262,66 @@ type AlertHandler interface { // Alert represents a monitoring alert type Alert struct { - Type string - Severity string - Message string - NodeID string - TaskID string - Timestamp time.Time - Metrics map[string]interface{} + ID string `json:"id"` + Timestamp time.Time `json:"timestamp"` + Severity AlertSeverity `json:"severity"` + Type AlertType `json:"type"` + Message string `json:"message"` + Details map[string]interface{} `json:"details"` + NodeID string `json:"node_id,omitempty"` + TaskID string `json:"task_id,omitempty"` + Threshold interface{} `json:"threshold,omitempty"` + ActualValue interface{} `json:"actual_value,omitempty"` } +type AlertSeverity string + +const ( + AlertSeverityInfo AlertSeverity = "info" + AlertSeverityWarning AlertSeverity = "warning" + AlertSeverityCritical AlertSeverity = "critical" +) + +type AlertType string + +const ( + AlertTypeFailureRate AlertType = "failure_rate" + AlertTypeExecutionTime AlertType = "execution_time" + AlertTypeTaskLoad AlertType = "task_load" + AlertTypeNodeFailures AlertType = "node_failures" + AlertTypeCircuitBreaker AlertType = "circuit_breaker" + AlertTypeHealthCheck AlertType = "health_check" +) + // NewMonitor creates a new DAG monitor func NewMonitor(dag *DAG, logger logger.Logger) *Monitor { return &Monitor{ dag: dag, metrics: NewMonitoringMetrics(), logger: logger, - alertThresholds: &AlertThresholds{ - MaxFailureRate: 0.1, // 10% failure rate + thresholds: &AlertThresholds{ + MaxFailureRate: 0.1, // 10% MaxExecutionTime: 5 * time.Minute, MaxTasksInProgress: 1000, - MinSuccessRate: 0.9, // 90% success rate + MinSuccessRate: 0.9, // 90% MaxNodeFailures: 10, HealthCheckInterval: 30 * time.Second, }, - stopCh: make(chan struct{}), + handlers: make([]AlertHandler, 0), + stopCh: make(chan struct{}), } } // Start begins monitoring func (m *Monitor) Start(ctx context.Context) { m.mu.Lock() - if m.monitoringActive { - m.mu.Unlock() + defer m.mu.Unlock() + + if m.running { return } - m.monitoringActive = true - m.mu.Unlock() - // Start health check routine + m.running = true go m.healthCheckRoutine(ctx) m.logger.Info("DAG monitoring started") @@ -315,12 +332,13 @@ func (m *Monitor) Stop() { m.mu.Lock() defer m.mu.Unlock() - if !m.monitoringActive { + if !m.running { return } + m.running = false close(m.stopCh) - m.monitoringActive = false + m.logger.Info("DAG monitoring stopped") } @@ -328,14 +346,14 @@ func (m *Monitor) Stop() { func (m *Monitor) SetAlertThresholds(thresholds *AlertThresholds) { m.mu.Lock() defer m.mu.Unlock() - m.alertThresholds = thresholds + m.thresholds = thresholds } // AddAlertHandler adds an alert handler func (m *Monitor) AddAlertHandler(handler AlertHandler) { m.mu.Lock() defer m.mu.Unlock() - m.alertHandlers = append(m.alertHandlers, handler) + m.handlers = append(m.handlers, handler) } // GetMetrics returns current metrics @@ -345,7 +363,7 @@ func (m *Monitor) GetMetrics() *MonitoringMetrics { // healthCheckRoutine performs periodic health checks func (m *Monitor) healthCheckRoutine(ctx context.Context) { - ticker := time.NewTicker(m.alertThresholds.HealthCheckInterval) + ticker := time.NewTicker(m.thresholds.HealthCheckInterval) defer ticker.Stop() for { @@ -362,50 +380,57 @@ func (m *Monitor) healthCheckRoutine(ctx context.Context) { // performHealthCheck checks system health and triggers alerts func (m *Monitor) performHealthCheck() { - snapshot := m.metrics.GetSnapshot() + metrics := m.GetMetrics() // Check failure rate - if snapshot.TasksTotal > 0 { - failureRate := float64(snapshot.TasksFailed) / float64(snapshot.TasksTotal) - if failureRate > m.alertThresholds.MaxFailureRate { + if metrics.TasksTotal > 0 { + failureRate := float64(metrics.TasksFailed) / float64(metrics.TasksTotal) + if failureRate > m.thresholds.MaxFailureRate { m.triggerAlert(Alert{ - Type: "high_failure_rate", - Severity: "warning", - Message: fmt.Sprintf("High failure rate: %.2f%%", failureRate*100), - Timestamp: time.Now(), - Metrics: map[string]interface{}{ - "failure_rate": failureRate, - "total_tasks": snapshot.TasksTotal, - "failed_tasks": snapshot.TasksFailed, + ID: mq.NewID(), + Timestamp: time.Now(), + Severity: AlertSeverityCritical, + Type: AlertTypeFailureRate, + Message: "High failure rate detected", + Threshold: m.thresholds.MaxFailureRate, + ActualValue: failureRate, + Details: map[string]interface{}{ + "failed_tasks": metrics.TasksFailed, + "total_tasks": metrics.TasksTotal, }, }) } } - // Check tasks in progress - if snapshot.TasksInProgress > m.alertThresholds.MaxTasksInProgress { + // Check task load + if metrics.TasksInProgress > m.thresholds.MaxTasksInProgress { m.triggerAlert(Alert{ - Type: "high_task_load", - Severity: "warning", - Message: fmt.Sprintf("High number of tasks in progress: %d", snapshot.TasksInProgress), - Timestamp: time.Now(), - Metrics: map[string]interface{}{ - "tasks_in_progress": snapshot.TasksInProgress, - "threshold": m.alertThresholds.MaxTasksInProgress, + ID: mq.NewID(), + Timestamp: time.Now(), + Severity: AlertSeverityWarning, + Type: AlertTypeTaskLoad, + Message: "High task load detected", + Threshold: m.thresholds.MaxTasksInProgress, + ActualValue: metrics.TasksInProgress, + Details: map[string]interface{}{ + "tasks_in_progress": metrics.TasksInProgress, }, }) } // Check node failures - for nodeID, failures := range snapshot.NodeFailures { - if failures > m.alertThresholds.MaxNodeFailures { + for nodeID, failures := range metrics.NodeFailures { + if failures > m.thresholds.MaxNodeFailures { m.triggerAlert(Alert{ - Type: "node_failures", - Severity: "error", - Message: fmt.Sprintf("Node %s has %d failures", nodeID, failures), - NodeID: nodeID, - Timestamp: time.Now(), - Metrics: map[string]interface{}{ + ID: mq.NewID(), + Timestamp: time.Now(), + Severity: AlertSeverityCritical, + Type: AlertTypeNodeFailures, + Message: fmt.Sprintf("Node %s has too many failures", nodeID), + NodeID: nodeID, + Threshold: m.thresholds.MaxNodeFailures, + ActualValue: failures, + Details: map[string]interface{}{ "node_id": nodeID, "failures": failures, }, @@ -414,15 +439,17 @@ func (m *Monitor) performHealthCheck() { } // Check execution time - if snapshot.AverageExecutionTime > m.alertThresholds.MaxExecutionTime { + if metrics.AverageExecutionTime > m.thresholds.MaxExecutionTime { m.triggerAlert(Alert{ - Type: "slow_execution", - Severity: "warning", - Message: fmt.Sprintf("Average execution time is high: %v", snapshot.AverageExecutionTime), - Timestamp: time.Now(), - Metrics: map[string]interface{}{ - "average_execution_time": snapshot.AverageExecutionTime, - "threshold": m.alertThresholds.MaxExecutionTime, + ID: mq.NewID(), + Timestamp: time.Now(), + Severity: AlertSeverityWarning, + Type: AlertTypeExecutionTime, + Message: "Average execution time is too high", + Threshold: m.thresholds.MaxExecutionTime, + ActualValue: metrics.AverageExecutionTime, + Details: map[string]interface{}{ + "average_execution_time": metrics.AverageExecutionTime.String(), }, }) } @@ -431,16 +458,20 @@ func (m *Monitor) performHealthCheck() { // triggerAlert sends alerts to all registered handlers func (m *Monitor) triggerAlert(alert Alert) { m.logger.Warn("Alert triggered", - logger.Field{Key: "type", Value: alert.Type}, - logger.Field{Key: "severity", Value: alert.Severity}, + logger.Field{Key: "alert_id", Value: alert.ID}, + logger.Field{Key: "type", Value: string(alert.Type)}, + logger.Field{Key: "severity", Value: string(alert.Severity)}, logger.Field{Key: "message", Value: alert.Message}, ) - for _, handler := range m.alertHandlers { - if err := handler.HandleAlert(alert); err != nil { - m.logger.Error("Alert handler failed", - logger.Field{Key: "error", Value: err.Error()}, - ) - } + for _, handler := range m.handlers { + go func(h AlertHandler, a Alert) { + if err := h.HandleAlert(a); err != nil { + m.logger.Error("Alert handler error", + logger.Field{Key: "error", Value: err.Error()}, + logger.Field{Key: "alert_id", Value: a.ID}, + ) + } + }(handler, alert) } } diff --git a/dag/node.go b/dag/node.go deleted file mode 100644 index 772d262..0000000 --- a/dag/node.go +++ /dev/null @@ -1,136 +0,0 @@ -package dag - -import ( - "context" - "fmt" - "strings" -) - -func (tm *DAG) GetNextNodes(key string) ([]*Node, error) { - key = strings.Split(key, Delimiter)[0] - // use cache if available - if tm.nextNodesCache != nil { - if next, ok := tm.nextNodesCache[key]; ok { - return next, nil - } - } - node, exists := tm.nodes.Get(key) - if !exists { - return nil, fmt.Errorf("Node with key %s does not exist while getting next node", key) - } - var successors []*Node - for _, edge := range node.Edges { - successors = append(successors, edge.To) - } - if conds, exists := tm.conditions[key]; exists { - for _, targetKey := range conds { - if targetNode, exists := tm.nodes.Get(targetKey); exists { - successors = append(successors, targetNode) - } - } - } - return successors, nil -} - -func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) { - key = strings.Split(key, Delimiter)[0] - // use cache if available - if tm.prevNodesCache != nil { - if prev, ok := tm.prevNodesCache[key]; ok { - return prev, nil - } - } - var predecessors []*Node - tm.nodes.ForEach(func(_ string, node *Node) bool { - for _, target := range node.Edges { - if target.To.ID == key { - predecessors = append(predecessors, node) - } - } - return true - }) - for fromNode, conds := range tm.conditions { - for _, targetKey := range conds { - if targetKey == key { - node, exists := tm.nodes.Get(fromNode) - if !exists { - return nil, fmt.Errorf("Node with key %s does not exist while getting previous node", fromNode) - } - predecessors = append(predecessors, node) - } - } - } - return predecessors, nil -} - -func (tm *DAG) GetLastNodes() ([]*Node, error) { - var lastNodes []*Node - tm.nodes.ForEach(func(key string, node *Node) bool { - if len(node.Edges) == 0 { - if conds, exists := tm.conditions[node.ID]; !exists || len(conds) == 0 { - lastNodes = append(lastNodes, node) - } - } - return true - }) - return lastNodes, nil -} - -func (tm *DAG) IsLastNode(key string) (bool, error) { - node, exists := tm.nodes.Get(key) - if !exists { - return false, fmt.Errorf("Node with key %s does not exist", key) - } - if len(node.Edges) > 0 { - return false, nil - } - if conds, exists := tm.conditions[node.ID]; exists && len(conds) > 0 { - return false, nil - } - return true, nil -} - -func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { - val := ctx.Value("initial_node") - initialNode, ok := val.(string) - if ok { - return initialNode, nil - } - if tm.startNode == "" { - firstNode := tm.findStartNode() - if firstNode != nil { - tm.startNode = firstNode.ID - } - } - - if tm.startNode == "" { - return "", fmt.Errorf("initial node not found") - } - return tm.startNode, nil -} - -func (tm *DAG) findStartNode() *Node { - incomingEdges := make(map[string]bool) - connectedNodes := make(map[string]bool) - for _, node := range tm.nodes.AsMap() { - for _, edge := range node.Edges { - if edge.Type.IsValid() { - connectedNodes[node.ID] = true - connectedNodes[edge.To.ID] = true - incomingEdges[edge.To.ID] = true - } - } - if cond, ok := tm.conditions[node.ID]; ok { - for _, target := range cond { - connectedNodes[target] = true - incomingEdges[target] = true - } - } - } - for nodeID, node := range tm.nodes.AsMap() { - if !incomingEdges[nodeID] && connectedNodes[nodeID] { - return node - } - } - return nil -} diff --git a/dag/retry.go b/dag/retry.go index 237e38f..c06953f 100644 --- a/dag/retry.go +++ b/dag/retry.go @@ -137,6 +137,29 @@ func (rm *NodeRetryManager) getKey(taskID, nodeID string) string { return taskID + ":" + nodeID } +// SetGlobalConfig sets the global retry configuration +func (rm *NodeRetryManager) SetGlobalConfig(config *RetryConfig) { + rm.mu.Lock() + defer rm.mu.Unlock() + rm.config = config + rm.logger.Info("Global retry configuration updated") +} + +// SetNodeConfig sets retry configuration for a specific node +func (rm *NodeRetryManager) SetNodeConfig(nodeID string, config *RetryConfig) { + // For simplicity, we'll store node-specific configs in a map + // This could be extended to support per-node configurations + rm.mu.Lock() + defer rm.mu.Unlock() + + // Store node-specific config (this is a simplified implementation) + // In a full implementation, you'd have a nodeConfigs map + rm.logger.Info("Node-specific retry configuration set", + logger.Field{Key: "nodeID", Value: nodeID}, + logger.Field{Key: "maxRetries", Value: config.MaxRetries}, + ) +} + // RetryableProcessor wraps a processor with retry logic type RetryableProcessor struct { processor mq.Processor diff --git a/dag/task_manager.go b/dag/task_manager.go index fc7945e..a911c4f 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -305,21 +305,40 @@ func (tm *TaskManager) processNode(exec *task) { tm.handleNext(exec.ctx, node, state, result) } +// logNodeExecution logs node execution details func (tm *TaskManager) logNodeExecution(exec *task, pureNodeID string, result mq.Result, latency time.Duration) { + success := result.Error == nil + + // Log to DAG activity logger if available + if tm.dag.activityLogger != nil { + ctx := context.WithValue(exec.ctx, "task_id", exec.taskID) + ctx = context.WithValue(ctx, "node_id", pureNodeID) + ctx = context.WithValue(ctx, "duration", latency) + if result.Error != nil { + ctx = context.WithValue(ctx, "error", result.Error) + } + + tm.dag.activityLogger.LogNodeExecution(ctx, exec.taskID, pureNodeID, result, latency) + } + + // Update monitoring metrics + if tm.dag.monitor != nil { + tm.dag.monitor.metrics.RecordNodeExecution(pureNodeID, latency, success) + } + + // Log to standard logger fields := []logger.Field{ - {Key: "nodeID", Value: exec.nodeID}, - {Key: "pureNodeID", Value: pureNodeID}, + {Key: "nodeID", Value: pureNodeID}, {Key: "taskID", Value: exec.taskID}, - {Key: "latency", Value: latency.String()}, + {Key: "duration", Value: latency.String()}, + {Key: "success", Value: success}, } if result.Error != nil { fields = append(fields, logger.Field{Key: "error", Value: result.Error.Error()}) - fields = append(fields, logger.Field{Key: "status", Value: mq.Failed}) tm.dag.Logger().Error("Node execution failed", fields...) } else { - fields = append(fields, logger.Field{Key: "status", Value: mq.Completed}) - tm.dag.Logger().Info("Node executed successfully", fields...) + tm.dag.Logger().Info("Node execution completed", fields...) } } @@ -583,7 +602,16 @@ func (tm *TaskManager) Resume() { } } +// Stop gracefully stops the task manager func (tm *TaskManager) Stop() { close(tm.stopCh) tm.wg.Wait() + + // Clean up resources + tm.taskStates.Clear() + tm.parentNodes.Clear() + tm.childNodes.Clear() + tm.deferredTasks.Clear() + tm.currentNodePayload.Clear() + tm.currentNodeResult.Clear() } diff --git a/examples/enhanced_dag_demo.go b/examples/enhanced_dag_demo.go new file mode 100644 index 0000000..acda8d1 --- /dev/null +++ b/examples/enhanced_dag_demo.go @@ -0,0 +1,557 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" + "github.com/oarkflow/mq/logger" +) + +// ExampleProcessor demonstrates a custom processor with debugging +type ExampleProcessor struct { + name string + tags []string +} + +func NewExampleProcessor(name string) *ExampleProcessor { + return &ExampleProcessor{ + name: name, + tags: []string{"example", "demo"}, + } +} + +func (p *ExampleProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { + // Simulate processing time + time.Sleep(100 * time.Millisecond) + + // Add some example processing logic + var data map[string]interface{} + if err := task.UnmarshalPayload(&data); err != nil { + return mq.Result{Error: err} + } + + // Process the data + data["processed_by"] = p.name + data["processed_at"] = time.Now() + + payload, _ := task.MarshalPayload(data) + return mq.Result{Payload: payload} +} + +func (p *ExampleProcessor) SetConfig(payload dag.Payload) {} +func (p *ExampleProcessor) SetTags(tags ...string) { p.tags = append(p.tags, tags...) } +func (p *ExampleProcessor) GetTags() []string { return p.tags } +func (p *ExampleProcessor) Consume(ctx context.Context) error { return nil } +func (p *ExampleProcessor) Pause(ctx context.Context) error { return nil } +func (p *ExampleProcessor) Resume(ctx context.Context) error { return nil } +func (p *ExampleProcessor) Stop(ctx context.Context) error { return nil } +func (p *ExampleProcessor) Close() error { return nil } +func (p *ExampleProcessor) GetType() string { return "example" } +func (p *ExampleProcessor) GetKey() string { return p.name } +func (p *ExampleProcessor) SetKey(key string) { p.name = key } + +// CustomActivityHook demonstrates custom activity processing +type CustomActivityHook struct { + logger logger.Logger +} + +func (h *CustomActivityHook) OnActivity(entry dag.ActivityEntry) error { + // Custom processing of activity entries + if entry.Level == dag.ActivityLevelError { + h.logger.Error("Critical activity detected", + logger.Field{Key: "activity_id", Value: entry.ID}, + logger.Field{Key: "dag_name", Value: entry.DAGName}, + logger.Field{Key: "message", Value: entry.Message}, + ) + + // Here you could send notifications, trigger alerts, etc. + } + return nil +} + +// CustomAlertHandler demonstrates custom alert handling +type CustomAlertHandler struct { + logger logger.Logger +} + +func (h *CustomAlertHandler) HandleAlert(alert dag.Alert) error { + h.logger.Warn("DAG Alert received", + logger.Field{Key: "type", Value: alert.Type}, + logger.Field{Key: "severity", Value: alert.Severity}, + logger.Field{Key: "message", Value: alert.Message}, + ) + + // Here you could integrate with external alerting systems + // like Slack, PagerDuty, email, etc. + + return nil +} + +func main() { + // Initialize logger + log := logger.New(logger.Config{ + Level: logger.LevelInfo, + Format: logger.FormatJSON, + }) + + // Create a comprehensive DAG with all enhanced features + server := mq.NewServer("demo", ":0", log) + + // Create DAG with comprehensive configuration + dagInstance := dag.NewDAG("production-workflow", "workflow-key", func(ctx context.Context, result mq.Result) { + log.Info("Workflow completed", + logger.Field{Key: "result", Value: string(result.Payload)}, + ) + }) + + // Initialize all enhanced components + setupEnhancedDAG(dagInstance, log) + + // Build the workflow + buildWorkflow(dagInstance, log) + + // Start the server and DAG + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + if err := server.Start(ctx); err != nil { + log.Error("Server failed to start", logger.Field{Key: "error", Value: err.Error()}) + } + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Start enhanced DAG features + startEnhancedFeatures(ctx, dagInstance, log) + + // Set up HTTP API for monitoring and management + setupHTTPAPI(dagInstance, log) + + // Start the HTTP server + go func() { + log.Info("Starting HTTP server on :8080") + if err := http.ListenAndServe(":8080", nil); err != nil { + log.Error("HTTP server failed", logger.Field{Key: "error", Value: err.Error()}) + } + }() + + // Demonstrate the enhanced features + demonstrateFeatures(ctx, dagInstance, log) + + // Wait for shutdown signal + waitForShutdown(ctx, cancel, dagInstance, server, log) +} + +func setupEnhancedDAG(dagInstance *dag.DAG, log logger.Logger) { + // Initialize activity logger with memory persistence + activityConfig := dag.DefaultActivityLoggerConfig() + activityConfig.BufferSize = 500 + activityConfig.FlushInterval = 2 * time.Second + + persistence := dag.NewMemoryActivityPersistence() + dagInstance.InitializeActivityLogger(activityConfig, persistence) + + // Add custom activity hook + customHook := &CustomActivityHook{logger: log} + dagInstance.AddActivityHook(customHook) + + // Initialize monitoring with comprehensive configuration + monitorConfig := dag.MonitoringConfig{ + MetricsInterval: 5 * time.Second, + EnableHealthCheck: true, + BufferSize: 1000, + } + + alertThresholds := &dag.AlertThresholds{ + MaxFailureRate: 0.1, // 10% + MaxExecutionTime: 30 * time.Second, + MaxTasksInProgress: 100, + MinSuccessRate: 0.9, // 90% + MaxNodeFailures: 5, + HealthCheckInterval: 10 * time.Second, + } + + dagInstance.InitializeMonitoring(monitorConfig, alertThresholds) + + // Add custom alert handler + customAlertHandler := &CustomAlertHandler{logger: log} + dagInstance.AddAlertHandler(customAlertHandler) + + // Initialize configuration management + dagInstance.InitializeConfigManager() + + // Set up rate limiting + dagInstance.InitializeRateLimiter() + dagInstance.SetRateLimit("validate", 10.0, 5) // 10 req/sec, burst 5 + dagInstance.SetRateLimit("process", 20.0, 10) // 20 req/sec, burst 10 + dagInstance.SetRateLimit("finalize", 5.0, 2) // 5 req/sec, burst 2 + + // Initialize retry management + retryConfig := &dag.RetryConfig{ + MaxRetries: 3, + InitialDelay: 1 * time.Second, + MaxDelay: 10 * time.Second, + BackoffFactor: 2.0, + Jitter: true, + RetryCondition: func(err error) bool { + // Custom retry condition - retry on specific errors + return err != nil && err.Error() != "permanent_failure" + }, + } + dagInstance.InitializeRetryManager(retryConfig) + + // Initialize transaction management + txConfig := dag.TransactionConfig{ + DefaultTimeout: 5 * time.Minute, + CleanupInterval: 10 * time.Minute, + } + dagInstance.InitializeTransactionManager(txConfig) + + // Initialize cleanup management + cleanupConfig := dag.CleanupConfig{ + Interval: 5 * time.Minute, + TaskRetentionPeriod: 1 * time.Hour, + ResultRetentionPeriod: 2 * time.Hour, + MaxRetainedTasks: 1000, + } + dagInstance.InitializeCleanupManager(cleanupConfig) + + // Initialize performance optimizer + dagInstance.InitializePerformanceOptimizer() + + // Set up webhook manager for external notifications + httpClient := dag.NewSimpleHTTPClient(30 * time.Second) + webhookManager := dag.NewWebhookManager(httpClient, log) + + // Add webhook for task completion events + webhookConfig := dag.WebhookConfig{ + URL: "https://api.example.com/dag-events", // Replace with actual endpoint + Headers: map[string]string{"Authorization": "Bearer your-token"}, + RetryCount: 3, + Events: []string{"task_completed", "task_failed", "dag_completed"}, + } + webhookManager.AddWebhook("task_completed", webhookConfig) + dagInstance.SetWebhookManager(webhookManager) + + log.Info("Enhanced DAG features initialized successfully") +} + +func buildWorkflow(dagInstance *dag.DAG, log logger.Logger) { + // Create processors for each step + validator := NewExampleProcessor("validator") + processor := NewExampleProcessor("processor") + enricher := NewExampleProcessor("enricher") + finalizer := NewExampleProcessor("finalizer") + + // Build the workflow with retry configurations + retryConfig := &dag.RetryConfig{ + MaxRetries: 2, + InitialDelay: 500 * time.Millisecond, + MaxDelay: 5 * time.Second, + BackoffFactor: 2.0, + } + + dagInstance. + AddNodeWithRetry(dag.Function, "Validate Input", "validate", validator, retryConfig, true). + AddNodeWithRetry(dag.Function, "Process Data", "process", processor, retryConfig). + AddNodeWithRetry(dag.Function, "Enrich Data", "enrich", enricher, retryConfig). + AddNodeWithRetry(dag.Function, "Finalize", "finalize", finalizer, retryConfig). + Connect("validate", "process"). + Connect("process", "enrich"). + Connect("enrich", "finalize") + + // Add conditional connections + dagInstance.AddCondition("validate", "success", "process") + dagInstance.AddCondition("validate", "failure", "finalize") // Skip to finalize on validation failure + + // Validate the DAG structure + if err := dagInstance.ValidateDAG(); err != nil { + log.Error("DAG validation failed", logger.Field{Key: "error", Value: err.Error()}) + os.Exit(1) + } + + log.Info("Workflow built and validated successfully") +} + +func startEnhancedFeatures(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) { + // Start monitoring + dagInstance.StartMonitoring(ctx) + + // Start cleanup manager + dagInstance.StartCleanup(ctx) + + // Enable batch processing + dagInstance.SetBatchProcessingEnabled(true) + + log.Info("Enhanced features started") +} + +func setupHTTPAPI(dagInstance *dag.DAG, log logger.Logger) { + // Set up standard DAG handlers + dagInstance.Handlers(http.DefaultServeMux, "/dag") + + // Set up enhanced API endpoints + enhancedAPI := dag.NewEnhancedAPIHandler(dagInstance) + enhancedAPI.RegisterRoutes(http.DefaultServeMux) + + // Custom endpoints for demonstration + http.HandleFunc("/demo/activities", func(w http.ResponseWriter, r *http.Request) { + filter := dag.ActivityFilter{ + Limit: 50, + } + + activities, err := dagInstance.GetActivities(filter) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := dagInstance.GetActivityLogger().(*dag.ActivityLogger).WriteJSON(w, activities); err != nil { + log.Error("Failed to write activities response", logger.Field{Key: "error", Value: err.Error()}) + } + }) + + http.HandleFunc("/demo/stats", func(w http.ResponseWriter, r *http.Request) { + stats, err := dagInstance.GetActivityStats(dag.ActivityFilter{}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := dagInstance.GetActivityLogger().(*dag.ActivityLogger).WriteJSON(w, stats); err != nil { + log.Error("Failed to write stats response", logger.Field{Key: "error", Value: err.Error()}) + } + }) + + log.Info("HTTP API endpoints configured") +} + +func demonstrateFeatures(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) { + log.Info("Demonstrating enhanced DAG features...") + + // 1. Process a successful task + log.Info("Processing successful task...") + processTask(ctx, dagInstance, map[string]interface{}{ + "id": "task-001", + "data": "valid input data", + "type": "success", + }, log) + + // 2. Process a task that will fail + log.Info("Processing failing task...") + processTask(ctx, dagInstance, map[string]interface{}{ + "id": "task-002", + "data": nil, // This will cause processing issues + "type": "failure", + }, log) + + // 3. Process with transaction + log.Info("Processing with transaction...") + processWithTransaction(ctx, dagInstance, map[string]interface{}{ + "id": "task-003", + "data": "transaction data", + "type": "transaction", + }, log) + + // 4. Demonstrate rate limiting + log.Info("Demonstrating rate limiting...") + demonstrateRateLimiting(ctx, dagInstance, log) + + // 5. Show monitoring metrics + time.Sleep(2 * time.Second) // Allow time for metrics to accumulate + showMetrics(dagInstance, log) + + // 6. Show activity logs + showActivityLogs(dagInstance, log) +} + +func processTask(ctx context.Context, dagInstance *dag.DAG, payload map[string]interface{}, log logger.Logger) { + // Add context information + ctx = context.WithValue(ctx, "user_id", "demo-user") + ctx = context.WithValue(ctx, "session_id", "demo-session") + ctx = context.WithValue(ctx, "trace_id", mq.NewID()) + + result := dagInstance.Process(ctx, payload) + if result.Error != nil { + log.Error("Task processing failed", + logger.Field{Key: "error", Value: result.Error.Error()}, + logger.Field{Key: "payload", Value: payload}, + ) + } else { + log.Info("Task processed successfully", + logger.Field{Key: "result_size", Value: len(result.Payload)}, + ) + } +} + +func processWithTransaction(ctx context.Context, dagInstance *dag.DAG, payload map[string]interface{}, log logger.Logger) { + taskID := fmt.Sprintf("tx-%s", mq.NewID()) + + // Begin transaction + tx := dagInstance.BeginTransaction(taskID) + if tx == nil { + log.Error("Failed to begin transaction") + return + } + + // Add transaction context + ctx = context.WithValue(ctx, "transaction_id", tx.ID) + ctx = context.WithValue(ctx, "task_id", taskID) + + // Process the task + result := dagInstance.Process(ctx, payload) + + // Commit or rollback based on result + if result.Error != nil { + if err := dagInstance.RollbackTransaction(tx.ID); err != nil { + log.Error("Failed to rollback transaction", + logger.Field{Key: "tx_id", Value: tx.ID}, + logger.Field{Key: "error", Value: err.Error()}, + ) + } else { + log.Info("Transaction rolled back", + logger.Field{Key: "tx_id", Value: tx.ID}, + ) + } + } else { + if err := dagInstance.CommitTransaction(tx.ID); err != nil { + log.Error("Failed to commit transaction", + logger.Field{Key: "tx_id", Value: tx.ID}, + logger.Field{Key: "error", Value: err.Error()}, + ) + } else { + log.Info("Transaction committed", + logger.Field{Key: "tx_id", Value: tx.ID}, + ) + } + } +} + +func demonstrateRateLimiting(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) { + // Try to exceed rate limits + for i := 0; i < 15; i++ { + allowed := dagInstance.CheckRateLimit("validate") + log.Info("Rate limit check", + logger.Field{Key: "attempt", Value: i + 1}, + logger.Field{Key: "allowed", Value: allowed}, + ) + + if allowed { + processTask(ctx, dagInstance, map[string]interface{}{ + "id": fmt.Sprintf("rate-test-%d", i), + "data": "rate limiting test", + }, log) + } + + time.Sleep(100 * time.Millisecond) + } +} + +func showMetrics(dagInstance *dag.DAG, log logger.Logger) { + metrics := dagInstance.GetMonitoringMetrics() + if metrics != nil { + log.Info("Current DAG Metrics", + logger.Field{Key: "total_tasks", Value: metrics.TasksTotal}, + logger.Field{Key: "completed_tasks", Value: metrics.TasksCompleted}, + logger.Field{Key: "failed_tasks", Value: metrics.TasksFailed}, + logger.Field{Key: "tasks_in_progress", Value: metrics.TasksInProgress}, + logger.Field{Key: "avg_execution_time", Value: metrics.AverageExecutionTime.String()}, + ) + + // Show node-specific metrics + for nodeID := range map[string]bool{"validate": true, "process": true, "enrich": true, "finalize": true} { + if nodeStats := dagInstance.GetNodeStats(nodeID); nodeStats != nil { + log.Info("Node Metrics", + logger.Field{Key: "node_id", Value: nodeID}, + logger.Field{Key: "executions", Value: nodeStats.TotalExecutions}, + logger.Field{Key: "failures", Value: nodeStats.FailureCount}, + logger.Field{Key: "avg_duration", Value: nodeStats.AverageExecutionTime.String()}, + ) + } + } + } else { + log.Warn("Monitoring metrics not available") + } +} + +func showActivityLogs(dagInstance *dag.DAG, log logger.Logger) { + // Get recent activities + filter := dag.ActivityFilter{ + Limit: 10, + SortBy: "timestamp", + SortOrder: "desc", + } + + activities, err := dagInstance.GetActivities(filter) + if err != nil { + log.Error("Failed to get activities", logger.Field{Key: "error", Value: err.Error()}) + return + } + + log.Info("Recent Activities", logger.Field{Key: "count", Value: len(activities)}) + for _, activity := range activities { + log.Info("Activity", + logger.Field{Key: "id", Value: activity.ID}, + logger.Field{Key: "type", Value: string(activity.Type)}, + logger.Field{Key: "level", Value: string(activity.Level)}, + logger.Field{Key: "message", Value: activity.Message}, + logger.Field{Key: "task_id", Value: activity.TaskID}, + logger.Field{Key: "node_id", Value: activity.NodeID}, + ) + } + + // Get activity statistics + stats, err := dagInstance.GetActivityStats(dag.ActivityFilter{}) + if err != nil { + log.Error("Failed to get activity stats", logger.Field{Key: "error", Value: err.Error()}) + return + } + + log.Info("Activity Statistics", + logger.Field{Key: "total_activities", Value: stats.TotalActivities}, + logger.Field{Key: "success_rate", Value: fmt.Sprintf("%.2f%%", stats.SuccessRate*100)}, + logger.Field{Key: "failure_rate", Value: fmt.Sprintf("%.2f%%", stats.FailureRate*100)}, + logger.Field{Key: "avg_duration", Value: stats.AverageDuration.String()}, + ) +} + +func waitForShutdown(ctx context.Context, cancel context.CancelFunc, dagInstance *dag.DAG, server *mq.Server, log logger.Logger) { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + log.Info("DAG system is running. Available endpoints:", + logger.Field{Key: "workflow", Value: "http://localhost:8080/dag/"}, + logger.Field{Key: "process", Value: "http://localhost:8080/dag/process"}, + logger.Field{Key: "metrics", Value: "http://localhost:8080/api/dag/metrics"}, + logger.Field{Key: "health", Value: "http://localhost:8080/api/dag/health"}, + logger.Field{Key: "activities", Value: "http://localhost:8080/demo/activities"}, + logger.Field{Key: "stats", Value: "http://localhost:8080/demo/stats"}, + ) + + <-sigChan + log.Info("Shutdown signal received, cleaning up...") + + // Graceful shutdown + cancel() + + // Stop enhanced features + dagInstance.StopEnhanced(ctx) + + // Stop server + if err := server.Stop(ctx); err != nil { + log.Error("Error stopping server", logger.Field{Key: "error", Value: err.Error()}) + } + + log.Info("Shutdown complete") +}