Files
mq/dag/retry.go
Oarkflow d814019d73 update
2025-07-30 12:29:04 +05:45

364 lines
9.4 KiB
Go

package dag
import (
"context"
"fmt"
"sync"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/logger"
)
// RetryConfig defines retry behavior for failed nodes
type RetryConfig struct {
MaxRetries int
InitialDelay time.Duration
MaxDelay time.Duration
BackoffFactor float64
Jitter bool
RetryCondition func(err error) bool
}
// DefaultRetryConfig returns a sensible default retry configuration
func DefaultRetryConfig() *RetryConfig {
return &RetryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
Jitter: true,
RetryCondition: func(err error) bool { return true }, // Retry all errors by default
}
}
// NodeRetryManager handles retry logic for individual nodes
type NodeRetryManager struct {
config *RetryConfig
attempts map[string]int
mu sync.RWMutex
logger logger.Logger
}
// NewNodeRetryManager creates a new retry manager
func NewNodeRetryManager(config *RetryConfig, logger logger.Logger) *NodeRetryManager {
if config == nil {
config = DefaultRetryConfig()
}
return &NodeRetryManager{
config: config,
attempts: make(map[string]int),
logger: logger,
}
}
// ShouldRetry determines if a failed node should be retried
func (rm *NodeRetryManager) ShouldRetry(taskID, nodeID string, err error) bool {
rm.mu.RLock()
attempts := rm.attempts[rm.getKey(taskID, nodeID)]
rm.mu.RUnlock()
if attempts >= rm.config.MaxRetries {
return false
}
if rm.config.RetryCondition != nil && !rm.config.RetryCondition(err) {
return false
}
return true
}
// GetRetryDelay calculates the delay before the next retry
func (rm *NodeRetryManager) GetRetryDelay(taskID, nodeID string) time.Duration {
rm.mu.RLock()
attempts := rm.attempts[rm.getKey(taskID, nodeID)]
rm.mu.RUnlock()
delay := rm.config.InitialDelay
for i := 0; i < attempts; i++ {
delay = time.Duration(float64(delay) * rm.config.BackoffFactor)
if delay > rm.config.MaxDelay {
delay = rm.config.MaxDelay
break
}
}
if rm.config.Jitter {
// Add up to 25% jitter
jitter := time.Duration(float64(delay) * 0.25 * (0.5 - float64(time.Now().UnixNano()%2)))
delay += jitter
}
return delay
}
// RecordAttempt records a retry attempt
func (rm *NodeRetryManager) RecordAttempt(taskID, nodeID string) {
rm.mu.Lock()
key := rm.getKey(taskID, nodeID)
rm.attempts[key]++
rm.mu.Unlock()
rm.logger.Info("Retry attempt recorded",
logger.Field{Key: "taskID", Value: taskID},
logger.Field{Key: "nodeID", Value: nodeID},
logger.Field{Key: "attempt", Value: rm.attempts[key]},
)
}
// Reset clears retry attempts for a task/node combination
func (rm *NodeRetryManager) Reset(taskID, nodeID string) {
rm.mu.Lock()
delete(rm.attempts, rm.getKey(taskID, nodeID))
rm.mu.Unlock()
}
// ResetTask clears all retry attempts for a task
func (rm *NodeRetryManager) ResetTask(taskID string) {
rm.mu.Lock()
for key := range rm.attempts {
if len(key) > len(taskID) && key[:len(taskID)+1] == taskID+":" {
delete(rm.attempts, key)
}
}
rm.mu.Unlock()
}
// GetAttempts returns the number of attempts for a task/node combination
func (rm *NodeRetryManager) GetAttempts(taskID, nodeID string) int {
rm.mu.RLock()
attempts := rm.attempts[rm.getKey(taskID, nodeID)]
rm.mu.RUnlock()
return attempts
}
func (rm *NodeRetryManager) getKey(taskID, nodeID string) string {
return taskID + ":" + nodeID
}
// SetGlobalConfig sets the global retry configuration
func (rm *NodeRetryManager) SetGlobalConfig(config *RetryConfig) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.config = config
rm.logger.Info("Global retry configuration updated")
}
// SetNodeConfig sets retry configuration for a specific node
func (rm *NodeRetryManager) SetNodeConfig(nodeID string, config *RetryConfig) {
// For simplicity, we'll store node-specific configs in a map
// This could be extended to support per-node configurations
rm.mu.Lock()
defer rm.mu.Unlock()
// Store node-specific config (this is a simplified implementation)
// In a full implementation, you'd have a nodeConfigs map
rm.logger.Info("Node-specific retry configuration set",
logger.Field{Key: "nodeID", Value: nodeID},
logger.Field{Key: "maxRetries", Value: config.MaxRetries},
)
}
// RetryableProcessor wraps a processor with retry logic
type RetryableProcessor struct {
processor mq.Processor
retryManager *NodeRetryManager
logger logger.Logger
}
// NewRetryableProcessor creates a processor with retry capabilities
func NewRetryableProcessor(processor mq.Processor, config *RetryConfig, logger logger.Logger) *RetryableProcessor {
return &RetryableProcessor{
processor: processor,
retryManager: NewNodeRetryManager(config, logger),
logger: logger,
}
}
// ProcessTask processes a task with retry logic
func (rp *RetryableProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
taskID := task.ID
nodeID := task.Topic
result := rp.processor.ProcessTask(ctx, task)
// If the task failed and should be retried
if result.Error != nil && rp.retryManager.ShouldRetry(taskID, nodeID, result.Error) {
rp.retryManager.RecordAttempt(taskID, nodeID)
delay := rp.retryManager.GetRetryDelay(taskID, nodeID)
rp.logger.Warn("Task failed, scheduling retry",
logger.Field{Key: "taskID", Value: taskID},
logger.Field{Key: "nodeID", Value: nodeID},
logger.Field{Key: "error", Value: result.Error.Error()},
logger.Field{Key: "retryDelay", Value: delay.String()},
logger.Field{Key: "attempt", Value: rp.retryManager.GetAttempts(taskID, nodeID)},
)
// Schedule retry after delay
time.AfterFunc(delay, func() {
retryResult := rp.processor.ProcessTask(ctx, task)
if retryResult.Error == nil {
rp.retryManager.Reset(taskID, nodeID)
rp.logger.Info("Task retry succeeded",
logger.Field{Key: "taskID", Value: taskID},
logger.Field{Key: "nodeID", Value: nodeID},
)
}
})
// Return original failure result
return result
}
// If successful, reset retry attempts
if result.Error == nil {
rp.retryManager.Reset(taskID, nodeID)
}
return result
}
// Stop stops the processor
func (rp *RetryableProcessor) Stop(ctx context.Context) error {
return rp.processor.Stop(ctx)
}
// Close closes the processor
func (rp *RetryableProcessor) Close() error {
if closer, ok := rp.processor.(interface{ Close() error }); ok {
return closer.Close()
}
return nil
}
// Consume starts consuming messages
func (rp *RetryableProcessor) Consume(ctx context.Context) error {
return rp.processor.Consume(ctx)
}
// Pause pauses the processor
func (rp *RetryableProcessor) Pause(ctx context.Context) error {
return rp.processor.Pause(ctx)
}
// Resume resumes the processor
func (rp *RetryableProcessor) Resume(ctx context.Context) error {
return rp.processor.Resume(ctx)
}
// GetKey returns the processor key
func (rp *RetryableProcessor) GetKey() string {
return rp.processor.GetKey()
}
// SetKey sets the processor key
func (rp *RetryableProcessor) SetKey(key string) {
rp.processor.SetKey(key)
}
// GetType returns the processor type
func (rp *RetryableProcessor) GetType() string {
return rp.processor.GetType()
}
// Circuit Breaker Implementation
type CircuitBreakerState int
const (
CircuitClosed CircuitBreakerState = iota
CircuitOpen
CircuitHalfOpen
)
// CircuitBreakerConfig defines circuit breaker behavior
type CircuitBreakerConfig struct {
FailureThreshold int
ResetTimeout time.Duration
HalfOpenMaxCalls int
}
// CircuitBreaker implements circuit breaker pattern for nodes
type CircuitBreaker struct {
config *CircuitBreakerConfig
state CircuitBreakerState
failures int
lastFailTime time.Time
halfOpenCalls int
mu sync.RWMutex
logger logger.Logger
}
// NewCircuitBreaker creates a new circuit breaker
func NewCircuitBreaker(config *CircuitBreakerConfig, logger logger.Logger) *CircuitBreaker {
return &CircuitBreaker{
config: config,
state: CircuitClosed,
logger: logger,
}
}
// Execute executes a function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() error) error {
cb.mu.Lock()
defer cb.mu.Unlock()
switch cb.state {
case CircuitOpen:
if time.Since(cb.lastFailTime) > cb.config.ResetTimeout {
cb.state = CircuitHalfOpen
cb.halfOpenCalls = 0
cb.logger.Info("Circuit breaker transitioning to half-open")
} else {
return fmt.Errorf("circuit breaker is open")
}
case CircuitHalfOpen:
if cb.halfOpenCalls >= cb.config.HalfOpenMaxCalls {
return fmt.Errorf("circuit breaker half-open call limit exceeded")
}
cb.halfOpenCalls++
}
err := fn()
if err != nil {
cb.failures++
cb.lastFailTime = time.Now()
if cb.state == CircuitHalfOpen {
cb.state = CircuitOpen
cb.logger.Warn("Circuit breaker opened from half-open state")
} else if cb.failures >= cb.config.FailureThreshold {
cb.state = CircuitOpen
cb.logger.Warn("Circuit breaker opened due to failure threshold")
}
} else {
if cb.state == CircuitHalfOpen {
cb.state = CircuitClosed
cb.failures = 0
cb.logger.Info("Circuit breaker closed from half-open state")
} else if cb.state == CircuitClosed {
cb.failures = 0
}
}
return err
}
// GetState returns the current circuit breaker state
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
cb.mu.RLock()
defer cb.mu.RUnlock()
return cb.state
}
// Reset manually resets the circuit breaker
func (cb *CircuitBreaker) Reset() {
cb.mu.Lock()
defer cb.mu.Unlock()
cb.state = CircuitClosed
cb.failures = 0
cb.halfOpenCalls = 0
}