diff --git a/dedup_and_flow.go b/dedup_and_flow.go index 0ce9fd9..de5543a 100644 --- a/dedup_and_flow.go +++ b/dedup_and_flow.go @@ -6,11 +6,14 @@ import ( "encoding/hex" "encoding/json" "fmt" + "os" "runtime" + "strconv" "sync" "time" "github.com/oarkflow/mq/logger" + "golang.org/x/time/rate" ) // DedupEntry represents a deduplication cache entry @@ -235,31 +238,432 @@ func (dm *DeduplicationManager) Shutdown(ctx context.Context) error { return nil } -// FlowController manages backpressure and flow control -type FlowController struct { - credits int64 - maxCredits int64 - minCredits int64 - creditRefillRate int64 - mu sync.Mutex - logger logger.Logger - shutdown chan struct{} - refillInterval time.Duration - onCreditLow func(current, max int64) - onCreditHigh func(current, max int64) +// TokenBucketStrategy implements token bucket algorithm +type TokenBucketStrategy struct { + tokens int64 + capacity int64 + refillRate int64 + refillInterval time.Duration + lastRefill time.Time + mu sync.Mutex + shutdown chan struct{} + logger logger.Logger +} + +// NewTokenBucketStrategy creates a new token bucket strategy +func NewTokenBucketStrategy(config FlowControlConfig) *TokenBucketStrategy { + if config.MaxCredits == 0 { + config.MaxCredits = 1000 + } + if config.RefillRate == 0 { + config.RefillRate = 10 + } + if config.RefillInterval == 0 { + config.RefillInterval = 100 * time.Millisecond + } + if config.BurstSize == 0 { + config.BurstSize = config.MaxCredits + } + + tbs := &TokenBucketStrategy{ + tokens: config.BurstSize, + capacity: config.BurstSize, + refillRate: config.RefillRate, + refillInterval: config.RefillInterval, + lastRefill: time.Now(), + shutdown: make(chan struct{}), + logger: config.Logger, + } + + go tbs.refillLoop() + + return tbs +} + +// Acquire attempts to acquire tokens +func (tbs *TokenBucketStrategy) Acquire(ctx context.Context, amount int64) error { + for { + tbs.mu.Lock() + if tbs.tokens >= amount { + tbs.tokens -= amount + tbs.mu.Unlock() + return nil + } + tbs.mu.Unlock() + + select { + case <-time.After(10 * time.Millisecond): + continue + case <-ctx.Done(): + return ctx.Err() + case <-tbs.shutdown: + return fmt.Errorf("token bucket shutting down") + } + } +} + +// Release returns tokens (not typically used in token bucket) +func (tbs *TokenBucketStrategy) Release(amount int64) { + // Token bucket doesn't typically release tokens back + // This is a no-op for token bucket strategy +} + +// GetAvailableCredits returns available tokens +func (tbs *TokenBucketStrategy) GetAvailableCredits() int64 { + tbs.mu.Lock() + defer tbs.mu.Unlock() + return tbs.tokens +} + +// GetStats returns token bucket statistics +func (tbs *TokenBucketStrategy) GetStats() map[string]interface{} { + tbs.mu.Lock() + defer tbs.mu.Unlock() + + utilization := float64(tbs.capacity-tbs.tokens) / float64(tbs.capacity) * 100 + + return map[string]interface{}{ + "strategy": "token_bucket", + "tokens": tbs.tokens, + "capacity": tbs.capacity, + "refill_rate": tbs.refillRate, + "utilization": utilization, + "last_refill": tbs.lastRefill, + } +} + +// Shutdown stops the token bucket +func (tbs *TokenBucketStrategy) Shutdown() { + close(tbs.shutdown) +} + +// refillLoop periodically refills tokens +func (tbs *TokenBucketStrategy) refillLoop() { + ticker := time.NewTicker(tbs.refillInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + tbs.mu.Lock() + tbs.tokens += tbs.refillRate + if tbs.tokens > tbs.capacity { + tbs.tokens = tbs.capacity + } + tbs.lastRefill = time.Now() + tbs.mu.Unlock() + case <-tbs.shutdown: + return + } + } +} + +// FlowControlStrategy defines the interface for different flow control algorithms +type FlowControlStrategy interface { + // Acquire attempts to acquire credits for processing + Acquire(ctx context.Context, amount int64) error + // Release returns credits after processing + Release(amount int64) + // GetAvailableCredits returns current available credits + GetAvailableCredits() int64 + // GetStats returns strategy-specific statistics + GetStats() map[string]interface{} + // Shutdown cleans up resources + Shutdown() } // FlowControlConfig holds flow control configuration type FlowControlConfig struct { - MaxCredits int64 - MinCredits int64 - RefillRate int64 // Credits to add per interval - RefillInterval time.Duration - Logger logger.Logger + Strategy FlowControlStrategyType `json:"strategy" yaml:"strategy"` + MaxCredits int64 `json:"max_credits" yaml:"max_credits"` + MinCredits int64 `json:"min_credits" yaml:"min_credits"` + RefillRate int64 `json:"refill_rate" yaml:"refill_rate"` + RefillInterval time.Duration `json:"refill_interval" yaml:"refill_interval"` + BurstSize int64 `json:"burst_size" yaml:"burst_size"` // For token bucket + Logger logger.Logger `json:"-" yaml:"-"` } -// NewFlowController creates a new flow controller +// FlowControlStrategyType represents different flow control strategies +type FlowControlStrategyType string + +const ( + StrategyTokenBucket FlowControlStrategyType = "token_bucket" + StrategyLeakyBucket FlowControlStrategyType = "leaky_bucket" + StrategyCreditBased FlowControlStrategyType = "credit_based" + StrategyRateLimiter FlowControlStrategyType = "rate_limiter" +) + +// FlowController manages backpressure and flow control using pluggable strategies +type FlowController struct { + strategy FlowControlStrategy + config FlowControlConfig + onCreditLow func(current, max int64) + onCreditHigh func(current, max int64) + logger logger.Logger + shutdown chan struct{} +} + +// FlowControllerFactory creates flow controllers with different strategies +type FlowControllerFactory struct{} + +// NewFlowControllerFactory creates a new factory +func NewFlowControllerFactory() *FlowControllerFactory { + return &FlowControllerFactory{} +} + +// CreateFlowController creates a flow controller with the specified strategy +func (f *FlowControllerFactory) CreateFlowController(config FlowControlConfig) (*FlowController, error) { + if config.Strategy == "" { + config.Strategy = StrategyTokenBucket + } + + // Validate configuration based on strategy + if err := f.validateConfig(config); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + + return NewFlowController(config), nil +} + +// CreateTokenBucketFlowController creates a token bucket flow controller +func (f *FlowControllerFactory) CreateTokenBucketFlowController(maxCredits, refillRate int64, refillInterval time.Duration, logger logger.Logger) *FlowController { + config := FlowControlConfig{ + Strategy: StrategyTokenBucket, + MaxCredits: maxCredits, + RefillRate: refillRate, + RefillInterval: refillInterval, + BurstSize: maxCredits, + Logger: logger, + } + return NewFlowController(config) +} + +// CreateLeakyBucketFlowController creates a leaky bucket flow controller +func (f *FlowControllerFactory) CreateLeakyBucketFlowController(capacity int64, leakInterval time.Duration, logger logger.Logger) *FlowController { + config := FlowControlConfig{ + Strategy: StrategyLeakyBucket, + MaxCredits: capacity, + RefillInterval: leakInterval, + Logger: logger, + } + return NewFlowController(config) +} + +// CreateCreditBasedFlowController creates a credit-based flow controller +func (f *FlowControllerFactory) CreateCreditBasedFlowController(maxCredits, minCredits, refillRate int64, refillInterval time.Duration, logger logger.Logger) *FlowController { + config := FlowControlConfig{ + Strategy: StrategyCreditBased, + MaxCredits: maxCredits, + MinCredits: minCredits, + RefillRate: refillRate, + RefillInterval: refillInterval, + Logger: logger, + } + return NewFlowController(config) +} + +// CreateRateLimiterFlowController creates a rate limiter flow controller +func (f *FlowControllerFactory) CreateRateLimiterFlowController(requestsPerSecond, burstSize int64, logger logger.Logger) *FlowController { + config := FlowControlConfig{ + Strategy: StrategyRateLimiter, + RefillRate: requestsPerSecond, + BurstSize: burstSize, + Logger: logger, + } + return NewFlowController(config) +} + +// validateConfig validates the configuration for the specified strategy +func (f *FlowControllerFactory) validateConfig(config FlowControlConfig) error { + switch config.Strategy { + case StrategyTokenBucket: + if config.MaxCredits <= 0 { + return fmt.Errorf("max_credits must be positive for token bucket strategy") + } + if config.RefillRate <= 0 { + return fmt.Errorf("refill_rate must be positive for token bucket strategy") + } + case StrategyLeakyBucket: + if config.MaxCredits <= 0 { + return fmt.Errorf("max_credits must be positive for leaky bucket strategy") + } + case StrategyCreditBased: + if config.MaxCredits <= 0 { + return fmt.Errorf("max_credits must be positive for credit-based strategy") + } + if config.MinCredits < 0 || config.MinCredits > config.MaxCredits { + return fmt.Errorf("min_credits must be between 0 and max_credits for credit-based strategy") + } + case StrategyRateLimiter: + if config.RefillRate <= 0 { + return fmt.Errorf("refill_rate must be positive for rate limiter strategy") + } + if config.BurstSize <= 0 { + return fmt.Errorf("burst_size must be positive for rate limiter strategy") + } + default: + return fmt.Errorf("unknown strategy: %s", config.Strategy) + } + return nil +} + +// NewFlowController creates a new flow controller with the specified strategy func NewFlowController(config FlowControlConfig) *FlowController { + if config.Strategy == "" { + config.Strategy = StrategyTokenBucket + } + + var strategy FlowControlStrategy + switch config.Strategy { + case StrategyTokenBucket: + strategy = NewTokenBucketStrategy(config) + case StrategyLeakyBucket: + strategy = NewLeakyBucketStrategy(config) + case StrategyCreditBased: + strategy = NewCreditBasedStrategy(config) + case StrategyRateLimiter: + strategy = NewRateLimiterStrategy(config) + default: + // Default to token bucket + strategy = NewTokenBucketStrategy(config) + } + + fc := &FlowController{ + strategy: strategy, + config: config, + logger: config.Logger, + shutdown: make(chan struct{}), + } + + return fc +} + +// LeakyBucketStrategy implements leaky bucket algorithm +type LeakyBucketStrategy struct { + queue chan struct{} + capacity int64 + leakRate time.Duration + lastLeak time.Time + mu sync.Mutex + shutdown chan struct{} + logger logger.Logger +} + +// NewLeakyBucketStrategy creates a new leaky bucket strategy +func NewLeakyBucketStrategy(config FlowControlConfig) *LeakyBucketStrategy { + if config.MaxCredits == 0 { + config.MaxCredits = 1000 + } + if config.RefillInterval == 0 { + config.RefillInterval = 100 * time.Millisecond + } + + lbs := &LeakyBucketStrategy{ + queue: make(chan struct{}, config.MaxCredits), + capacity: config.MaxCredits, + leakRate: config.RefillInterval, + lastLeak: time.Now(), + shutdown: make(chan struct{}), + logger: config.Logger, + } + + go lbs.leakLoop() + + return lbs +} + +// Acquire attempts to add to the bucket +func (lbs *LeakyBucketStrategy) Acquire(ctx context.Context, amount int64) error { + for i := int64(0); i < amount; i++ { + select { + case lbs.queue <- struct{}{}: + // Successfully added + case <-ctx.Done(): + return ctx.Err() + case <-lbs.shutdown: + return fmt.Errorf("leaky bucket shutting down") + default: + // Bucket is full, wait and retry + select { + case <-time.After(10 * time.Millisecond): + continue + case <-ctx.Done(): + return ctx.Err() + case <-lbs.shutdown: + return fmt.Errorf("leaky bucket shutting down") + } + } + } + return nil +} + +// Release removes from the bucket (leaking) +func (lbs *LeakyBucketStrategy) Release(amount int64) { + for i := int64(0); i < amount; i++ { + select { + case <-lbs.queue: + default: + } + } +} + +// GetAvailableCredits returns available capacity +func (lbs *LeakyBucketStrategy) GetAvailableCredits() int64 { + return lbs.capacity - int64(len(lbs.queue)) +} + +// GetStats returns leaky bucket statistics +func (lbs *LeakyBucketStrategy) GetStats() map[string]interface{} { + return map[string]interface{}{ + "strategy": "leaky_bucket", + "queue_size": len(lbs.queue), + "capacity": lbs.capacity, + "leak_rate": lbs.leakRate, + "utilization": float64(len(lbs.queue)) / float64(lbs.capacity) * 100, + } +} + +// Shutdown stops the leaky bucket +func (lbs *LeakyBucketStrategy) Shutdown() { + close(lbs.shutdown) +} + +// leakLoop periodically leaks from the bucket +func (lbs *LeakyBucketStrategy) leakLoop() { + ticker := time.NewTicker(lbs.leakRate) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + select { + case <-lbs.queue: + // Leaked one + default: + // Empty + } + case <-lbs.shutdown: + return + } + } +} + +// CreditBasedStrategy implements credit-based flow control +type CreditBasedStrategy struct { + credits int64 + maxCredits int64 + minCredits int64 + refillRate int64 + refillInterval time.Duration + mu sync.Mutex + shutdown chan struct{} + logger logger.Logger + onCreditLow func(current, max int64) + onCreditHigh func(current, max int64) +} + +// NewCreditBasedStrategy creates a new credit-based strategy +func NewCreditBasedStrategy(config FlowControlConfig) *CreditBasedStrategy { if config.MaxCredits == 0 { config.MaxCredits = 1000 } @@ -273,131 +677,229 @@ func NewFlowController(config FlowControlConfig) *FlowController { config.RefillInterval = 100 * time.Millisecond } - fc := &FlowController{ - credits: config.MaxCredits, - maxCredits: config.MaxCredits, - minCredits: config.MinCredits, - creditRefillRate: config.RefillRate, - refillInterval: config.RefillInterval, - logger: config.Logger, - shutdown: make(chan struct{}), + cbs := &CreditBasedStrategy{ + credits: config.MaxCredits, + maxCredits: config.MaxCredits, + minCredits: config.MinCredits, + refillRate: config.RefillRate, + refillInterval: config.RefillInterval, + shutdown: make(chan struct{}), + logger: config.Logger, } - go fc.refillLoop() + go cbs.refillLoop() - return fc + return cbs } -// AcquireCredit attempts to acquire credits for processing -func (fc *FlowController) AcquireCredit(ctx context.Context, amount int64) error { +// Acquire attempts to acquire credits +func (cbs *CreditBasedStrategy) Acquire(ctx context.Context, amount int64) error { for { - fc.mu.Lock() - if fc.credits >= amount { - fc.credits -= amount + cbs.mu.Lock() + if cbs.credits >= amount { + cbs.credits -= amount - // Check if credits are low - if fc.credits < fc.minCredits && fc.onCreditLow != nil { - go fc.onCreditLow(fc.credits, fc.maxCredits) + if cbs.credits < cbs.minCredits && cbs.onCreditLow != nil { + go cbs.onCreditLow(cbs.credits, cbs.maxCredits) } - fc.mu.Unlock() + cbs.mu.Unlock() return nil } - fc.mu.Unlock() + cbs.mu.Unlock() - // Wait before retrying select { case <-time.After(10 * time.Millisecond): continue case <-ctx.Done(): return ctx.Err() - case <-fc.shutdown: - return fmt.Errorf("flow controller shutting down") + case <-cbs.shutdown: + return fmt.Errorf("credit-based strategy shutting down") } } } -// ReleaseCredit returns credits after processing -func (fc *FlowController) ReleaseCredit(amount int64) { - fc.mu.Lock() - defer fc.mu.Unlock() +// Release returns credits +func (cbs *CreditBasedStrategy) Release(amount int64) { + cbs.mu.Lock() + defer cbs.mu.Unlock() - fc.credits += amount - if fc.credits > fc.maxCredits { - fc.credits = fc.maxCredits + cbs.credits += amount + if cbs.credits > cbs.maxCredits { + cbs.credits = cbs.maxCredits } - // Check if credits recovered - if fc.credits > fc.maxCredits/2 && fc.onCreditHigh != nil { - go fc.onCreditHigh(fc.credits, fc.maxCredits) + if cbs.credits > cbs.maxCredits/2 && cbs.onCreditHigh != nil { + go cbs.onCreditHigh(cbs.credits, cbs.maxCredits) } } +// GetAvailableCredits returns available credits +func (cbs *CreditBasedStrategy) GetAvailableCredits() int64 { + cbs.mu.Lock() + defer cbs.mu.Unlock() + return cbs.credits +} + +// GetStats returns credit-based statistics +func (cbs *CreditBasedStrategy) GetStats() map[string]interface{} { + cbs.mu.Lock() + defer cbs.mu.Unlock() + + utilization := float64(cbs.maxCredits-cbs.credits) / float64(cbs.maxCredits) * 100 + + return map[string]interface{}{ + "strategy": "credit_based", + "credits": cbs.credits, + "max_credits": cbs.maxCredits, + "min_credits": cbs.minCredits, + "refill_rate": cbs.refillRate, + "utilization": utilization, + } +} + +// Shutdown stops the credit-based strategy +func (cbs *CreditBasedStrategy) Shutdown() { + close(cbs.shutdown) +} + // refillLoop periodically refills credits -func (fc *FlowController) refillLoop() { - ticker := time.NewTicker(fc.refillInterval) +func (cbs *CreditBasedStrategy) refillLoop() { + ticker := time.NewTicker(cbs.refillInterval) defer ticker.Stop() for { select { case <-ticker.C: - fc.mu.Lock() - fc.credits += fc.creditRefillRate - if fc.credits > fc.maxCredits { - fc.credits = fc.maxCredits + cbs.mu.Lock() + cbs.credits += cbs.refillRate + if cbs.credits > cbs.maxCredits { + cbs.credits = cbs.maxCredits } - fc.mu.Unlock() - case <-fc.shutdown: + cbs.mu.Unlock() + case <-cbs.shutdown: return } } } +// RateLimiterStrategy implements rate limiting using golang.org/x/time/rate +type RateLimiterStrategy struct { + limiter *rate.Limiter + shutdown chan struct{} + logger logger.Logger +} + +// NewRateLimiterStrategy creates a new rate limiter strategy +func NewRateLimiterStrategy(config FlowControlConfig) *RateLimiterStrategy { + if config.RefillRate == 0 { + config.RefillRate = 10 + } + if config.BurstSize == 0 { + config.BurstSize = 100 + } + + // Convert refill rate to requests per second + rps := rate.Limit(config.RefillRate) / rate.Limit(time.Second/time.Millisecond*100) + + rls := &RateLimiterStrategy{ + limiter: rate.NewLimiter(rps, int(config.BurstSize)), + shutdown: make(chan struct{}), + logger: config.Logger, + } + + return rls +} + +// Acquire attempts to acquire permission +func (rls *RateLimiterStrategy) Acquire(ctx context.Context, amount int64) error { + // For rate limiter, amount represents the number of requests + for i := int64(0); i < amount; i++ { + if err := rls.limiter.Wait(ctx); err != nil { + return err + } + } + return nil +} + +// Release is a no-op for rate limiter +func (rls *RateLimiterStrategy) Release(amount int64) { + // Rate limiter doesn't release tokens back +} + +// GetAvailableCredits returns burst capacity minus tokens used +func (rls *RateLimiterStrategy) GetAvailableCredits() int64 { + // This is approximate since rate.Limiter doesn't expose internal state + return int64(rls.limiter.Burst()) - int64(rls.limiter.Tokens()) +} + +// GetStats returns rate limiter statistics +func (rls *RateLimiterStrategy) GetStats() map[string]interface{} { + return map[string]interface{}{ + "strategy": "rate_limiter", + "limit": rls.limiter.Limit(), + "burst": rls.limiter.Burst(), + "tokens": rls.limiter.Tokens(), + } +} + +// Shutdown stops the rate limiter +func (rls *RateLimiterStrategy) Shutdown() { + close(rls.shutdown) +} + +// AcquireCredit attempts to acquire credits for processing +func (fc *FlowController) AcquireCredit(ctx context.Context, amount int64) error { + return fc.strategy.Acquire(ctx, amount) +} + +// ReleaseCredit returns credits after processing +func (fc *FlowController) ReleaseCredit(amount int64) { + fc.strategy.Release(amount) +} + // GetAvailableCredits returns the current available credits func (fc *FlowController) GetAvailableCredits() int64 { - fc.mu.Lock() - defer fc.mu.Unlock() - return fc.credits + return fc.strategy.GetAvailableCredits() } // SetOnCreditLow sets callback for low credit warning func (fc *FlowController) SetOnCreditLow(fn func(current, max int64)) { fc.onCreditLow = fn + // If strategy supports callbacks, set them + if cbs, ok := fc.strategy.(*CreditBasedStrategy); ok { + cbs.onCreditLow = fn + } } // SetOnCreditHigh sets callback for credit recovery func (fc *FlowController) SetOnCreditHigh(fn func(current, max int64)) { fc.onCreditHigh = fn + // If strategy supports callbacks, set them + if cbs, ok := fc.strategy.(*CreditBasedStrategy); ok { + cbs.onCreditHigh = fn + } } // AdjustMaxCredits dynamically adjusts maximum credits func (fc *FlowController) AdjustMaxCredits(newMax int64) { - fc.mu.Lock() - defer fc.mu.Unlock() - - fc.maxCredits = newMax - if fc.credits > newMax { - fc.credits = newMax - } - + fc.config.MaxCredits = newMax fc.logger.Info("Adjusted max credits", logger.Field{Key: "newMax", Value: newMax}) } // GetStats returns flow control statistics func (fc *FlowController) GetStats() map[string]interface{} { - fc.mu.Lock() - defer fc.mu.Unlock() - - utilization := float64(fc.maxCredits-fc.credits) / float64(fc.maxCredits) * 100 - - return map[string]interface{}{ - "credits": fc.credits, - "max_credits": fc.maxCredits, - "min_credits": fc.minCredits, - "utilization": utilization, - "refill_rate": fc.creditRefillRate, + stats := fc.strategy.GetStats() + stats["config"] = map[string]interface{}{ + "strategy": fc.config.Strategy, + "max_credits": fc.config.MaxCredits, + "min_credits": fc.config.MinCredits, + "refill_rate": fc.config.RefillRate, + "refill_interval": fc.config.RefillInterval, + "burst_size": fc.config.BurstSize, } + return stats } // Shutdown stops the flow controller @@ -509,3 +1011,186 @@ func (bm *BackpressureMonitor) SetOnBackpressureRelieved(fn func()) { func (bm *BackpressureMonitor) Shutdown() { close(bm.shutdown) } + +// FlowControlConfigProvider provides configuration from various sources +type FlowControlConfigProvider interface { + GetConfig() (FlowControlConfig, error) +} + +// EnvConfigProvider loads configuration from environment variables +type EnvConfigProvider struct { + prefix string // Environment variable prefix, e.g., "FLOW_" +} + +// NewEnvConfigProvider creates a new environment config provider +func NewEnvConfigProvider(prefix string) *EnvConfigProvider { + if prefix == "" { + prefix = "FLOW_" + } + return &EnvConfigProvider{prefix: prefix} +} + +// GetConfig loads configuration from environment variables +func (e *EnvConfigProvider) GetConfig() (FlowControlConfig, error) { + config := FlowControlConfig{} + + // Load strategy + if strategy := os.Getenv(e.prefix + "STRATEGY"); strategy != "" { + config.Strategy = FlowControlStrategyType(strategy) + } else { + config.Strategy = StrategyTokenBucket + } + + // Load numeric values + if maxCredits := os.Getenv(e.prefix + "MAX_CREDITS"); maxCredits != "" { + if val, err := strconv.ParseInt(maxCredits, 10, 64); err == nil { + config.MaxCredits = val + } + } + + if minCredits := os.Getenv(e.prefix + "MIN_CREDITS"); minCredits != "" { + if val, err := strconv.ParseInt(minCredits, 10, 64); err == nil { + config.MinCredits = val + } + } + + if refillRate := os.Getenv(e.prefix + "REFILL_RATE"); refillRate != "" { + if val, err := strconv.ParseInt(refillRate, 10, 64); err == nil { + config.RefillRate = val + } + } + + if burstSize := os.Getenv(e.prefix + "BURST_SIZE"); burstSize != "" { + if val, err := strconv.ParseInt(burstSize, 10, 64); err == nil { + config.BurstSize = val + } + } + + // Load duration values + if refillInterval := os.Getenv(e.prefix + "REFILL_INTERVAL"); refillInterval != "" { + if val, err := time.ParseDuration(refillInterval); err == nil { + config.RefillInterval = val + } + } + + // Set defaults if not specified + e.setDefaults(&config) + + return config, nil +} + +// setDefaults sets default values for missing configuration +func (e *EnvConfigProvider) setDefaults(config *FlowControlConfig) { + if config.MaxCredits == 0 { + config.MaxCredits = 1000 + } + if config.MinCredits == 0 { + config.MinCredits = 100 + } + if config.RefillRate == 0 { + config.RefillRate = 10 + } + if config.RefillInterval == 0 { + config.RefillInterval = 100 * time.Millisecond + } + if config.BurstSize == 0 { + config.BurstSize = config.MaxCredits + } +} + +// FileConfigProvider loads configuration from a file +type FileConfigProvider struct { + filePath string +} + +// NewFileConfigProvider creates a new file config provider +func NewFileConfigProvider(filePath string) *FileConfigProvider { + return &FileConfigProvider{filePath: filePath} +} + +// GetConfig loads configuration from a file +func (f *FileConfigProvider) GetConfig() (FlowControlConfig, error) { + data, err := os.ReadFile(f.filePath) + if err != nil { + return FlowControlConfig{}, fmt.Errorf("failed to read config file: %w", err) + } + + var config FlowControlConfig + if err := json.Unmarshal(data, &config); err != nil { + return FlowControlConfig{}, fmt.Errorf("failed to parse config file: %w", err) + } + + // Set defaults for missing values + f.setDefaults(&config) + + return config, nil +} + +// setDefaults sets default values for missing configuration +func (f *FileConfigProvider) setDefaults(config *FlowControlConfig) { + if config.Strategy == "" { + config.Strategy = StrategyTokenBucket + } + if config.MaxCredits == 0 { + config.MaxCredits = 1000 + } + if config.MinCredits == 0 { + config.MinCredits = 100 + } + if config.RefillRate == 0 { + config.RefillRate = 10 + } + if config.RefillInterval == 0 { + config.RefillInterval = 100 * time.Millisecond + } + if config.BurstSize == 0 { + config.BurstSize = config.MaxCredits + } +} + +// CompositeConfigProvider combines multiple config providers +type CompositeConfigProvider struct { + providers []FlowControlConfigProvider +} + +// NewCompositeConfigProvider creates a new composite config provider +func NewCompositeConfigProvider(providers ...FlowControlConfigProvider) *CompositeConfigProvider { + return &CompositeConfigProvider{providers: providers} +} + +// GetConfig loads configuration from all providers, with later providers overriding earlier ones +func (c *CompositeConfigProvider) GetConfig() (FlowControlConfig, error) { + var finalConfig FlowControlConfig + + for _, provider := range c.providers { + config, err := provider.GetConfig() + if err != nil { + return FlowControlConfig{}, fmt.Errorf("config provider failed: %w", err) + } + + // Merge configurations (simple override for now) + if config.Strategy != "" { + finalConfig.Strategy = config.Strategy + } + if config.MaxCredits != 0 { + finalConfig.MaxCredits = config.MaxCredits + } + if config.MinCredits != 0 { + finalConfig.MinCredits = config.MinCredits + } + if config.RefillRate != 0 { + finalConfig.RefillRate = config.RefillRate + } + if config.RefillInterval != 0 { + finalConfig.RefillInterval = config.RefillInterval + } + if config.BurstSize != 0 { + finalConfig.BurstSize = config.BurstSize + } + if config.Logger != nil { + finalConfig.Logger = config.Logger + } + } + + return finalConfig, nil +} diff --git a/enhanced_integration.go b/enhanced_integration.go index 91d6ed8..6fe8280 100644 --- a/enhanced_integration.go +++ b/enhanced_integration.go @@ -3,6 +3,7 @@ package mq import ( "context" "encoding/json" + "fmt" "time" "github.com/oarkflow/mq/logger" @@ -42,10 +43,28 @@ type BrokerEnhancedConfig struct { DedupPersistent bool // Flow Control Configuration - MaxCredits int64 - MinCredits int64 - CreditRefillRate int64 - CreditRefillInterval time.Duration + FlowControlStrategy FlowControlStrategyType + FlowControlConfigPath string // Path to flow control config file + FlowControlEnvPrefix string // Environment variable prefix for flow control + MaxCredits int64 + MinCredits int64 + CreditRefillRate int64 + CreditRefillInterval time.Duration + // Token bucket specific + TokenBucketCapacity int64 + TokenBucketRefillRate int64 + TokenBucketRefillInterval time.Duration + // Leaky bucket specific + LeakyBucketCapacity int64 + LeakyBucketLeakInterval time.Duration + // Credit-based specific + CreditBasedMaxCredits int64 + CreditBasedRefillRate int64 + CreditBasedRefillInterval time.Duration + CreditBasedBurstSize int64 + // Rate limiter specific + RateLimiterRequestsPerSecond int64 + RateLimiterBurstSize int64 // Backpressure Configuration QueueDepthThreshold int @@ -166,15 +185,70 @@ func (b *Broker) InitializeEnhancements(config *BrokerEnhancedConfig) error { } features.dedupManager = NewDeduplicationManager(dedupConfig) - // Initialize Flow Controller - flowConfig := FlowControlConfig{ - MaxCredits: config.MaxCredits, - MinCredits: config.MinCredits, - RefillRate: config.CreditRefillRate, - RefillInterval: config.CreditRefillInterval, - Logger: config.Logger, + // Initialize Flow Controller using factory + factory := NewFlowControllerFactory() + + // Try to load configuration from providers + var flowConfig FlowControlConfig + var err error + + // First try file-based configuration + if config.FlowControlConfigPath != "" { + fileProvider := NewFileConfigProvider(config.FlowControlConfigPath) + if loadedConfig, loadErr := fileProvider.GetConfig(); loadErr == nil { + flowConfig = loadedConfig + } + } + + // If no file config, try environment variables + if flowConfig.Strategy == "" && config.FlowControlEnvPrefix != "" { + envProvider := NewEnvConfigProvider(config.FlowControlEnvPrefix) + if loadedConfig, loadErr := envProvider.GetConfig(); loadErr == nil { + flowConfig = loadedConfig + } + } + + // If still no config, use broker config defaults based on strategy + if flowConfig.Strategy == "" { + flowConfig = FlowControlConfig{ + Strategy: config.FlowControlStrategy, + Logger: config.Logger, + } + + // Set strategy-specific defaults + switch config.FlowControlStrategy { + case StrategyTokenBucket: + flowConfig.MaxCredits = config.TokenBucketCapacity + flowConfig.RefillRate = config.TokenBucketRefillRate + flowConfig.RefillInterval = config.TokenBucketRefillInterval + case StrategyLeakyBucket: + flowConfig.MaxCredits = config.LeakyBucketCapacity + flowConfig.RefillInterval = config.LeakyBucketLeakInterval + case StrategyCreditBased: + flowConfig.MaxCredits = config.CreditBasedMaxCredits + flowConfig.RefillRate = config.CreditBasedRefillRate + flowConfig.RefillInterval = config.CreditBasedRefillInterval + flowConfig.BurstSize = config.CreditBasedBurstSize + case StrategyRateLimiter: + flowConfig.RefillRate = config.RateLimiterRequestsPerSecond + flowConfig.BurstSize = config.RateLimiterBurstSize + default: + // Fallback to token bucket + flowConfig.Strategy = StrategyTokenBucket + flowConfig.MaxCredits = config.MaxCredits + flowConfig.RefillRate = config.CreditRefillRate + flowConfig.RefillInterval = config.CreditRefillInterval + } + } + + // Ensure logger is set + flowConfig.Logger = config.Logger + + // Create flow controller using factory + features.flowController, err = factory.CreateFlowController(flowConfig) + if err != nil { + return fmt.Errorf("failed to create flow controller: %w", err) } - features.flowController = NewFlowController(flowConfig) // Initialize Backpressure Monitor backpressureConfig := BackpressureConfig{ @@ -237,19 +311,34 @@ func DefaultBrokerEnhancedConfig() *BrokerEnhancedConfig { ScaleDownThreshold: 0.25, DedupWindow: 5 * time.Minute, DedupCleanupInterval: 1 * time.Minute, - MaxCredits: 1000, - MinCredits: 100, - CreditRefillRate: 10, - CreditRefillInterval: 100 * time.Millisecond, - QueueDepthThreshold: 1000, - MemoryThreshold: 1 * 1024 * 1024 * 1024, // 1GB - ErrorRateThreshold: 0.5, - SnapshotInterval: 5 * time.Minute, - SnapshotRetention: 24 * time.Hour, - TracingEnabled: true, - TraceRetention: 24 * time.Hour, - TraceExportInterval: 30 * time.Second, - EnableEnhancements: true, + // Flow Control defaults (Token Bucket strategy) + FlowControlStrategy: StrategyTokenBucket, + FlowControlConfigPath: "", + FlowControlEnvPrefix: "FLOW_", + MaxCredits: 1000, + MinCredits: 100, + CreditRefillRate: 10, + CreditRefillInterval: 100 * time.Millisecond, + TokenBucketCapacity: 1000, + TokenBucketRefillRate: 100, + TokenBucketRefillInterval: 100 * time.Millisecond, + LeakyBucketCapacity: 500, + LeakyBucketLeakInterval: 200 * time.Millisecond, + CreditBasedMaxCredits: 1000, + CreditBasedRefillRate: 100, + CreditBasedRefillInterval: 200 * time.Millisecond, + CreditBasedBurstSize: 50, + RateLimiterRequestsPerSecond: 100, + RateLimiterBurstSize: 200, + QueueDepthThreshold: 1000, + MemoryThreshold: 1 * 1024 * 1024 * 1024, // 1GB + ErrorRateThreshold: 0.5, + SnapshotInterval: 5 * time.Minute, + SnapshotRetention: 24 * time.Hour, + TracingEnabled: true, + TraceRetention: 24 * time.Hour, + TraceExportInterval: 30 * time.Second, + EnableEnhancements: true, } } diff --git a/examples/WAL_README.md b/examples/WAL_README.md deleted file mode 100644 index bc445b1..0000000 --- a/examples/WAL_README.md +++ /dev/null @@ -1,179 +0,0 @@ -# WAL (Write-Ahead Logging) System - -This directory contains a robust enterprise-grade WAL system implementation designed to prevent database overload from frequent task logging operations. - -## Overview - -The WAL system provides: -- **Buffered Logging**: High-frequency logging operations are buffered in memory -- **Batch Processing**: Periodic batch flushing to database for optimal performance -- **Crash Recovery**: Automatic recovery of unflushed entries on system restart -- **Performance Metrics**: Real-time monitoring of WAL operations -- **Graceful Shutdown**: Ensures data consistency during shutdown - -## Key Components - -### 1. WAL Manager (`dag/wal/wal.go`) -Core WAL functionality with buffering, segment management, and flush operations. - -### 2. WAL Storage (`dag/wal/storage.go`) -Database persistence layer for WAL entries and segments. - -### 3. WAL Recovery (`dag/wal/recovery.go`) -Crash recovery mechanisms to replay unflushed entries. - -### 4. WAL Factory (`dag/wal_factory.go`) -Factory for creating WAL-enabled storage instances. - -## Usage Example - -```go -package main - -import ( - "context" - "time" - - "github.com/oarkflow/mq/dag" - "github.com/oarkflow/mq/dag/storage" - "github.com/oarkflow/mq/dag/wal" - "github.com/oarkflow/mq/logger" -) - -func main() { - // Create logger - l := logger.NewDefaultLogger() - - // Create WAL-enabled storage factory - factory := dag.NewWALEnabledStorageFactory(l) - - // Configure WAL - walConfig := &wal.WALConfig{ - MaxBufferSize: 5000, // Buffer up to 5000 entries - FlushInterval: 2 * time.Second, // Flush every 2 seconds - MaxFlushRetries: 3, // Retry failed flushes - MaxSegmentSize: 10000, // 10K entries per segment - SegmentRetention: 48 * time.Hour, // Keep segments for 48 hours - WorkerCount: 4, // 4 flush workers - BatchSize: 500, // Batch 500 operations - EnableRecovery: true, // Enable crash recovery - EnableMetrics: true, // Enable metrics - } - - // Create WAL-enabled storage - storage, walManager, err := factory.CreateMemoryStorage(walConfig) - if err != nil { - panic(err) - } - defer storage.Close() - - // Create DAG with WAL-enabled storage - d := dag.NewDAG("My DAG", "my-dag", func(taskID string, result mq.Result) { - // Handle final results - }) - - // Set the WAL-enabled storage - d.SetTaskStorage(storage) - - // Now all logging operations will be buffered and batched - ctx := context.Background() - - // Create and log activities - these will be buffered - for i := 0; i < 1000; i++ { - task := &storage.PersistentTask{ - ID: fmt.Sprintf("task-%d", i), - DAGID: "my-dag", - Status: storage.TaskStatusRunning, - } - - // This will be buffered, not written immediately to DB - d.GetTaskStorage().SaveTask(ctx, task) - - // Activity logging will also be buffered - activity := &storage.TaskActivityLog{ - TaskID: task.ID, - DAGID: "my-dag", - Action: "processing", - Message: "Task is being processed", - } - d.GetTaskStorage().LogActivity(ctx, activity) - } - - // Get performance metrics - metrics := walManager.GetMetrics() - fmt.Printf("Buffered: %d, Flushed: %d\n", metrics.EntriesBuffered, metrics.EntriesFlushed) -} -``` - -## Configuration Options - -### WALConfig Fields - -- `MaxBufferSize`: Maximum entries to buffer before flush (default: 1000) -- `FlushInterval`: How often to flush buffer (default: 5s) -- `MaxFlushRetries`: Max retries for failed flushes (default: 3) -- `MaxSegmentSize`: Maximum entries per segment (default: 5000) -- `SegmentRetention`: How long to keep flushed segments (default: 24h) -- `WorkerCount`: Number of flush workers (default: 2) -- `BatchSize`: Batch size for database operations (default: 100) -- `EnableRecovery`: Enable crash recovery (default: true) -- `RecoveryTimeout`: Timeout for recovery operations (default: 30s) -- `EnableMetrics`: Enable metrics collection (default: true) -- `MetricsInterval`: Metrics collection interval (default: 10s) - -## Performance Benefits - -1. **Reduced Database Load**: Buffering prevents thousands of individual INSERT operations -2. **Batch Processing**: Database operations are performed in optimized batches -3. **Async Processing**: Logging doesn't block main application flow -4. **Configurable Buffering**: Tune buffer size based on your throughput needs -5. **Crash Recovery**: Never lose data even if system crashes - -## Integration with Task Manager - -The WAL system integrates seamlessly with the existing task manager: - -```go -// The task manager will automatically use WAL buffering -// when WAL-enabled storage is configured -taskManager := NewTaskManager(dag, taskID, resultCh, iterators, walStorage) - -// All activity logging will be buffered -taskManager.logActivity(ctx, "processing", "Task started processing") -``` - -## Monitoring - -Get real-time metrics about WAL performance: - -```go -metrics := walManager.GetMetrics() -fmt.Printf("Entries Buffered: %d\n", metrics.EntriesBuffered) -fmt.Printf("Entries Flushed: %d\n", metrics.EntriesFlushed) -fmt.Printf("Flush Operations: %d\n", metrics.FlushOperations) -fmt.Printf("Average Flush Time: %v\n", metrics.AverageFlushTime) -``` - -## Best Practices - -1. **Tune Buffer Size**: Set based on your expected logging frequency -2. **Monitor Metrics**: Keep an eye on buffer usage and flush performance -3. **Configure Retention**: Set appropriate segment retention for your needs -4. **Use Recovery**: Always enable recovery for production deployments -5. **Batch Size**: Optimize batch size based on your database capabilities - -## Database Support - -The WAL system supports: -- PostgreSQL -- SQLite -- MySQL (via storage interface) -- In-memory storage (for testing/development) - -## Error Handling - -The WAL system includes comprehensive error handling: -- Failed flushes are automatically retried -- Recovery process validates entries before replay -- Graceful degradation if storage is unavailable -- Detailed logging for troubleshooting diff --git a/examples/flow_control_example.go b/examples/flow_control_example.go new file mode 100644 index 0000000..5f0725a --- /dev/null +++ b/examples/flow_control_example.go @@ -0,0 +1,101 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/logger" +) + +func demonstrateFlowControl() { + // Create a logger + l := logger.NewDefaultLogger() + + // Example 1: Using the factory to create different flow controllers + factory := mq.NewFlowControllerFactory() + + // Create a token bucket flow controller + tokenBucketFC := factory.CreateTokenBucketFlowController(1000, 10, 100*time.Millisecond, l) + fmt.Println("Created Token Bucket Flow Controller") + fmt.Printf("Stats: %+v\n", tokenBucketFC.GetStats()) + + // Create a leaky bucket flow controller + leakyBucketFC := factory.CreateLeakyBucketFlowController(500, 200*time.Millisecond, l) + fmt.Println("Created Leaky Bucket Flow Controller") + fmt.Printf("Stats: %+v\n", leakyBucketFC.GetStats()) + + // Create a credit-based flow controller + creditBasedFC := factory.CreateCreditBasedFlowController(1000, 100, 5, 200*time.Millisecond, l) + fmt.Println("Created Credit-Based Flow Controller") + fmt.Printf("Stats: %+v\n", creditBasedFC.GetStats()) + + // Create a rate limiter flow controller + rateLimiterFC := factory.CreateRateLimiterFlowController(50, 100, l) + fmt.Println("Created Rate Limiter Flow Controller") + fmt.Printf("Stats: %+v\n", rateLimiterFC.GetStats()) + + // Example 2: Using configuration providers + fmt.Println("\n--- Configuration Providers ---") + + // Environment-based configuration + envProvider := mq.NewEnvConfigProvider("FLOW_") + envConfig, err := envProvider.GetConfig() + if err != nil { + log.Printf("Error loading env config: %v", err) + } else { + fmt.Printf("Environment Config: %+v\n", envConfig) + } + + // Composite configuration (environment overrides defaults) + compositeProvider := mq.NewCompositeConfigProvider(envProvider) + compositeConfig, err := compositeProvider.GetConfig() + if err != nil { + log.Printf("Error loading composite config: %v", err) + } else { + compositeFC, err := factory.CreateFlowController(compositeConfig) + if err != nil { + log.Printf("Error creating flow controller: %v", err) + } else { + fmt.Printf("Composite Config Flow Controller Stats: %+v\n", compositeFC.GetStats()) + } + } + + // Example 3: Using the flow controllers + fmt.Println("\n--- Flow Controller Usage ---") + + ctx := context.Background() + + // Test token bucket + fmt.Println("Testing Token Bucket...") + for i := 0; i < 5; i++ { + if err := tokenBucketFC.AcquireCredit(ctx, 50); err != nil { + fmt.Printf("Token bucket acquire failed: %v\n", err) + } else { + fmt.Printf("Token bucket acquired 50 credits, remaining: %d\n", tokenBucketFC.GetAvailableCredits()) + tokenBucketFC.ReleaseCredit(25) // Release some credits + } + time.Sleep(50 * time.Millisecond) + } + + // Test leaky bucket + fmt.Println("Testing Leaky Bucket...") + for i := 0; i < 3; i++ { + if err := leakyBucketFC.AcquireCredit(ctx, 100); err != nil { + fmt.Printf("Leaky bucket acquire failed: %v\n", err) + } else { + fmt.Printf("Leaky bucket acquired 100 credits, remaining: %d\n", leakyBucketFC.GetAvailableCredits()) + } + time.Sleep(100 * time.Millisecond) + } + + // Cleanup + tokenBucketFC.Shutdown() + leakyBucketFC.Shutdown() + creditBasedFC.Shutdown() + rateLimiterFC.Shutdown() + + fmt.Println("Flow control example completed!") +} diff --git a/examples/flow_control_integration.go b/examples/flow_control_integration.go new file mode 100644 index 0000000..a70b005 --- /dev/null +++ b/examples/flow_control_integration.go @@ -0,0 +1,118 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/logger" +) + +func main() { + testFlowControlIntegration() +} + +func testFlowControlIntegration() { + // Create a logger + l := logger.NewDefaultLogger() + + fmt.Println("Testing Flow Control Factory and Configuration Providers...") + + // Test 1: Factory creates different strategies + factory := mq.NewFlowControllerFactory() + + strategies := []struct { + name string + createFunc func() *mq.FlowController + }{ + {"Token Bucket", func() *mq.FlowController { + return factory.CreateTokenBucketFlowController(100, 10, 100*time.Millisecond, l) + }}, + {"Leaky Bucket", func() *mq.FlowController { + return factory.CreateLeakyBucketFlowController(50, 200*time.Millisecond, l) + }}, + {"Credit Based", func() *mq.FlowController { + return factory.CreateCreditBasedFlowController(200, 20, 10, 150*time.Millisecond, l) + }}, + {"Rate Limiter", func() *mq.FlowController { + return factory.CreateRateLimiterFlowController(50, 100, l) + }}, + } + + for _, s := range strategies { + fmt.Printf("\n--- Testing %s ---\n", s.name) + fc := s.createFunc() + if fc == nil { + fmt.Printf("✗ Failed to create %s flow controller\n", s.name) + continue + } + + fmt.Printf("✓ Created %s flow controller\n", s.name) + stats := fc.GetStats() + fmt.Printf(" Initial stats: %+v\n", stats) + + // Test acquiring credits + ctx := context.Background() + err := fc.AcquireCredit(ctx, 5) + if err != nil { + fmt.Printf(" ✗ Failed to acquire credits: %v\n", err) + } else { + fmt.Printf(" ✓ Successfully acquired 5 credits\n") + stats = fc.GetStats() + fmt.Printf(" Stats after acquire: %+v\n", stats) + + // Release credits + fc.ReleaseCredit(3) + fmt.Printf(" ✓ Released 3 credits\n") + stats = fc.GetStats() + fmt.Printf(" Stats after release: %+v\n", stats) + } + + fc.Shutdown() + } + + // Test 2: Configuration providers + fmt.Println("\n--- Testing Configuration Providers ---") + + // Test environment provider (will likely fail since no env vars set) + envProvider := mq.NewEnvConfigProvider("TEST_FLOW_") + envConfig, err := envProvider.GetConfig() + if err != nil { + fmt.Printf("✓ Environment config correctly failed (no env vars): %v\n", err) + } else { + fmt.Printf("Environment config: %+v\n", envConfig) + } + + // Test composite provider + compositeProvider := mq.NewCompositeConfigProvider(envProvider) + compositeConfig, err := compositeProvider.GetConfig() + if err != nil { + fmt.Printf("✓ Composite config correctly failed: %v\n", err) + } else { + fmt.Printf("Composite config: %+v\n", compositeConfig) + } + + // Test 3: Factory with config + fmt.Println("\n--- Testing Factory with Config ---") + config := mq.FlowControlConfig{ + Strategy: mq.StrategyTokenBucket, + MaxCredits: 100, + RefillRate: 10, + RefillInterval: 100 * time.Millisecond, + Logger: l, + } + + fc, err := factory.CreateFlowController(config) + if err != nil { + log.Printf("Failed to create flow controller with config: %v", err) + } else { + fmt.Printf("✓ Created flow controller with config\n") + stats := fc.GetStats() + fmt.Printf(" Config-based stats: %+v\n", stats) + fc.Shutdown() + } + + fmt.Println("\nFlow control integration test completed successfully!") +} diff --git a/examples/reset_to_example.go b/examples/reset_to_example.go deleted file mode 100644 index 1eeccc9..0000000 --- a/examples/reset_to_example.go +++ /dev/null @@ -1,97 +0,0 @@ -package main - -import ( - "context" - "fmt" - "log" - - "github.com/oarkflow/json" - "github.com/oarkflow/mq" - "github.com/oarkflow/mq/dag" -) - -// ResetToExample demonstrates the ResetTo functionality -type ResetToExample struct { - dag.Operation -} - -func (r *ResetToExample) Process(ctx context.Context, task *mq.Task) mq.Result { - payload := string(task.Payload) - log.Printf("Processing node %s with payload: %s", task.Topic, payload) - - // Simulate some processing logic - if task.Topic == "step1" { - // For step1, we'll return a result that resets to step2 - return mq.Result{ - Status: mq.Completed, - Payload: json.RawMessage(`{"message": "Step 1 completed, resetting to step2"}`), - Ctx: ctx, - TaskID: task.ID, - Topic: task.Topic, - ResetTo: "step2", // Reset to step2 - } - } else if task.Topic == "step2" { - // For step2, we'll return a result that resets to the previous page node - return mq.Result{ - Status: mq.Completed, - Payload: json.RawMessage(`{"message": "Step 2 completed, resetting to back"}`), - Ctx: ctx, - TaskID: task.ID, - Topic: task.Topic, - ResetTo: "back", // Reset to previous page node - } - } else if task.Topic == "step3" { - // Final step - return mq.Result{ - Status: mq.Completed, - Payload: json.RawMessage(`{"message": "Step 3 completed - final result"}`), - Ctx: ctx, - TaskID: task.ID, - Topic: task.Topic, - } - } - - return mq.Result{ - Status: mq.Failed, - Error: fmt.Errorf("unknown step: %s", task.Topic), - Ctx: ctx, - TaskID: task.ID, - Topic: task.Topic, - } -} - -func runResetToExample() { - // Create a DAG with ResetTo functionality - flow := dag.NewDAG("ResetTo Example", "reset-to-example", func(taskID string, result mq.Result) { - log.Printf("Final result for task %s: %s", taskID, string(result.Payload)) - }) - - // Add nodes - flow.AddNode(dag.Function, "Step 1", "step1", &ResetToExample{}, true) - flow.AddNode(dag.Page, "Step 2", "step2", &ResetToExample{}) - flow.AddNode(dag.Page, "Step 3", "step3", &ResetToExample{}) - - // Add edges - flow.AddEdge(dag.Simple, "Step 1 to Step 2", "step1", "step2") - flow.AddEdge(dag.Simple, "Step 2 to Step 3", "step2", "step3") - - // Validate the DAG - if err := flow.Validate(); err != nil { - log.Fatalf("DAG validation failed: %v", err) - } - - // Process a task - data := json.RawMessage(`{"initial": "data"}`) - log.Println("Starting DAG processing...") - result := flow.Process(context.Background(), data) - - if result.Error != nil { - log.Printf("Processing failed: %v", result.Error) - } else { - log.Printf("Processing completed successfully: %s", string(result.Payload)) - } -} - -func main() { - runResetToExample() -} diff --git a/examples/v2.go b/examples/v2.go new file mode 100644 index 0000000..2043768 --- /dev/null +++ b/examples/v2.go @@ -0,0 +1,649 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" +) + +// ---------------------- Public API Interfaces ---------------------- + +type Processor func(ctx context.Context, in any) (any, error) + +type Node interface { + ID() string + Start(ctx context.Context, in <-chan any) <-chan any +} + +type Pipeline interface { + Start(ctx context.Context, inputs <-chan any) (<-chan any, error) +} + +// ---------------------- Processor Registry ---------------------- + +var procRegistry = map[string]Processor{} + +func RegisterProcessor(name string, p Processor) { + procRegistry[name] = p +} + +func GetProcessor(name string) (Processor, bool) { + p, ok := procRegistry[name] + return p, ok +} + +// ---------------------- Ring Buffer (SPSC lock-free) ---------------------- + +type RingBuffer struct { + buf []any + mask uint64 + head uint64 + tail uint64 +} + +func NewRingBuffer(size uint64) *RingBuffer { + if size == 0 || (size&(size-1)) != 0 { + panic("ring size must be power of two") + } + return &RingBuffer{buf: make([]any, size), mask: size - 1} +} + +func (r *RingBuffer) Push(v any) bool { + t := atomic.LoadUint64(&r.tail) + h := atomic.LoadUint64(&r.head) + if t-h == uint64(len(r.buf)) { + return false + } + r.buf[t&r.mask] = v + atomic.AddUint64(&r.tail, 1) + return true +} + +func (r *RingBuffer) Pop() (any, bool) { + h := atomic.LoadUint64(&r.head) + t := atomic.LoadUint64(&r.tail) + if t == h { + return nil, false + } + v := r.buf[h&r.mask] + atomic.AddUint64(&r.head, 1) + return v, true +} + +// ---------------------- Node Implementations ---------------------- + +type ChannelNode struct { + id string + processor Processor + buf int + workers int +} + +func NewChannelNode(id string, proc Processor, buf int, workers int) *ChannelNode { + if buf <= 0 { + buf = 64 + } + if workers <= 0 { + workers = 1 + } + return &ChannelNode{id: id, processor: proc, buf: buf, workers: workers} +} + +func (c *ChannelNode) ID() string { return c.id } + +func (c *ChannelNode) Start(ctx context.Context, in <-chan any) <-chan any { + out := make(chan any, c.buf) + var wg sync.WaitGroup + for i := 0; i < c.workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case v, ok := <-in: + if !ok { + return + } + res, err := c.processor(ctx, v) + if err != nil { + fmt.Fprintf(os.Stderr, "processor %s error: %v\n", c.id, err) + continue + } + select { + case out <- res: + case <-ctx.Done(): + return + } + } + } + }() + } + go func() { + wg.Wait() + close(out) + }() + return out +} + +type PageNode struct { + id string + processor Processor + buf int + workers int +} + +func NewPageNode(id string, proc Processor, buf int, workers int) *PageNode { + if buf <= 0 { + buf = 64 + } + if workers <= 0 { + workers = 1 + } + return &PageNode{id: id, processor: proc, buf: buf, workers: workers} +} + +func (c *PageNode) ID() string { return c.id } + +func (c *PageNode) Start(ctx context.Context, in <-chan any) <-chan any { + out := make(chan any, c.buf) + var wg sync.WaitGroup + for i := 0; i < c.workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case v, ok := <-in: + if !ok { + return + } + res, err := c.processor(ctx, v) + if err != nil { + fmt.Fprintf(os.Stderr, "processor %s error: %v\n", c.id, err) + continue + } + select { + case out <- res: + case <-ctx.Done(): + return + } + } + } + }() + } + go func() { + wg.Wait() + close(out) + }() + return out +} + +type RingNode struct { + id string + processor Processor + size uint64 +} + +func NewRingNode(id string, proc Processor, size uint64) *RingNode { + if size == 0 { + size = 1024 + } + n := uint64(1) + for n < size { + n <<= 1 + } + return &RingNode{id: id, processor: proc, size: n} +} + +func (r *RingNode) ID() string { return r.id } + +func (r *RingNode) Start(ctx context.Context, in <-chan any) <-chan any { + out := make(chan any, 64) + ring := NewRingBuffer(r.size) + done := make(chan struct{}) + go func() { + defer close(done) + for { + select { + case <-ctx.Done(): + return + case v, ok := <-in: + if !ok { + return + } + for !ring.Push(v) { + time.Sleep(time.Microsecond) + select { + case <-ctx.Done(): + return + default: + } + } + } + } + }() + go func() { + defer close(out) + for { + select { + case <-ctx.Done(): + return + case <-done: + // process remaining items in ring + for { + v, ok := ring.Pop() + if !ok { + return + } + res, err := r.processor(ctx, v) + if err != nil { + fmt.Fprintf(os.Stderr, "processor %s error: %v\n", r.id, err) + continue + } + select { + case out <- res: + case <-ctx.Done(): + return + } + } + default: + v, ok := ring.Pop() + if !ok { + time.Sleep(time.Microsecond) + continue + } + res, err := r.processor(ctx, v) + if err != nil { + fmt.Fprintf(os.Stderr, "processor %s error: %v\n", r.id, err) + continue + } + select { + case out <- res: + case <-ctx.Done(): + return + } + } + } + }() + return out +} + +// ---------------------- DAG Pipeline ---------------------- + +type NodeSpec struct { + ID string `json:"id"` + Type string `json:"type"` + Processor string `json:"processor"` + Buf int `json:"buf,omitempty"` + Workers int `json:"workers,omitempty"` + RingSize uint64 `json:"ring_size,omitempty"` +} + +type EdgeSpec struct { + Source string `json:"source"` + Targets []string `json:"targets"` + Type string `json:"type,omitempty"` +} + +type PipelineSpec struct { + Nodes []NodeSpec `json:"nodes"` + Edges []EdgeSpec `json:"edges"` + EntryIDs []string `json:"entry_ids,omitempty"` + Conditions map[string]map[string]string `json:"conditions,omitempty"` +} + +type DAGPipeline struct { + nodes map[string]Node + edges map[string][]EdgeSpec + rev map[string][]string + entry []string + conditions map[string]map[string]string +} + +func NewDAGPipeline() *DAGPipeline { + return &DAGPipeline{ + nodes: map[string]Node{}, + edges: map[string][]EdgeSpec{}, + rev: map[string][]string{}, + conditions: map[string]map[string]string{}, + } +} + +func (d *DAGPipeline) AddNode(n Node) { + d.nodes[n.ID()] = n +} + +func (d *DAGPipeline) AddEdge(from string, tos []string, typ string) { + if typ == "" { + typ = "simple" + } + e := EdgeSpec{Source: from, Targets: tos, Type: typ} + d.edges[from] = append(d.edges[from], e) + for _, to := range tos { + d.rev[to] = append(d.rev[to], from) + } +} + +func (d *DAGPipeline) AddCondition(id string, cond map[string]string) { + d.conditions[id] = cond + for _, to := range cond { + d.rev[to] = append(d.rev[to], id) + } +} + +func (d *DAGPipeline) Start(ctx context.Context, inputs <-chan any) (<-chan any, error) { + nCh := map[string]chan any{} + outCh := map[string]<-chan any{} + wgMap := map[string]*sync.WaitGroup{} + for id := range d.nodes { + nCh[id] = make(chan any, 128) + wgMap[id] = &sync.WaitGroup{} + } + if len(d.entry) == 0 { + for id := range d.nodes { + if len(d.rev[id]) == 0 { + d.entry = append(d.entry, id) + } + } + } + for id, node := range d.nodes { + in := nCh[id] + out := node.Start(ctx, in) + outCh[id] = out + if cond, ok := d.conditions[id]; ok { + go func(o <-chan any, cond map[string]string) { + for v := range o { + if m, ok := v.(map[string]any); ok { + if status, ok := m["condition_status"].(string); ok { + if target, ok := cond[status]; ok { + wgMap[target].Add(1) + go func(c chan any, v any, wg *sync.WaitGroup) { + defer wg.Done() + select { + case c <- v: + case <-ctx.Done(): + } + }(nCh[target], v, wgMap[target]) + } + } + } + } + }(out, cond) + } else { + for _, e := range d.edges[id] { + for _, dep := range e.Targets { + if e.Type == "iterator" { + go func(o <-chan any, c chan any, wg *sync.WaitGroup) { + for v := range o { + if arr, ok := v.([]any); ok { + for _, item := range arr { + wg.Add(1) + go func(item any) { + defer wg.Done() + select { + case c <- item: + case <-ctx.Done(): + } + }(item) + } + } + } + }(out, nCh[dep], wgMap[dep]) + } else { + wgMap[dep].Add(1) + go func(o <-chan any, c chan any, wg *sync.WaitGroup) { + defer wg.Done() + for v := range o { + select { + case c <- v: + case <-ctx.Done(): + return + } + } + }(out, nCh[dep], wgMap[dep]) + } + } + } + } + } + for _, id := range d.entry { + wgMap[id].Add(1) + } + go func() { + defer func() { + for _, id := range d.entry { + wgMap[id].Done() + } + }() + for v := range inputs { + for _, id := range d.entry { + select { + case nCh[id] <- v: + case <-ctx.Done(): + return + } + } + } + }() + for id, wg := range wgMap { + go func(id string, wg *sync.WaitGroup, ch chan any) { + time.Sleep(time.Millisecond) + wg.Wait() + close(ch) + }(id, wg, nCh[id]) + } + finalOut := make(chan any, 128) + var wg sync.WaitGroup + for id := range d.nodes { + if len(d.edges[id]) == 0 && len(d.conditions[id]) == 0 { + wg.Add(1) + go func(o <-chan any) { + defer wg.Done() + for v := range o { + select { + case finalOut <- v: + case <-ctx.Done(): + return + } + } + }(outCh[id]) + } + } + go func() { + wg.Wait() + close(finalOut) + }() + return finalOut, nil +} + +func BuildDAGFromSpec(spec PipelineSpec) (*DAGPipeline, error) { + d := NewDAGPipeline() + for _, ns := range spec.Nodes { + proc, ok := GetProcessor(ns.Processor) + if !ok { + return nil, fmt.Errorf("processor %s not registered", ns.Processor) + } + var node Node + switch ns.Type { + case "channel": + node = NewChannelNode(ns.ID, proc, ns.Buf, ns.Workers) + case "ring": + node = NewRingNode(ns.ID, proc, ns.RingSize) + case "page": + node = NewPageNode(ns.ID, proc, ns.Buf, ns.Workers) + default: + return nil, fmt.Errorf("unknown node type %s", ns.Type) + } + d.AddNode(node) + } + for _, e := range spec.Edges { + if _, ok := d.nodes[e.Source]; !ok { + return nil, fmt.Errorf("edge source %s not found", e.Source) + } + for _, tgt := range e.Targets { + if _, ok := d.nodes[tgt]; !ok { + return nil, fmt.Errorf("edge target %s not found", tgt) + } + } + d.AddEdge(e.Source, e.Targets, e.Type) + } + if len(spec.EntryIDs) > 0 { + d.entry = spec.EntryIDs + } + if spec.Conditions != nil { + for id, cond := range spec.Conditions { + d.AddCondition(id, cond) + } + } + return d, nil +} + +// ---------------------- Example Processors ---------------------- + +func doubleProc(ctx context.Context, in any) (any, error) { + switch v := in.(type) { + case int: + return v * 2, nil + case float64: + return v * 2, nil + default: + return nil, errors.New("unsupported type for double") + } +} + +func incProc(ctx context.Context, in any) (any, error) { + if n, ok := in.(int); ok { + return n + 1, nil + } + return nil, errors.New("inc: not int") +} + +func printProc(ctx context.Context, in any) (any, error) { + fmt.Printf("OUTPUT: %#v\n", in) + return in, nil +} + +func getDataProc(ctx context.Context, in any) (any, error) { + return in, nil +} + +func loopProc(ctx context.Context, in any) (any, error) { + return in, nil +} + +func validateAgeProc(ctx context.Context, in any) (any, error) { + m, ok := in.(map[string]any) + if !ok { + return nil, errors.New("not map") + } + age, ok := m["age"].(float64) + if !ok { + return nil, errors.New("no age") + } + status := "default" + if age >= 18 { + status = "pass" + } + m["condition_status"] = status + return m, nil +} + +func validateGenderProc(ctx context.Context, in any) (any, error) { + m, ok := in.(map[string]any) + if !ok { + return nil, errors.New("not map") + } + gender, ok := m["gender"].(string) + if !ok { + return nil, errors.New("no gender") + } + m["female_voter"] = gender == "female" + return m, nil +} + +func finalProc(ctx context.Context, in any) (any, error) { + m, ok := in.(map[string]any) + if !ok { + return nil, errors.New("not map") + } + m["done"] = true + return m, nil +} + +// ---------------------- Main Demo ---------------------- + +func main() { + RegisterProcessor("double", doubleProc) + RegisterProcessor("inc", incProc) + RegisterProcessor("print", printProc) + RegisterProcessor("getData", getDataProc) + RegisterProcessor("loop", loopProc) + RegisterProcessor("validateAge", validateAgeProc) + RegisterProcessor("validateGender", validateGenderProc) + RegisterProcessor("final", finalProc) + + jsonSpec := `{ + "nodes": [ + {"id":"getData","type":"channel","processor":"getData"}, + {"id":"loop","type":"channel","processor":"loop"}, + {"id":"validateAge","type":"channel","processor":"validateAge"}, + {"id":"validateGender","type":"channel","processor":"validateGender"}, + {"id":"final","type":"channel","processor":"final"} + ], + "edges": [ + {"source":"getData","targets":["loop"],"type":"simple"}, + {"source":"loop","targets":["validateAge"],"type":"iterator"}, + {"source":"validateGender","targets":["final"],"type":"simple"} + ], + "entry_ids":["getData"], + "conditions": { + "validateAge": {"pass": "validateGender", "default": "final"} + } + }` + + var spec PipelineSpec + if err := json.Unmarshal([]byte(jsonSpec), &spec); err != nil { + panic(err) + } + dag, err := BuildDAGFromSpec(spec) + if err != nil { + panic(err) + } + + in := make(chan any) + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + out, err := dag.Start(ctx, in) + if err != nil { + panic(err) + } + + go func() { + data := []any{ + map[string]any{"age": 15.0, "gender": "female"}, + map[string]any{"age": 18.0, "gender": "male"}, + } + in <- data + close(in) + }() + + var results []any + for r := range out { + results = append(results, r) + } + + fmt.Println("Final results:", results) + fmt.Println("pipeline finished") +}