diff --git a/.gitignore b/.gitignore index ad4e1f2..fb61f0f 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ go.work *.svg .qodo .history +*.log diff --git a/ENHANCED_SERVICES_README.md b/ENHANCED_SERVICES_README.md deleted file mode 100644 index c79a7c3..0000000 --- a/ENHANCED_SERVICES_README.md +++ /dev/null @@ -1,416 +0,0 @@ -# Enhanced Services with DAG + Workflow Engine - -## Overview - -The enhanced services architecture successfully integrates all workflow engine features into the DAG system, providing complete feature parity and backward compatibility. This upgrade provides both traditional DAG functionality and advanced workflow capabilities through a unified service layer. - -## Architecture Components - -### 1. Enhanced Service Manager (`enhanced_setup.go`) -- **Purpose**: Core service orchestration with DAG + workflow integration -- **Features**: - - Dual-mode execution (Traditional DAG + Enhanced Workflow) - - HTTP API endpoints for workflow management - - Enhanced validation with workflow rule support - - Service health monitoring and metrics - - Background task management - -### 2. Enhanced Contracts (`enhanced_contracts.go`) -- **Purpose**: Service interfaces for DAG + workflow integration -- **Key Interfaces**: - - `EnhancedServiceManager`: Core service management - - `EnhancedDAGService`: Dual-mode DAG operations - - `EnhancedValidation`: Workflow validation rules - - `EnhancedHandler`: Unified handler structure - -### 3. Enhanced DAG Service (`enhanced_dag_service.go`) -- **Purpose**: DAG service with workflow engine capabilities -- **Features**: - - Traditional DAG execution (backward compatibility) - - Enhanced workflow execution with advanced processors - - State management and persistence - - Execution result handling with proper field mapping - -### 4. Enhanced Validation (`enhanced_validation.go`) -- **Purpose**: Validation service with workflow rule support -- **Features**: - - Schema validation with workflow rules - - Field-level validation (string, email, numeric, etc.) - - Custom validation logic with processor integration - - Validation result aggregation - -## Features Implemented - -### Complete Workflow Engine Integration ✅ -All 8 advanced processors from the workflow engine are now available in the DAG system: - -1. **Validator Processor**: Schema and field validation -2. **Router Processor**: Conditional routing and decision making -3. **Transformer Processor**: Data transformation and mapping -4. **Aggregator Processor**: Data aggregation and summarization -5. **Filter Processor**: Data filtering and selection -6. **Sorter Processor**: Data sorting and ordering -7. **Notify Processor**: Notification and messaging -8. **Storage Processor**: Data persistence and retrieval - -### Enhanced DAG Capabilities ✅ -- **Dual Mode Support**: Both traditional DAG and workflow modes -- **Advanced Retry Logic**: Exponential backoff with circuit breaker -- **State Management**: Persistent execution state tracking -- **Scheduling**: Background task scheduling and execution -- **Security**: Authentication and authorization support -- **Middleware**: Pre/post execution hooks -- **Metrics**: Performance monitoring and reporting - -### HTTP API Integration ✅ -Complete REST API for workflow management: -- `GET /api/v1/handlers` - List all handlers -- `POST /api/v1/execute/:key` - Execute workflow by key -- `GET /api/v1/workflows` - List workflow instances -- `POST /api/v1/workflows/:id/execute` - Execute specific workflow -- `GET /health` - Service health check - -### Validation System ✅ -Enhanced validation with workflow rule support: -- Field-level validation rules -- Type checking (string, email, numeric, etc.) -- Length constraints (min/max) -- Required field validation -- Custom validation messages -- Validation result aggregation - -## Usage Examples - -### 1. Traditional DAG Mode (Backward Compatibility) -```go -// Traditional DAG handler -handler := services.EnhancedHandler{ - Key: "traditional-dag", - Name: "Traditional DAG", - WorkflowEnabled: false, // Use traditional DAG mode - Nodes: []services.EnhancedNode{ - { - ID: "start", - Name: "Start Process", - Node: "basic", - FirstNode: true, - }, - { - ID: "process", - Name: "Process Data", - Node: "basic", - }, - }, - Edges: []services.Edge{ - {Source: "start", Target: []string{"process"}}, - }, -} -``` - -### 2. Enhanced Workflow Mode -```go -// Enhanced workflow handler with processors -handler := services.EnhancedHandler{ - Key: "enhanced-workflow", - Name: "Enhanced Workflow", - WorkflowEnabled: true, // Use enhanced workflow mode - ValidationRules: []*dag.WorkflowValidationRule{ - { - Field: "email", - Type: "email", - Required: true, - Message: "Valid email is required", - }, - }, - Nodes: []services.EnhancedNode{ - { - ID: "validate-input", - Name: "Validate Input", - Type: "validator", - ProcessorType: "validator", - }, - { - ID: "route-data", - Name: "Route Decision", - Type: "router", - ProcessorType: "router", - }, - { - ID: "transform-data", - Name: "Transform Data", - Type: "transformer", - ProcessorType: "transformer", - }, - }, - Edges: []services.Edge{ - {Source: "validate-input", Target: []string{"route-data"}}, - {Source: "route-data", Target: []string{"transform-data"}}, - }, -} -``` - -### 3. Service Configuration -```go -config := &services.EnhancedServiceConfig{ - BrokerURL: "nats://localhost:4222", - Debug: true, - - // Enhanced DAG configuration - EnhancedDAGConfig: &dag.EnhancedDAGConfig{ - EnableWorkflowEngine: true, - MaintainDAGMode: true, - EnableStateManagement: true, - EnableAdvancedRetry: true, - EnableCircuitBreaker: true, - MaxConcurrentExecutions: 10, - DefaultTimeout: 30 * time.Second, - }, - - // Workflow engine configuration - WorkflowEngineConfig: &dag.WorkflowEngineConfig{ - MaxConcurrentExecutions: 5, - DefaultTimeout: 2 * time.Minute, - EnablePersistence: true, - EnableSecurity: true, - RetryConfig: &dag.RetryConfig{ - MaxRetries: 3, - InitialDelay: 1 * time.Second, - BackoffFactor: 2.0, - }, - }, -} -``` - -### 4. Service Initialization -```go -// Create enhanced service manager -manager := services.NewEnhancedServiceManager(config) - -// Initialize services -if err := manager.Initialize(config); err != nil { - log.Fatalf("Failed to initialize services: %v", err) -} - -// Start services -ctx := context.Background() -if err := manager.Start(ctx); err != nil { - log.Fatalf("Failed to start services: %v", err) -} -defer manager.Stop(ctx) - -// Register handlers -for _, handler := range handlers { - if err := manager.RegisterEnhancedHandler(handler); err != nil { - log.Printf("Failed to register handler %s: %v", handler.Key, err) - } -} -``` - -### 5. HTTP API Setup -```go -// Create Fiber app -app := fiber.New() - -// Register HTTP routes -if err := manager.RegisterHTTPRoutes(app); err != nil { - log.Fatalf("Failed to register HTTP routes: %v", err) -} - -// Start server -log.Fatal(app.Listen(":3000")) -``` - -### 6. Workflow Execution -```go -// Execute workflow programmatically -ctx := context.Background() -input := map[string]any{ - "name": "John Doe", - "email": "john@example.com", -} - -result, err := manager.ExecuteEnhancedWorkflow(ctx, "enhanced-workflow", input) -if err != nil { - log.Printf("Execution failed: %v", err) -} else { - log.Printf("Execution completed: %s (Status: %s)", result.ID, result.Status) -} -``` - -## HTTP API Usage - -### Execute Workflow via REST API -```bash -# Execute workflow with POST request -curl -X POST http://localhost:3000/api/v1/execute/enhanced-workflow \ - -H "Content-Type: application/json" \ - -d '{ - "name": "John Doe", - "email": "john@example.com", - "age": 30 - }' -``` - -### List Available Handlers -```bash -# Get list of registered handlers -curl -X GET http://localhost:3000/api/v1/handlers -``` - -### Health Check -```bash -# Check service health -curl -X GET http://localhost:3000/health -``` - -## Advanced Features - -### 1. Validation Rules -The enhanced validation system supports comprehensive field validation: - -```go -ValidationRules: []*dag.WorkflowValidationRule{ - { - Field: "name", - Type: "string", - Required: true, - MinLength: 2, - MaxLength: 50, - Message: "Name must be 2-50 characters", - }, - { - Field: "email", - Type: "email", - Required: true, - Message: "Valid email is required", - }, - { - Field: "age", - Type: "number", - Min: 18, - Max: 120, - Message: "Age must be between 18 and 120", - }, -} -``` - -### 2. Processor Configuration -Each processor can be configured with specific parameters: - -```go -Config: dag.WorkflowNodeConfig{ - // Validator processor config - ValidationType: "schema", - ValidationRules: []dag.WorkflowValidationRule{...}, - - // Router processor config - RoutingRules: []dag.RoutingRule{...}, - - // Transformer processor config - TransformationRules: []dag.TransformationRule{...}, - - // Storage processor config - StorageType: "memory", - StorageConfig: map[string]any{...}, -} -``` - -### 3. Error Handling and Retry -Built-in retry logic with exponential backoff: - -```go -RetryConfig: &dag.RetryConfig{ - MaxRetries: 3, - InitialDelay: 1 * time.Second, - MaxDelay: 30 * time.Second, - BackoffFactor: 2.0, -} -``` - -### 4. State Management -Persistent execution state tracking: - -```go -EnhancedDAGConfig: &dag.EnhancedDAGConfig{ - EnableStateManagement: true, - EnablePersistence: true, -} -``` - -## Migration Guide - -### From Traditional DAG to Enhanced Services - -1. **Keep existing DAG handlers**: Set `WorkflowEnabled: false` -2. **Add enhanced features gradually**: Create new handlers with `WorkflowEnabled: true` -3. **Use validation rules**: Add `ValidationRules` for input validation -4. **Configure processors**: Set appropriate `ProcessorType` for each node -5. **Test both modes**: Verify traditional and enhanced workflows work correctly - -### Configuration Migration -```go -// Before (traditional) -config := &services.ServiceConfig{ - BrokerURL: "nats://localhost:4222", -} - -// After (enhanced) -config := &services.EnhancedServiceConfig{ - BrokerURL: "nats://localhost:4222", - EnhancedDAGConfig: &dag.EnhancedDAGConfig{ - EnableWorkflowEngine: true, - MaintainDAGMode: true, // Keep backward compatibility - }, -} -``` - -## Performance Considerations - -1. **Concurrent Executions**: Configure `MaxConcurrentExecutions` based on system resources -2. **Timeout Settings**: Set appropriate `DefaultTimeout` for workflow complexity -3. **Retry Strategy**: Balance retry attempts with system load -4. **State Management**: Enable persistence only when needed -5. **Metrics**: Monitor performance with built-in metrics - -## Troubleshooting - -### Common Issues - -1. **Handler Registration Fails** - - Check validation rules syntax - - Verify processor types are valid - - Ensure node dependencies are correct - -2. **Workflow Execution Errors** - - Validate input data format - - Check processor configurations - - Review error logs for details - -3. **HTTP API Issues** - - Verify routes are registered correctly - - Check request format and headers - - Review service health status - -### Debug Mode -Enable debug mode for detailed logging: -```go -config := &services.EnhancedServiceConfig{ - Debug: true, - // ... other config -} -``` - -## Conclusion - -The enhanced services architecture successfully provides complete feature parity between the DAG system and workflow engine. All workflow engine features are now available in the DAG system while maintaining full backward compatibility with existing traditional DAG implementations. - -Key achievements: -- ✅ Complete workflow engine integration (8 advanced processors) -- ✅ Dual-mode support (traditional DAG + enhanced workflow) -- ✅ HTTP API for workflow management -- ✅ Enhanced validation with workflow rules -- ✅ Service health monitoring and metrics -- ✅ Backward compatibility maintained -- ✅ Production-ready architecture - -The system now provides a unified, powerful, and flexible platform for both simple DAG operations and complex workflow orchestration. diff --git a/WORKFLOW_ENGINE_COMPLETE.md b/WORKFLOW_ENGINE_COMPLETE.md deleted file mode 100644 index 3f92360..0000000 --- a/WORKFLOW_ENGINE_COMPLETE.md +++ /dev/null @@ -1,469 +0,0 @@ -# Complete Workflow Engine Documentation - -## Overview - -This is a **production-ready, enterprise-grade workflow engine** built on top of the existing DAG system. It provides comprehensive workflow orchestration capabilities with support for complex business processes, data pipelines, approval workflows, and automated task execution. - -## 🎯 Key Features - -### Core Capabilities -- ✅ **Workflow Definition & Management** - JSON-based workflow definitions with versioning -- ✅ **Multi-Node Type Support** - Task, API, Transform, Decision, Human Task, Timer, Loop, Parallel, Database, Email, Webhook -- ✅ **Advanced Execution Engine** - DAG-based execution with state management and error handling -- ✅ **Flexible Scheduling** - Support for immediate, delayed, and conditional execution -- ✅ **RESTful API** - Complete HTTP API for workflow management and execution -- ✅ **Real-time Monitoring** - Execution tracking, metrics, and health monitoring -- ✅ **Error Handling & Recovery** - Retry policies, rollback support, and checkpoint recovery - -### Enterprise Features -- ✅ **Scalable Architecture** - Worker pool management and concurrent execution -- ✅ **Data Persistence** - In-memory storage with extensible storage interface -- ✅ **Security Framework** - Authentication, authorization, and CORS support -- ✅ **Audit & Tracing** - Complete execution history and tracing capabilities -- ✅ **Variable Management** - Runtime variables and templating support -- ✅ **Condition-based Routing** - Dynamic workflow paths based on conditions - -## 📁 Project Structure - -``` -workflow/ -├── types.go # Core types and interfaces -├── processors.go # Node type processors (Task, API, Transform, etc.) -├── registry.go # Workflow definition storage and management -├── engine.go # Main workflow execution engine -├── api.go # HTTP API handlers and routes -├── demo/ -│ └── main.go # Comprehensive demonstration -└── example/ - └── main.go # Simple usage examples -``` - -## 🚀 Quick Start - -### 1. Import the Package -```go -import "github.com/oarkflow/mq/workflow" -``` - -### 2. Create and Start Engine -```go -config := &workflow.Config{ - MaxWorkers: 10, - ExecutionTimeout: 30 * time.Minute, - EnableMetrics: true, - EnableAudit: true, -} - -engine := workflow.NewWorkflowEngine(config) -ctx := context.Background() -engine.Start(ctx) -defer engine.Stop(ctx) -``` - -### 3. Define a Workflow -```go -workflow := &workflow.WorkflowDefinition{ - ID: "sample-workflow", - Name: "Sample Data Processing", - Description: "A simple data processing workflow", - Version: "1.0.0", - Status: workflow.WorkflowStatusActive, - Nodes: []workflow.WorkflowNode{ - { - ID: "fetch-data", - Name: "Fetch Data", - Type: workflow.NodeTypeAPI, - Config: workflow.NodeConfig{ - URL: "https://api.example.com/data", - Method: "GET", - }, - }, - { - ID: "process-data", - Name: "Process Data", - Type: workflow.NodeTypeTransform, - Config: workflow.NodeConfig{ - TransformType: "json_path", - Expression: "$.data", - }, - }, - }, - Edges: []workflow.WorkflowEdge{ - { - ID: "fetch-to-process", - FromNode: "fetch-data", - ToNode: "process-data", - }, - }, -} - -// Register workflow -engine.RegisterWorkflow(ctx, workflow) -``` - -### 4. Execute Workflow -```go -execution, err := engine.ExecuteWorkflow(ctx, "sample-workflow", map[string]any{ - "input_data": "test_value", -}, &workflow.ExecutionOptions{ - Priority: workflow.PriorityMedium, - Owner: "user123", -}) - -if err != nil { - log.Fatal(err) -} - -fmt.Printf("Execution started: %s\n", execution.ID) -``` - -## 🏗️ Node Types - -The workflow engine supports various node types for different use cases: - -### Task Node -Execute custom scripts or commands -```go -{ - Type: workflow.NodeTypeTask, - Config: workflow.NodeConfig{ - Script: "console.log('Processing:', ${data})", - }, -} -``` - -### API Node -Make HTTP requests to external services -```go -{ - Type: workflow.NodeTypeAPI, - Config: workflow.NodeConfig{ - URL: "https://api.service.com/endpoint", - Method: "POST", - Headers: map[string]string{ - "Authorization": "Bearer ${token}", - }, - }, -} -``` - -### Transform Node -Transform and manipulate data -```go -{ - Type: workflow.NodeTypeTransform, - Config: workflow.NodeConfig{ - TransformType: "json_path", - Expression: "$.users[*].email", - }, -} -``` - -### Decision Node -Conditional routing based on rules -```go -{ - Type: workflow.NodeTypeDecision, - Config: workflow.NodeConfig{ - Rules: []workflow.Rule{ - { - Condition: "age >= 18", - Output: "adult", - NextNode: "adult-process", - }, - { - Condition: "age < 18", - Output: "minor", - NextNode: "minor-process", - }, - }, - }, -} -``` - -### Human Task Node -Wait for human intervention -```go -{ - Type: workflow.NodeTypeHumanTask, - Config: workflow.NodeConfig{ - Custom: map[string]any{ - "assignee": "manager@company.com", - "due_date": "3 days", - "description": "Please review and approve", - }, - }, -} -``` - -### Timer Node -Add delays or scheduled execution -```go -{ - Type: workflow.NodeTypeTimer, - Config: workflow.NodeConfig{ - Duration: 30 * time.Second, - Schedule: "0 9 * * 1", // Every Monday at 9 AM - }, -} -``` - -### Database Node -Execute database operations -```go -{ - Type: workflow.NodeTypeDatabase, - Config: workflow.NodeConfig{ - Query: "INSERT INTO logs (message, created_at) VALUES (?, ?)", - Connection: "main_db", - }, -} -``` - -### Email Node -Send email notifications -```go -{ - Type: workflow.NodeTypeEmail, - Config: workflow.NodeConfig{ - To: []string{"user@example.com"}, - Subject: "Workflow Completed", - Body: "Your workflow has completed successfully.", - }, -} -``` - -## 🌐 REST API Endpoints - -### Workflow Management -``` -POST /api/v1/workflows # Create workflow -GET /api/v1/workflows # List workflows -GET /api/v1/workflows/:id # Get workflow -PUT /api/v1/workflows/:id # Update workflow -DELETE /api/v1/workflows/:id # Delete workflow -GET /api/v1/workflows/:id/versions # Get versions -``` - -### Execution Management -``` -POST /api/v1/workflows/:id/execute # Execute workflow -GET /api/v1/workflows/:id/executions # List workflow executions -GET /api/v1/workflows/executions # List all executions -GET /api/v1/workflows/executions/:id # Get execution -POST /api/v1/workflows/executions/:id/cancel # Cancel execution -POST /api/v1/workflows/executions/:id/suspend# Suspend execution -POST /api/v1/workflows/executions/:id/resume # Resume execution -``` - -### Monitoring -``` -GET /api/v1/workflows/health # Health check -GET /api/v1/workflows/metrics # System metrics -``` - -## 🎮 Demo Application - -Run the comprehensive demo to see all features: - -```bash -cd /Users/sujit/Sites/mq -go build -o workflow-demo ./workflow/demo -./workflow-demo -``` - -The demo includes: -- **Data Processing Workflow** - API integration, validation, transformation, and storage -- **Approval Workflow** - Multi-stage human task workflow with conditional routing -- **ETL Pipeline** - Parallel data processing with complex transformations - -Demo endpoints: -- `http://localhost:3000/` - Main API info -- `http://localhost:3000/demo/workflows` - View registered workflows -- `http://localhost:3000/demo/executions` - View running executions -- `http://localhost:3000/api/v1/workflows/health` - Health check - -## 🔧 Configuration - -### Engine Configuration -```go -config := &workflow.Config{ - MaxWorkers: 10, // Concurrent execution workers - ExecutionTimeout: 30 * time.Minute, // Maximum execution time - EnableMetrics: true, // Enable metrics collection - EnableAudit: true, // Enable audit logging - EnableTracing: true, // Enable execution tracing - LogLevel: "info", // Logging level - Storage: workflow.StorageConfig{ - Type: "memory", // Storage backend - MaxConnections: 100, // Max storage connections - }, - Security: workflow.SecurityConfig{ - EnableAuth: false, // Enable authentication - AllowedOrigins: []string{"*"}, // CORS allowed origins - }, -} -``` - -### Workflow Configuration -```go -config := workflow.WorkflowConfig{ - Timeout: &timeout, // Workflow timeout - MaxRetries: 3, // Maximum retry attempts - Priority: workflow.PriorityMedium, // Execution priority - Concurrency: 5, // Concurrent node execution - ErrorHandling: workflow.ErrorHandling{ - OnFailure: "stop", // stop, continue, retry - MaxErrors: 3, // Maximum errors allowed - Rollback: false, // Enable rollback on failure - }, -} -``` - -## 📊 Execution Monitoring - -### Execution Status -- `pending` - Execution is queued -- `running` - Currently executing -- `completed` - Finished successfully -- `failed` - Execution failed -- `cancelled` - Manually cancelled -- `suspended` - Temporarily suspended - -### Execution Context -Each execution maintains: -- **Variables** - Runtime variables and data -- **Trace** - Complete execution history -- **Checkpoints** - Recovery points -- **Metadata** - Additional context information - -### Node Execution Tracking -Each node execution tracks: -- Input/Output data -- Execution duration -- Error information -- Retry attempts -- Execution logs - -## 🔒 Security Features - -### Authentication & Authorization -- Configurable authentication system -- Role-based access control -- API key management -- JWT token support - -### Data Security -- Input/output data encryption -- Secure variable storage -- Audit trail logging -- CORS protection - -## 🚀 Performance Features - -### Scalability -- Horizontal scaling support -- Worker pool management -- Concurrent execution -- Resource optimization - -### Optimization -- DAG-based execution optimization -- Caching strategies -- Memory management -- Performance monitoring - -## 🔧 Extensibility - -### Custom Node Types -Add custom processors by implementing the `WorkflowProcessor` interface: - -```go -type CustomProcessor struct { - Config workflow.NodeConfig -} - -func (p *CustomProcessor) Process(ctx context.Context, data []byte) mq.Result { - // Custom processing logic - return mq.Result{Payload: processedData} -} - -func (p *CustomProcessor) Close() error { - // Cleanup logic - return nil -} -``` - -### Storage Backends -Implement custom storage by satisfying the interfaces: -- `WorkflowRegistry` - Workflow definition storage -- `StateManager` - Execution state management - -### Custom Middleware -Add middleware for cross-cutting concerns: -- Logging -- Metrics collection -- Authentication -- Rate limiting - -## 📈 Production Considerations - -### Monitoring & Observability -- Implement proper logging -- Set up metrics collection -- Configure health checks -- Enable distributed tracing - -### High Availability -- Database clustering -- Load balancing -- Failover mechanisms -- Backup strategies - -### Security Hardening -- Enable authentication -- Implement proper RBAC -- Secure API endpoints -- Audit logging - -## 🎯 Use Cases - -This workflow engine is perfect for: - -1. **Data Processing Pipelines** - ETL/ELT operations, data validation, transformation -2. **Business Process Automation** - Approval workflows, document processing, compliance -3. **Integration Workflows** - API orchestration, system integration, event processing -4. **DevOps Automation** - CI/CD pipelines, deployment workflows, infrastructure automation -5. **Notification Systems** - Multi-channel notifications, escalation workflows -6. **Content Management** - Publishing workflows, review processes, content distribution - -## ✅ Production Readiness Checklist - -The workflow engine includes all production-ready features: - -- ✅ **Comprehensive Type System** - Full type definitions for all components -- ✅ **Multiple Node Processors** - 11+ different node types for various use cases -- ✅ **Storage & Registry** - Versioned workflow storage with filtering and pagination -- ✅ **Execution Engine** - DAG-based execution with state management -- ✅ **Scheduling System** - Delayed execution and workflow scheduling -- ✅ **REST API** - Complete HTTP API with all CRUD operations -- ✅ **Error Handling** - Comprehensive error handling and recovery -- ✅ **Monitoring** - Health checks, metrics, and execution tracking -- ✅ **Security** - Authentication, authorization, and CORS support -- ✅ **Scalability** - Worker pools, concurrency control, and resource management -- ✅ **Extensibility** - Plugin architecture for custom processors and storage -- ✅ **Documentation** - Complete documentation with examples and demos - -## 🎉 Conclusion - -This complete workflow engine provides everything needed for production enterprise workflow automation. It combines the power of the existing DAG system with modern workflow orchestration capabilities, making it suitable for a wide range of business applications. - -The engine is designed to be: -- **Powerful** - Handles complex workflows with conditional routing and parallel processing -- **Flexible** - Supports multiple node types and custom extensions -- **Scalable** - Built for high-throughput production environments -- **Reliable** - Comprehensive error handling and recovery mechanisms -- **Observable** - Full monitoring, tracing, and metrics capabilities -- **Secure** - Enterprise-grade security features - -Start building your workflows today! 🚀 diff --git a/WORKFLOW_ENGINE_INTEGRATION_COMPLETE.md b/WORKFLOW_ENGINE_INTEGRATION_COMPLETE.md deleted file mode 100644 index 870a084..0000000 --- a/WORKFLOW_ENGINE_INTEGRATION_COMPLETE.md +++ /dev/null @@ -1,176 +0,0 @@ -# Enhanced DAG + Workflow Engine Integration - COMPLETE - -## 🎯 Mission Accomplished! - -**Original Question**: "Does DAG covers entire features of workflow engine from workflow folder? If not implement them" - -**Answer**: ✅ **YES! The DAG system now has COMPLETE feature parity with the workflow engine and more!** - -## 🏆 What Was Accomplished - -### 1. Complete Workflow Processor Integration -All advanced workflow processors from the workflow engine are now fully integrated into the DAG system: - -- ✅ **HTML Processor** - Generate HTML content from templates -- ✅ **SMS Processor** - Send SMS notifications via multiple providers -- ✅ **Auth Processor** - Handle authentication and authorization -- ✅ **Validator Processor** - Data validation with custom rules -- ✅ **Router Processor** - Conditional routing based on rules -- ✅ **Storage Processor** - Data persistence across multiple backends -- ✅ **Notification Processor** - Multi-channel notifications -- ✅ **Webhook Receiver Processor** - Handle incoming webhook requests - -### 2. Complete Workflow Engine Integration -The entire workflow engine is now integrated into the DAG system: - -- ✅ **WorkflowEngineManager** - Central orchestration and management -- ✅ **WorkflowRegistry** - Workflow definition management -- ✅ **AdvancedWorkflowStateManager** - Execution state tracking -- ✅ **WorkflowScheduler** - Time-based workflow execution -- ✅ **WorkflowExecutor** - Workflow execution engine -- ✅ **ProcessorFactory** - Dynamic processor creation and registration - -### 3. Enhanced Data Types and Configurations -Extended the DAG system with advanced workflow data types: - -- ✅ **WorkflowValidationRule** - Field validation with custom rules -- ✅ **WorkflowRoutingRule** - Conditional routing logic -- ✅ **WorkflowNodeConfig** - Enhanced node configuration -- ✅ **WorkflowExecution** - Execution tracking and management -- ✅ **RetryConfig** - Advanced retry policies -- ✅ **ScheduledTask** - Time-based execution scheduling - -### 4. Advanced Features Integration -All advanced workflow features are now part of the DAG system: - -- ✅ **Security & Authentication** - Built-in security features -- ✅ **Middleware Support** - Request/response processing -- ✅ **Circuit Breaker** - Fault tolerance and resilience -- ✅ **Advanced Retry Logic** - Configurable retry policies -- ✅ **State Persistence** - Durable state management -- ✅ **Metrics & Monitoring** - Performance tracking -- ✅ **Scheduling** - Cron-based and time-based execution - -## 📁 Files Created/Enhanced - -### Core Integration Files -1. **`dag/workflow_processors.go`** (NEW) - - Complete implementation of all 8 advanced workflow processors - - BaseProcessor providing common functionality - - Full interface compliance with WorkflowProcessor - -2. **`dag/workflow_factory.go`** (NEW) - - ProcessorFactory for dynamic processor creation - - Registration system for all processor types - - Integration with workflow engine components - -3. **`dag/workflow_engine.go`** (NEW) - - Complete workflow engine implementation - - WorkflowEngineManager with all core components - - Registry, state management, scheduling, and execution - -4. **`dag/enhanced_dag.go`** (ENHANCED) - - Extended with new workflow node types - - Enhanced WorkflowNodeConfig with all workflow features - - Integration points for workflow engine - -### Demo and Examples -5. **`examples/final_integration_demo.go`** (NEW) - - Comprehensive demonstration of all integrated features - - Working examples of processor creation and workflow execution - - Validation that all components work together - -## 🔧 Technical Achievements - -### Integration Architecture -- **Unified System**: DAG + Workflow Engine = Single, powerful orchestration platform -- **Backward Compatibility**: All existing DAG functionality preserved -- **Enhanced Capabilities**: Workflow features enhance DAG beyond original capabilities -- **Production Ready**: Proper error handling, resource management, and cleanup - -### Code Quality -- **Type Safety**: All interfaces properly implemented -- **Error Handling**: Comprehensive error handling throughout -- **Resource Management**: Proper cleanup and resource disposal -- **Documentation**: Extensive comments and documentation - -### Performance -- **Efficient Execution**: Optimized processor creation and execution -- **Memory Management**: Proper resource cleanup and memory management -- **Concurrent Execution**: Support for concurrent workflow execution -- **Scalability**: Configurable concurrency and resource limits - -## 🎯 Feature Parity Comparison - -| Feature Category | Original Workflow | Enhanced DAG | Status | -|-----------------|-------------------|--------------|---------| -| Basic Processors | ✓ Available | ✓ Integrated | ✅ COMPLETE | -| Advanced Processors | ✓ 8 Processors | ✓ All 8 Integrated | ✅ COMPLETE | -| Processor Factory | ✓ Available | ✓ Integrated | ✅ COMPLETE | -| Workflow Engine | ✓ Available | ✓ Integrated | ✅ COMPLETE | -| State Management | ✓ Available | ✓ Enhanced | ✅ ENHANCED | -| Scheduling | ✓ Available | ✓ Enhanced | ✅ ENHANCED | -| Security | ✓ Available | ✓ Enhanced | ✅ ENHANCED | -| Middleware | ✓ Available | ✓ Enhanced | ✅ ENHANCED | -| DAG Visualization | ❌ Not Available | ✓ Available | ✅ ADDED | -| Advanced Retry | ✓ Basic | ✓ Enhanced | ✅ ENHANCED | -| Execution Tracking | ✓ Available | ✓ Enhanced | ✅ ENHANCED | -| Recovery | ✓ Basic | ✓ Advanced | ✅ ENHANCED | - -## 🧪 Validation & Testing - -### Compilation Status -- ✅ `workflow_processors.go` - No errors -- ✅ `workflow_factory.go` - No errors -- ✅ `workflow_engine.go` - No errors -- ✅ `enhanced_dag.go` - No errors -- ✅ `final_integration_demo.go` - No errors - -### Integration Testing -- ✅ All 8 advanced processors can be created successfully -- ✅ Workflow engine starts and manages executions -- ✅ State management creates and tracks executions -- ✅ Registry manages workflow definitions -- ✅ Processor factory creates all processor types -- ✅ Enhanced DAG integrates with workflow engine - -## 🚀 Usage Examples - -The enhanced DAG can now handle complex workflows like: - -```go -// Create enhanced DAG with workflow capabilities -config := &dag.EnhancedDAGConfig{ - EnableWorkflowEngine: true, - EnableStateManagement: true, - EnableAdvancedRetry: true, -} -enhancedDAG, _ := dag.NewEnhancedDAG("workflow", "key", config) - -// Create workflow engine with all features -engine := dag.NewWorkflowEngineManager(&dag.WorkflowEngineConfig{ - MaxConcurrentExecutions: 10, - EnableSecurity: true, - EnableScheduling: true, -}) - -// Use any of the 8 advanced processors -factory := engine.GetProcessorFactory() -htmlProcessor, _ := factory.CreateProcessor("html", config) -smsProcessor, _ := factory.CreateProcessor("sms", config) -// ... and 6 more advanced processors -``` - -## 🎉 Conclusion - -**Mission Status: ✅ COMPLETE SUCCESS!** - -The DAG system now has **COMPLETE feature parity** with the workflow engine from the workflow folder, plus additional enhancements that make it even more powerful: - -1. **All workflow engine features** are now part of the DAG system -2. **All 8 advanced processors** are fully integrated and functional -3. **Enhanced capabilities** beyond the original workflow engine -4. **Backward compatibility** with existing DAG functionality maintained -5. **Production-ready integration** with proper error handling and resource management - -The enhanced DAG system is now a **unified, comprehensive workflow orchestration platform** that combines the best of both DAG and workflow engine capabilities! diff --git a/ack_system.go b/ack_system.go new file mode 100644 index 0000000..718c0db --- /dev/null +++ b/ack_system.go @@ -0,0 +1,414 @@ +package mq + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/oarkflow/mq/logger" +) + +// AckType defines the type of acknowledgment +type AckType int + +const ( + // ACK indicates successful processing + ACK AckType = iota + // NACK indicates processing failure, message should be requeued + NACK + // REJECT indicates message should be moved to DLQ + REJECT +) + +// MessageAck represents an acknowledgment for a message +type MessageAck struct { + TaskID string + AckType AckType + Reason string + Timestamp time.Time +} + +// PendingMessage represents a message awaiting acknowledgment +type PendingMessage struct { + Task *Task + QueueName string + SentAt time.Time + Deadline time.Time + RetryCount int + ConsumerID string +} + +// AckManager manages message acknowledgments and ensures at-least-once delivery +type AckManager struct { + pending map[string]*PendingMessage + mu sync.RWMutex + ackTimeout time.Duration + maxRetries int + redeliverChan chan *PendingMessage + ackChan chan MessageAck + shutdown chan struct{} + logger logger.Logger + onRedeliver func(*PendingMessage) + onAck func(*PendingMessage) + onNack func(*PendingMessage) + onReject func(*PendingMessage) + onTimeout func(*PendingMessage) +} + +// AckManagerConfig holds configuration for AckManager +type AckManagerConfig struct { + AckTimeout time.Duration + MaxRetries int + CheckInterval time.Duration + Logger logger.Logger +} + +// NewAckManager creates a new acknowledgment manager +func NewAckManager(config AckManagerConfig) *AckManager { + if config.AckTimeout == 0 { + config.AckTimeout = 30 * time.Second + } + if config.MaxRetries == 0 { + config.MaxRetries = 3 + } + if config.CheckInterval == 0 { + config.CheckInterval = 5 * time.Second + } + + am := &AckManager{ + pending: make(map[string]*PendingMessage), + ackTimeout: config.AckTimeout, + maxRetries: config.MaxRetries, + redeliverChan: make(chan *PendingMessage, 1000), + ackChan: make(chan MessageAck, 1000), + shutdown: make(chan struct{}), + logger: config.Logger, + } + + // Start background workers + go am.processAcks() + go am.checkTimeouts(config.CheckInterval) + + return am +} + +// TrackMessage adds a message to the pending acknowledgment list +func (am *AckManager) TrackMessage(ctx context.Context, task *Task, queueName, consumerID string) error { + am.mu.Lock() + defer am.mu.Unlock() + + if _, exists := am.pending[task.ID]; exists { + return fmt.Errorf("message already being tracked: %s", task.ID) + } + + now := time.Now() + pending := &PendingMessage{ + Task: task, + QueueName: queueName, + SentAt: now, + Deadline: now.Add(am.ackTimeout), + RetryCount: task.Retries, + ConsumerID: consumerID, + } + + am.pending[task.ID] = pending + + am.logger.Debug("Message tracked for acknowledgment", + logger.Field{Key: "taskID", Value: task.ID}, + logger.Field{Key: "queue", Value: queueName}, + logger.Field{Key: "consumer", Value: consumerID}) + + return nil +} + +// Acknowledge processes an acknowledgment for a message +func (am *AckManager) Acknowledge(ctx context.Context, taskID string, ackType AckType, reason string) error { + ack := MessageAck{ + TaskID: taskID, + AckType: ackType, + Reason: reason, + Timestamp: time.Now(), + } + + select { + case am.ackChan <- ack: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-am.shutdown: + return fmt.Errorf("ack manager is shutting down") + } +} + +// processAcks processes incoming acknowledgments +func (am *AckManager) processAcks() { + for { + select { + case ack := <-am.ackChan: + am.handleAck(ack) + case <-am.shutdown: + return + } + } +} + +// handleAck handles a single acknowledgment +func (am *AckManager) handleAck(ack MessageAck) { + am.mu.Lock() + pending, exists := am.pending[ack.TaskID] + if !exists { + am.mu.Unlock() + am.logger.Warn("Received ACK for unknown message", + logger.Field{Key: "taskID", Value: ack.TaskID}) + return + } + + delete(am.pending, ack.TaskID) + am.mu.Unlock() + + switch ack.AckType { + case ACK: + pending.Task.Status = Completed + pending.Task.ProcessedAt = time.Now() + am.logger.Info("Message acknowledged successfully", + logger.Field{Key: "taskID", Value: ack.TaskID}, + logger.Field{Key: "queue", Value: pending.QueueName}) + if am.onAck != nil { + am.onAck(pending) + } + + case NACK: + pending.RetryCount++ + if pending.RetryCount < am.maxRetries { + am.logger.Info("Message NACKed, requeuing", + logger.Field{Key: "taskID", Value: ack.TaskID}, + logger.Field{Key: "retryCount", Value: pending.RetryCount}, + logger.Field{Key: "reason", Value: ack.Reason}) + am.redeliverChan <- pending + if am.onNack != nil { + am.onNack(pending) + } + } else { + am.logger.Warn("Message exceeded max retries, rejecting", + logger.Field{Key: "taskID", Value: ack.TaskID}, + logger.Field{Key: "retries", Value: pending.RetryCount}) + pending.Task.Status = Failed + if am.onReject != nil { + am.onReject(pending) + } + } + + case REJECT: + pending.Task.Status = Failed + am.logger.Warn("Message rejected", + logger.Field{Key: "taskID", Value: ack.TaskID}, + logger.Field{Key: "reason", Value: ack.Reason}) + if am.onReject != nil { + am.onReject(pending) + } + } +} + +// checkTimeouts periodically checks for messages that have exceeded their acknowledgment timeout +func (am *AckManager) checkTimeouts(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + am.processTimeouts() + case <-am.shutdown: + return + } + } +} + +// processTimeouts finds and handles timed-out messages +func (am *AckManager) processTimeouts() { + now := time.Now() + am.mu.Lock() + + var timedOut []*PendingMessage + for taskID, pending := range am.pending { + if now.After(pending.Deadline) { + timedOut = append(timedOut, pending) + delete(am.pending, taskID) + } + } + + am.mu.Unlock() + + for _, pending := range timedOut { + pending.RetryCount++ + if pending.RetryCount < am.maxRetries { + am.logger.Warn("Message acknowledgment timeout, requeuing", + logger.Field{Key: "taskID", Value: pending.Task.ID}, + logger.Field{Key: "retryCount", Value: pending.RetryCount}, + logger.Field{Key: "queue", Value: pending.QueueName}) + am.redeliverChan <- pending + if am.onTimeout != nil { + am.onTimeout(pending) + } + } else { + am.logger.Error("Message exceeded max retries after timeout", + logger.Field{Key: "taskID", Value: pending.Task.ID}, + logger.Field{Key: "retries", Value: pending.RetryCount}) + pending.Task.Status = Failed + if am.onReject != nil { + am.onReject(pending) + } + } + } +} + +// GetRedeliverChannel returns the channel for redelivering messages +func (am *AckManager) GetRedeliverChannel() <-chan *PendingMessage { + return am.redeliverChan +} + +// SetOnRedeliver sets the callback for message redelivery +func (am *AckManager) SetOnRedeliver(fn func(*PendingMessage)) { + am.onRedeliver = fn +} + +// SetOnAck sets the callback for successful acknowledgments +func (am *AckManager) SetOnAck(fn func(*PendingMessage)) { + am.onAck = fn +} + +// SetOnNack sets the callback for negative acknowledgments +func (am *AckManager) SetOnNack(fn func(*PendingMessage)) { + am.onNack = fn +} + +// SetOnReject sets the callback for rejected messages +func (am *AckManager) SetOnReject(fn func(*PendingMessage)) { + am.onReject = fn +} + +// SetOnTimeout sets the callback for timed-out messages +func (am *AckManager) SetOnTimeout(fn func(*PendingMessage)) { + am.onTimeout = fn +} + +// GetPendingCount returns the number of pending acknowledgments +func (am *AckManager) GetPendingCount() int { + am.mu.RLock() + defer am.mu.RUnlock() + return len(am.pending) +} + +// GetPendingMessages returns all pending messages (for monitoring/debugging) +func (am *AckManager) GetPendingMessages() []*PendingMessage { + am.mu.RLock() + defer am.mu.RUnlock() + + messages := make([]*PendingMessage, 0, len(am.pending)) + for _, msg := range am.pending { + messages = append(messages, msg) + } + return messages +} + +// CancelPending removes a message from pending tracking (e.g., consumer disconnected) +func (am *AckManager) CancelPending(taskID string) { + am.mu.Lock() + defer am.mu.Unlock() + + if pending, exists := am.pending[taskID]; exists { + delete(am.pending, taskID) + am.logger.Debug("Cancelled pending message", + logger.Field{Key: "taskID", Value: taskID}, + logger.Field{Key: "queue", Value: pending.QueueName}) + } +} + +// CancelAllForConsumer removes all pending messages for a specific consumer +func (am *AckManager) CancelAllForConsumer(consumerID string) int { + am.mu.Lock() + defer am.mu.Unlock() + + cancelled := 0 + for taskID, pending := range am.pending { + if pending.ConsumerID == consumerID { + delete(am.pending, taskID) + cancelled++ + // Optionally requeue these messages + go func(p *PendingMessage) { + am.redeliverChan <- p + }(pending) + } + } + + if cancelled > 0 { + am.logger.Info("Cancelled pending messages for disconnected consumer", + logger.Field{Key: "consumerID", Value: consumerID}, + logger.Field{Key: "count", Value: cancelled}) + } + + return cancelled +} + +// Shutdown gracefully shuts down the acknowledgment manager +func (am *AckManager) Shutdown(ctx context.Context) error { + close(am.shutdown) + + // Wait for pending acknowledgments with timeout + timeout := time.After(30 * time.Second) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-timeout: + am.mu.RLock() + pendingCount := len(am.pending) + am.mu.RUnlock() + return fmt.Errorf("shutdown timeout with %d pending messages", pendingCount) + case <-ticker.C: + am.mu.RLock() + pendingCount := len(am.pending) + am.mu.RUnlock() + if pendingCount == 0 { + am.logger.Info("AckManager shutdown complete") + return nil + } + } + } +} + +// GetStats returns statistics about the acknowledgment manager +func (am *AckManager) GetStats() map[string]interface{} { + am.mu.RLock() + defer am.mu.RUnlock() + + var oldestPending time.Time + var totalWaitTime time.Duration + now := time.Now() + + for _, pending := range am.pending { + if oldestPending.IsZero() || pending.SentAt.Before(oldestPending) { + oldestPending = pending.SentAt + } + totalWaitTime += now.Sub(pending.SentAt) + } + + avgWaitTime := time.Duration(0) + if len(am.pending) > 0 { + avgWaitTime = totalWaitTime / time.Duration(len(am.pending)) + } + + return map[string]interface{}{ + "pending_count": len(am.pending), + "oldest_pending": oldestPending, + "avg_wait_time": avgWaitTime, + "ack_timeout": am.ackTimeout, + "max_retries": am.maxRetries, + "redeliver_backlog": len(am.redeliverChan), + "ack_backlog": len(am.ackChan), + } +} diff --git a/dag/dag.go b/dag/dag.go index e3d068a..21e525a 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -344,6 +344,18 @@ func NewDAG(name, key string, finalResultCallback func(taskID string, result mq. ) d.server = mq.NewBroker(opts...) + // Initialize all enhanced features for DAG broker + enhancedConfig := mq.DefaultBrokerEnhancedConfig() + enhancedConfig.Logger = d.server.Options().Logger() + enhancedConfig.EnableEnhancements = true + + // Initialize enhanced features (DLQ, WAL, ACK, dedup, flow control, etc.) + if err := d.server.InitializeEnhancements(enhancedConfig); err != nil { + log.Printf("[WARN] Failed to initialize enhanced features for DAG: %v", err) + } else { + log.Printf("[INFO] Enhanced features initialized for DAG: %s", name) + } + // Now initialize enhanced features that need the server logger := d.server.Options().Logger() d.validator = NewDAGValidator(d) diff --git a/dedup_and_flow.go b/dedup_and_flow.go new file mode 100644 index 0000000..0ce9fd9 --- /dev/null +++ b/dedup_and_flow.go @@ -0,0 +1,511 @@ +package mq + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "runtime" + "sync" + "time" + + "github.com/oarkflow/mq/logger" +) + +// DedupEntry represents a deduplication cache entry +type DedupEntry struct { + MessageID string + ContentHash string + FirstSeen time.Time + LastSeen time.Time + Count int +} + +// DeduplicationManager manages message deduplication +type DeduplicationManager struct { + cache map[string]*DedupEntry + mu sync.RWMutex + window time.Duration + cleanupInterval time.Duration + shutdown chan struct{} + logger logger.Logger + persistent DedupStorage + onDuplicate func(*DedupEntry) +} + +// DedupStorage interface for persistent deduplication storage +type DedupStorage interface { + Store(ctx context.Context, entry *DedupEntry) error + Get(ctx context.Context, key string) (*DedupEntry, error) + Delete(ctx context.Context, key string) error + DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) + Close() error +} + +// DedupConfig holds configuration for deduplication +type DedupConfig struct { + Window time.Duration // Time window for deduplication + CleanupInterval time.Duration + Persistent DedupStorage // Optional persistent storage + Logger logger.Logger +} + +// NewDeduplicationManager creates a new deduplication manager +func NewDeduplicationManager(config DedupConfig) *DeduplicationManager { + if config.Window == 0 { + config.Window = 5 * time.Minute + } + if config.CleanupInterval == 0 { + config.CleanupInterval = 1 * time.Minute + } + + dm := &DeduplicationManager{ + cache: make(map[string]*DedupEntry), + window: config.Window, + cleanupInterval: config.CleanupInterval, + shutdown: make(chan struct{}), + logger: config.Logger, + persistent: config.Persistent, + } + + go dm.cleanupLoop() + + return dm +} + +// CheckDuplicate checks if a message is a duplicate +func (dm *DeduplicationManager) CheckDuplicate(ctx context.Context, task *Task) (bool, error) { + // Generate dedup key from task + dedupKey := dm.generateDedupKey(task) + + dm.mu.Lock() + defer dm.mu.Unlock() + + // Check in-memory cache + if entry, exists := dm.cache[dedupKey]; exists { + // Check if within window + if time.Since(entry.FirstSeen) < dm.window { + entry.LastSeen = time.Now() + entry.Count++ + + if dm.onDuplicate != nil { + go dm.onDuplicate(entry) + } + + dm.logger.Debug("Duplicate message detected", + logger.Field{Key: "dedupKey", Value: dedupKey}, + logger.Field{Key: "count", Value: entry.Count}, + logger.Field{Key: "taskID", Value: task.ID}) + + return true, nil + } + + // Entry expired, remove it + delete(dm.cache, dedupKey) + } + + // Check persistent storage if available + if dm.persistent != nil { + entry, err := dm.persistent.Get(ctx, dedupKey) + if err == nil && time.Since(entry.FirstSeen) < dm.window { + entry.LastSeen = time.Now() + entry.Count++ + dm.cache[dedupKey] = entry + + if dm.onDuplicate != nil { + go dm.onDuplicate(entry) + } + + return true, nil + } + } + + // Not a duplicate, add to cache + entry := &DedupEntry{ + MessageID: task.ID, + ContentHash: dedupKey, + FirstSeen: time.Now(), + LastSeen: time.Now(), + Count: 1, + } + + dm.cache[dedupKey] = entry + + // Persist if storage available + if dm.persistent != nil { + go dm.persistent.Store(ctx, entry) + } + + return false, nil +} + +// generateDedupKey generates a deduplication key from a task +func (dm *DeduplicationManager) generateDedupKey(task *Task) string { + // If task has explicit dedup key, use it + if task.DedupKey != "" { + return task.DedupKey + } + + // Otherwise, hash the content + hasher := sha256.New() + hasher.Write([]byte(task.Topic)) + hasher.Write(task.Payload) + + // Include headers in hash for more precise deduplication + if task.Headers != nil { + headerBytes, _ := json.Marshal(task.Headers) + hasher.Write(headerBytes) + } + + return hex.EncodeToString(hasher.Sum(nil)) +} + +// cleanupLoop periodically cleans up expired entries +func (dm *DeduplicationManager) cleanupLoop() { + ticker := time.NewTicker(dm.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + dm.cleanup() + case <-dm.shutdown: + return + } + } +} + +// cleanup removes expired deduplication entries +func (dm *DeduplicationManager) cleanup() { + dm.mu.Lock() + defer dm.mu.Unlock() + + cutoff := time.Now().Add(-dm.window) + removed := 0 + + for key, entry := range dm.cache { + if entry.FirstSeen.Before(cutoff) { + delete(dm.cache, key) + removed++ + } + } + + if removed > 0 { + dm.logger.Debug("Cleaned up expired dedup entries", + logger.Field{Key: "removed", Value: removed}) + } + + // Cleanup persistent storage + if dm.persistent != nil { + go dm.persistent.DeleteOlderThan(context.Background(), dm.window) + } +} + +// SetOnDuplicate sets callback for duplicate detection +func (dm *DeduplicationManager) SetOnDuplicate(fn func(*DedupEntry)) { + dm.onDuplicate = fn +} + +// GetStats returns deduplication statistics +func (dm *DeduplicationManager) GetStats() map[string]interface{} { + dm.mu.RLock() + defer dm.mu.RUnlock() + + totalDuplicates := 0 + for _, entry := range dm.cache { + totalDuplicates += entry.Count - 1 // Subtract 1 for original message + } + + return map[string]interface{}{ + "cache_size": len(dm.cache), + "total_duplicates": totalDuplicates, + "window": dm.window, + } +} + +// Shutdown stops the deduplication manager +func (dm *DeduplicationManager) Shutdown(ctx context.Context) error { + close(dm.shutdown) + + if dm.persistent != nil { + return dm.persistent.Close() + } + + return nil +} + +// FlowController manages backpressure and flow control +type FlowController struct { + credits int64 + maxCredits int64 + minCredits int64 + creditRefillRate int64 + mu sync.Mutex + logger logger.Logger + shutdown chan struct{} + refillInterval time.Duration + onCreditLow func(current, max int64) + onCreditHigh func(current, max int64) +} + +// FlowControlConfig holds flow control configuration +type FlowControlConfig struct { + MaxCredits int64 + MinCredits int64 + RefillRate int64 // Credits to add per interval + RefillInterval time.Duration + Logger logger.Logger +} + +// NewFlowController creates a new flow controller +func NewFlowController(config FlowControlConfig) *FlowController { + if config.MaxCredits == 0 { + config.MaxCredits = 1000 + } + if config.MinCredits == 0 { + config.MinCredits = 100 + } + if config.RefillRate == 0 { + config.RefillRate = 10 + } + if config.RefillInterval == 0 { + config.RefillInterval = 100 * time.Millisecond + } + + fc := &FlowController{ + credits: config.MaxCredits, + maxCredits: config.MaxCredits, + minCredits: config.MinCredits, + creditRefillRate: config.RefillRate, + refillInterval: config.RefillInterval, + logger: config.Logger, + shutdown: make(chan struct{}), + } + + go fc.refillLoop() + + return fc +} + +// AcquireCredit attempts to acquire credits for processing +func (fc *FlowController) AcquireCredit(ctx context.Context, amount int64) error { + for { + fc.mu.Lock() + if fc.credits >= amount { + fc.credits -= amount + + // Check if credits are low + if fc.credits < fc.minCredits && fc.onCreditLow != nil { + go fc.onCreditLow(fc.credits, fc.maxCredits) + } + + fc.mu.Unlock() + return nil + } + fc.mu.Unlock() + + // Wait before retrying + select { + case <-time.After(10 * time.Millisecond): + continue + case <-ctx.Done(): + return ctx.Err() + case <-fc.shutdown: + return fmt.Errorf("flow controller shutting down") + } + } +} + +// ReleaseCredit returns credits after processing +func (fc *FlowController) ReleaseCredit(amount int64) { + fc.mu.Lock() + defer fc.mu.Unlock() + + fc.credits += amount + if fc.credits > fc.maxCredits { + fc.credits = fc.maxCredits + } + + // Check if credits recovered + if fc.credits > fc.maxCredits/2 && fc.onCreditHigh != nil { + go fc.onCreditHigh(fc.credits, fc.maxCredits) + } +} + +// refillLoop periodically refills credits +func (fc *FlowController) refillLoop() { + ticker := time.NewTicker(fc.refillInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + fc.mu.Lock() + fc.credits += fc.creditRefillRate + if fc.credits > fc.maxCredits { + fc.credits = fc.maxCredits + } + fc.mu.Unlock() + case <-fc.shutdown: + return + } + } +} + +// GetAvailableCredits returns the current available credits +func (fc *FlowController) GetAvailableCredits() int64 { + fc.mu.Lock() + defer fc.mu.Unlock() + return fc.credits +} + +// SetOnCreditLow sets callback for low credit warning +func (fc *FlowController) SetOnCreditLow(fn func(current, max int64)) { + fc.onCreditLow = fn +} + +// SetOnCreditHigh sets callback for credit recovery +func (fc *FlowController) SetOnCreditHigh(fn func(current, max int64)) { + fc.onCreditHigh = fn +} + +// AdjustMaxCredits dynamically adjusts maximum credits +func (fc *FlowController) AdjustMaxCredits(newMax int64) { + fc.mu.Lock() + defer fc.mu.Unlock() + + fc.maxCredits = newMax + if fc.credits > newMax { + fc.credits = newMax + } + + fc.logger.Info("Adjusted max credits", + logger.Field{Key: "newMax", Value: newMax}) +} + +// GetStats returns flow control statistics +func (fc *FlowController) GetStats() map[string]interface{} { + fc.mu.Lock() + defer fc.mu.Unlock() + + utilization := float64(fc.maxCredits-fc.credits) / float64(fc.maxCredits) * 100 + + return map[string]interface{}{ + "credits": fc.credits, + "max_credits": fc.maxCredits, + "min_credits": fc.minCredits, + "utilization": utilization, + "refill_rate": fc.creditRefillRate, + } +} + +// Shutdown stops the flow controller +func (fc *FlowController) Shutdown() { + close(fc.shutdown) +} + +// BackpressureMonitor monitors system backpressure +type BackpressureMonitor struct { + queueDepthThreshold int + memoryThreshold uint64 + errorRateThreshold float64 + checkInterval time.Duration + logger logger.Logger + shutdown chan struct{} + onBackpressureApplied func(reason string) + onBackpressureRelieved func() +} + +// BackpressureConfig holds backpressure configuration +type BackpressureConfig struct { + QueueDepthThreshold int + MemoryThreshold uint64 + ErrorRateThreshold float64 + CheckInterval time.Duration + Logger logger.Logger +} + +// NewBackpressureMonitor creates a new backpressure monitor +func NewBackpressureMonitor(config BackpressureConfig) *BackpressureMonitor { + if config.CheckInterval == 0 { + config.CheckInterval = 5 * time.Second + } + if config.ErrorRateThreshold == 0 { + config.ErrorRateThreshold = 0.5 // 50% error rate + } + + bm := &BackpressureMonitor{ + queueDepthThreshold: config.QueueDepthThreshold, + memoryThreshold: config.MemoryThreshold, + errorRateThreshold: config.ErrorRateThreshold, + checkInterval: config.CheckInterval, + logger: config.Logger, + shutdown: make(chan struct{}), + } + + go bm.monitorLoop() + + return bm +} + +// monitorLoop continuously monitors for backpressure conditions +func (bm *BackpressureMonitor) monitorLoop() { + ticker := time.NewTicker(bm.checkInterval) + defer ticker.Stop() + + backpressureActive := false + + for { + select { + case <-ticker.C: + shouldApply, reason := bm.shouldApplyBackpressure() + + if shouldApply && !backpressureActive { + backpressureActive = true + bm.logger.Warn("Applying backpressure", + logger.Field{Key: "reason", Value: reason}) + if bm.onBackpressureApplied != nil { + bm.onBackpressureApplied(reason) + } + } else if !shouldApply && backpressureActive { + backpressureActive = false + bm.logger.Info("Relieving backpressure") + if bm.onBackpressureRelieved != nil { + bm.onBackpressureRelieved() + } + } + case <-bm.shutdown: + return + } + } +} + +// shouldApplyBackpressure checks if backpressure should be applied +func (bm *BackpressureMonitor) shouldApplyBackpressure() (bool, string) { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + // Check memory threshold + if bm.memoryThreshold > 0 && memStats.Alloc > bm.memoryThreshold { + return true, fmt.Sprintf("memory threshold exceeded: %d > %d", + memStats.Alloc, bm.memoryThreshold) + } + + return false, "" +} + +// SetOnBackpressureApplied sets callback for backpressure application +func (bm *BackpressureMonitor) SetOnBackpressureApplied(fn func(reason string)) { + bm.onBackpressureApplied = fn +} + +// SetOnBackpressureRelieved sets callback for backpressure relief +func (bm *BackpressureMonitor) SetOnBackpressureRelieved(fn func()) { + bm.onBackpressureRelieved = fn +} + +// Shutdown stops the backpressure monitor +func (bm *BackpressureMonitor) Shutdown() { + close(bm.shutdown) +} diff --git a/deferred_tasks.go b/deferred_tasks.go new file mode 100644 index 0000000..242c76b --- /dev/null +++ b/deferred_tasks.go @@ -0,0 +1,173 @@ +package mq + +import ( + "context" + "log" + "time" + + "github.com/oarkflow/mq/consts" +) + +// processDeferredTasks continuously checks for deferred tasks that are ready to be executed +func (b *Broker) processDeferredTasks(ctx context.Context) { + ticker := time.NewTicker(1 * time.Second) // Check every second + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-b.stopDeferredChan: + return + case <-ticker.C: + now := time.Now() + var tasksToProcess []*QueuedTask + var taskIDsToRemove []string + + // Collect tasks that are ready to be processed + b.deferredTasks.ForEach(func(taskID string, queuedTask *QueuedTask) bool { + if queuedTask != nil && queuedTask.Task != nil { + if queuedTask.Task.DeferUntil.Before(now) || queuedTask.Task.DeferUntil.Equal(now) { + tasksToProcess = append(tasksToProcess, queuedTask) + taskIDsToRemove = append(taskIDsToRemove, taskID) + } + } + return true + }) + + // Process and remove ready tasks + for i, task := range tasksToProcess { + if task != nil && task.Message != nil { + queueName := task.Message.Queue + if queue, ok := b.queues.Get(queueName); ok { + // Send task to the queue for processing + select { + case queue.tasks <- task: + log.Printf("[DEFERRED] Task %s is now ready for processing on queue %s", + task.Task.ID, queueName) + b.deferredTasks.Del(taskIDsToRemove[i]) + default: + // Queue is full, keep in deferred state + log.Printf("[DEFERRED] Queue %s is full, task %s will retry", + queueName, task.Task.ID) + } + } else { + log.Printf("[DEFERRED] Queue %s not found for task %s", + queueName, task.Task.ID) + b.deferredTasks.Del(taskIDsToRemove[i]) + } + } + } + } + } +} + +// StartDeferredTaskProcessor starts the background processor for deferred tasks +func (b *Broker) StartDeferredTaskProcessor(ctx context.Context) { + go b.processDeferredTasks(ctx) + log.Println("[DEFERRED] Deferred task processor started") +} + +// StopDeferredTaskProcessor stops the deferred task processor +func (b *Broker) StopDeferredTaskProcessor() { + close(b.stopDeferredChan) + log.Println("[DEFERRED] Deferred task processor stopped") +} + +// AddDeferredTask adds a task to the deferred queue +func (b *Broker) AddDeferredTask(task *QueuedTask) { + if task != nil && task.Task != nil { + b.deferredTasks.Set(task.Task.ID, task) + log.Printf("[DEFERRED] Task %s deferred until %s", + task.Task.ID, task.Task.DeferUntil.Format(time.RFC3339)) + } +} + +// CreateDLQConsumer creates a consumer for a DLQ queue with a retry handler +func (b *Broker) CreateDLQConsumer(ctx context.Context, queueName string, retryHandler Handler) error { + dlqName := queueName + "_dlq" + + // Check if DLQ exists + dlq, ok := b.deadLetter.Get(queueName) + if !ok { + return nil // No DLQ for this queue + } + + // Create a consumer for the DLQ + consumerID := "dlq-consumer-" + queueName + + // Define DLQ processing handler + dlqHandler := func(ctx context.Context, task *Task) Result { + log.Printf("[DLQ] Processing task %s from DLQ %s (attempt %d)", + task.ID, dlqName, task.Retries+1) + + // Call custom retry handler if provided + if retryHandler != nil { + result := retryHandler(ctx, task) + if result.Status == Completed { + log.Printf("[DLQ] Task %s successfully reprocessed from DLQ", task.ID) + return result + } + } + + // Default behavior: retry with exponential backoff + task.Retries++ + if task.Retries < task.MaxRetries { + // Defer for exponential backoff + backoffDuration := time.Duration(1< 4 { + originalQueue = queueName[:len(queueName)-4] + } + + err := b.CreateDLQConsumer(ctx, originalQueue, retryHandler) + if err != nil { + log.Printf("[DLQ] Failed to create consumer for DLQ %s: %v", queueName, err) + } + return true + }) + log.Println("[DLQ] DLQ consumers setup complete") +} diff --git a/dlq_storage.go b/dlq_storage.go new file mode 100644 index 0000000..96288e2 --- /dev/null +++ b/dlq_storage.go @@ -0,0 +1,399 @@ +package mq + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/oarkflow/mq/logger" +) + +// DLQStorage defines the interface for Dead Letter Queue persistent storage +type DLQStorage interface { + // Store persists a failed task to the DLQ + Store(ctx context.Context, task *DLQEntry) error + + // Get retrieves a task from the DLQ by ID + Get(ctx context.Context, taskID string) (*DLQEntry, error) + + // List returns a paginated list of DLQ entries + List(ctx context.Context, offset, limit int, filter DLQFilter) ([]*DLQEntry, error) + + // Delete removes a task from the DLQ + Delete(ctx context.Context, taskID string) error + + // DeleteOlderThan removes entries older than the specified duration + DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) + + // Count returns the total number of entries + Count(ctx context.Context, filter DLQFilter) (int64, error) + + // Close closes the storage + Close() error +} + +// DLQEntry represents a dead letter queue entry +type DLQEntry struct { + TaskID string `json:"task_id"` + QueueName string `json:"queue_name"` + OriginalPayload json.RawMessage `json:"original_payload"` + ErrorMessage string `json:"error_message"` + ErrorType string `json:"error_type"` + FailedAt time.Time `json:"failed_at"` + RetryCount int `json:"retry_count"` + LastRetryAt time.Time `json:"last_retry_at,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + Priority int `json:"priority"` + TraceID string `json:"trace_id,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + ReprocessCount int `json:"reprocess_count"` +} + +// DLQFilter for filtering DLQ entries +type DLQFilter struct { + QueueName string + ErrorType string + FromDate time.Time + ToDate time.Time + MinRetries int + MaxRetries int +} + +// FileDLQStorage implements file-based DLQ storage +type FileDLQStorage struct { + baseDir string + mu sync.RWMutex + logger logger.Logger + index map[string]*DLQEntry // In-memory index for fast lookups +} + +// NewFileDLQStorage creates a new file-based DLQ storage +func NewFileDLQStorage(baseDir string, log logger.Logger) (*FileDLQStorage, error) { + if err := os.MkdirAll(baseDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create DLQ directory: %w", err) + } + + storage := &FileDLQStorage{ + baseDir: baseDir, + logger: log, + index: make(map[string]*DLQEntry), + } + + // Load existing entries into index + if err := storage.loadIndex(); err != nil { + return nil, fmt.Errorf("failed to load DLQ index: %w", err) + } + + return storage, nil +} + +// Store persists a DLQ entry to disk +func (f *FileDLQStorage) Store(ctx context.Context, entry *DLQEntry) error { + f.mu.Lock() + defer f.mu.Unlock() + + // Serialize entry + data, err := json.MarshalIndent(entry, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal DLQ entry: %w", err) + } + + // Create file path based on queue name and task ID + queueDir := filepath.Join(f.baseDir, entry.QueueName) + if err := os.MkdirAll(queueDir, 0755); err != nil { + return fmt.Errorf("failed to create queue directory: %w", err) + } + + filePath := filepath.Join(queueDir, fmt.Sprintf("%s.json", entry.TaskID)) + + // Write atomically using temp file + tempPath := filePath + ".tmp" + if err := os.WriteFile(tempPath, data, 0644); err != nil { + return fmt.Errorf("failed to write DLQ entry: %w", err) + } + + if err := os.Rename(tempPath, filePath); err != nil { + os.Remove(tempPath) + return fmt.Errorf("failed to rename DLQ entry: %w", err) + } + + // Update index + f.index[entry.TaskID] = entry + + f.logger.Info("DLQ entry persisted", + logger.Field{Key: "taskID", Value: entry.TaskID}, + logger.Field{Key: "queue", Value: entry.QueueName}) + + return nil +} + +// Get retrieves a DLQ entry by task ID +func (f *FileDLQStorage) Get(ctx context.Context, taskID string) (*DLQEntry, error) { + f.mu.RLock() + entry, exists := f.index[taskID] + f.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("DLQ entry not found: %s", taskID) + } + + return entry, nil +} + +// List returns paginated DLQ entries +func (f *FileDLQStorage) List(ctx context.Context, offset, limit int, filter DLQFilter) ([]*DLQEntry, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + var entries []*DLQEntry + for _, entry := range f.index { + if f.matchesFilter(entry, filter) { + entries = append(entries, entry) + } + } + + // Sort by failed_at descending (newest first) + for i := 0; i < len(entries)-1; i++ { + for j := i + 1; j < len(entries); j++ { + if entries[i].FailedAt.Before(entries[j].FailedAt) { + entries[i], entries[j] = entries[j], entries[i] + } + } + } + + // Apply pagination + start := offset + if start >= len(entries) { + return []*DLQEntry{}, nil + } + + end := start + limit + if end > len(entries) { + end = len(entries) + } + + return entries[start:end], nil +} + +// Delete removes a DLQ entry +func (f *FileDLQStorage) Delete(ctx context.Context, taskID string) error { + f.mu.Lock() + defer f.mu.Unlock() + + entry, exists := f.index[taskID] + if !exists { + return fmt.Errorf("DLQ entry not found: %s", taskID) + } + + // Delete file + filePath := filepath.Join(f.baseDir, entry.QueueName, fmt.Sprintf("%s.json", taskID)) + if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to delete DLQ entry file: %w", err) + } + + // Remove from index + delete(f.index, taskID) + + f.logger.Info("DLQ entry deleted", logger.Field{Key: "taskID", Value: taskID}) + + return nil +} + +// DeleteOlderThan removes entries older than specified duration +func (f *FileDLQStorage) DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + + cutoff := time.Now().Add(-duration) + deleted := 0 + + for taskID, entry := range f.index { + if entry.FailedAt.Before(cutoff) { + filePath := filepath.Join(f.baseDir, entry.QueueName, fmt.Sprintf("%s.json", taskID)) + if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { + f.logger.Error("Failed to delete old DLQ entry", + logger.Field{Key: "error", Value: err}, + logger.Field{Key: "taskID", Value: taskID}) + continue + } + delete(f.index, taskID) + deleted++ + } + } + + f.logger.Info("Deleted old DLQ entries", logger.Field{Key: "count", Value: deleted}) + + return deleted, nil +} + +// Count returns the total number of DLQ entries +func (f *FileDLQStorage) Count(ctx context.Context, filter DLQFilter) (int64, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + count := int64(0) + for _, entry := range f.index { + if f.matchesFilter(entry, filter) { + count++ + } + } + + return count, nil +} + +// Close closes the storage +func (f *FileDLQStorage) Close() error { + f.mu.Lock() + defer f.mu.Unlock() + + f.index = nil + return nil +} + +// loadIndex loads existing DLQ entries into memory index +func (f *FileDLQStorage) loadIndex() error { + return filepath.Walk(f.baseDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() || filepath.Ext(path) != ".json" { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + f.logger.Error("Failed to read DLQ entry", + logger.Field{Key: "error", Value: err}, + logger.Field{Key: "path", Value: path}) + return nil + } + + var entry DLQEntry + if err := json.Unmarshal(data, &entry); err != nil { + f.logger.Error("Failed to unmarshal DLQ entry", + logger.Field{Key: "error", Value: err}, + logger.Field{Key: "path", Value: path}) + return nil + } + + f.index[entry.TaskID] = &entry + return nil + }) +} + +// matchesFilter checks if an entry matches the filter +func (f *FileDLQStorage) matchesFilter(entry *DLQEntry, filter DLQFilter) bool { + if filter.QueueName != "" && entry.QueueName != filter.QueueName { + return false + } + if filter.ErrorType != "" && entry.ErrorType != filter.ErrorType { + return false + } + if !filter.FromDate.IsZero() && entry.FailedAt.Before(filter.FromDate) { + return false + } + if !filter.ToDate.IsZero() && entry.FailedAt.After(filter.ToDate) { + return false + } + if filter.MinRetries > 0 && entry.RetryCount < filter.MinRetries { + return false + } + if filter.MaxRetries > 0 && entry.RetryCount > filter.MaxRetries { + return false + } + return true +} + +// InMemoryDLQStorage implements in-memory DLQ storage (for testing or small scale) +type InMemoryDLQStorage struct { + entries map[string]*DLQEntry + mu sync.RWMutex +} + +// NewInMemoryDLQStorage creates a new in-memory DLQ storage +func NewInMemoryDLQStorage() *InMemoryDLQStorage { + return &InMemoryDLQStorage{ + entries: make(map[string]*DLQEntry), + } +} + +func (m *InMemoryDLQStorage) Store(ctx context.Context, entry *DLQEntry) error { + m.mu.Lock() + defer m.mu.Unlock() + m.entries[entry.TaskID] = entry + return nil +} + +func (m *InMemoryDLQStorage) Get(ctx context.Context, taskID string) (*DLQEntry, error) { + m.mu.RLock() + defer m.mu.RUnlock() + entry, exists := m.entries[taskID] + if !exists { + return nil, fmt.Errorf("DLQ entry not found: %s", taskID) + } + return entry, nil +} + +func (m *InMemoryDLQStorage) List(ctx context.Context, offset, limit int, filter DLQFilter) ([]*DLQEntry, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + var entries []*DLQEntry + for _, entry := range m.entries { + entries = append(entries, entry) + } + + start := offset + if start >= len(entries) { + return []*DLQEntry{}, nil + } + + end := start + limit + if end > len(entries) { + end = len(entries) + } + + return entries[start:end], nil +} + +func (m *InMemoryDLQStorage) Delete(ctx context.Context, taskID string) error { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.entries, taskID) + return nil +} + +func (m *InMemoryDLQStorage) DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + cutoff := time.Now().Add(-duration) + deleted := 0 + + for taskID, entry := range m.entries { + if entry.FailedAt.Before(cutoff) { + delete(m.entries, taskID) + deleted++ + } + } + + return deleted, nil +} + +func (m *InMemoryDLQStorage) Count(ctx context.Context, filter DLQFilter) (int64, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return int64(len(m.entries)), nil +} + +func (m *InMemoryDLQStorage) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.entries = nil + return nil +} diff --git a/enhanced_integration.go b/enhanced_integration.go new file mode 100644 index 0000000..91d6ed8 --- /dev/null +++ b/enhanced_integration.go @@ -0,0 +1,486 @@ +package mq + +import ( + "context" + "encoding/json" + "time" + + "github.com/oarkflow/mq/logger" +) + +// BrokerEnhancedConfig holds configuration for all enhanced features +type BrokerEnhancedConfig struct { + // DLQ Configuration + DLQStoragePath string + DLQRetentionPeriod time.Duration + DLQMaxSize int + + // WAL Configuration + WALDirectory string + WALMaxFileSize int64 + WALSyncInterval time.Duration + WALFsyncOnWrite bool + + // Acknowledgment Configuration + AckTimeout time.Duration + AckMaxRetries int + AckCheckInterval time.Duration + + // Worker Health Configuration + WorkerHealthTimeout time.Duration + WorkerCheckInterval time.Duration + + // Dynamic Scaling Configuration + MinWorkers int + MaxWorkers int + ScaleUpThreshold float64 + ScaleDownThreshold float64 + + // Deduplication Configuration + DedupWindow time.Duration + DedupCleanupInterval time.Duration + DedupPersistent bool + + // Flow Control Configuration + MaxCredits int64 + MinCredits int64 + CreditRefillRate int64 + CreditRefillInterval time.Duration + + // Backpressure Configuration + QueueDepthThreshold int + MemoryThreshold uint64 + ErrorRateThreshold float64 + + // Snapshot Configuration + SnapshotDirectory string + SnapshotInterval time.Duration + SnapshotRetention time.Duration + + // Tracing Configuration + TracingEnabled bool + TraceRetention time.Duration + TraceExportInterval time.Duration + + Logger logger.Logger + EnableEnhancements bool // Master switch for all enhancements +} + +// EnhancedFeatures holds all enhanced feature managers +type EnhancedFeatures struct { + ackManager *AckManager + walLog *WriteAheadLog + dlqStorage DLQStorage + dedupManager *DeduplicationManager + flowController *FlowController + backpressureMonitor *BackpressureMonitor + snapshotManager *SnapshotManager + traceManager *TraceManager + lifecycleTracker *MessageLifecycleTracker + config *BrokerEnhancedConfig + enabled bool +} + +// InitializeEnhancements initializes all enhanced features for a broker +func (b *Broker) InitializeEnhancements(config *BrokerEnhancedConfig) error { + if config == nil { + config = DefaultBrokerEnhancedConfig() + } + + if !config.EnableEnhancements { + return nil // Enhancements disabled + } + + features := &EnhancedFeatures{ + config: config, + enabled: true, + } + + // Initialize DLQ Storage + if config.DLQStoragePath != "" { + dlqStorage, err := NewFileDLQStorage(config.DLQStoragePath, config.Logger) + if err != nil { + return err + } + features.dlqStorage = dlqStorage + } else { + features.dlqStorage = NewInMemoryDLQStorage() + } + + // Initialize WAL + if config.WALDirectory != "" { + walConfig := WALConfig{ + Directory: config.WALDirectory, + MaxFileSize: config.WALMaxFileSize, + SyncInterval: config.WALSyncInterval, + FsyncOnWrite: config.WALFsyncOnWrite, + Logger: config.Logger, + } + wal, err := NewWriteAheadLog(walConfig) + if err != nil { + return err + } + features.walLog = wal + } + + // Initialize Acknowledgment Manager + ackConfig := AckManagerConfig{ + AckTimeout: config.AckTimeout, + MaxRetries: config.AckMaxRetries, + CheckInterval: config.AckCheckInterval, + Logger: config.Logger, + } + features.ackManager = NewAckManager(ackConfig) + + // Setup acknowledgment callbacks + features.ackManager.SetOnReject(func(pm *PendingMessage) { + // Move to DLQ + entry := &DLQEntry{ + TaskID: pm.Task.ID, + QueueName: pm.QueueName, + OriginalPayload: json.RawMessage(pm.Task.Payload), + ErrorMessage: "Max retries exceeded", + FailedAt: time.Now(), + RetryCount: pm.RetryCount, + } + _ = features.dlqStorage.Store(context.Background(), entry) + }) + + features.ackManager.SetOnRedeliver(func(pm *PendingMessage) { + // Requeue the message + if queue, exists := b.queues.Get(pm.QueueName); exists { + select { + case queue.tasks <- &QueuedTask{Task: pm.Task}: + default: + config.Logger.Warn("Failed to requeue message - queue full", + logger.Field{Key: "taskID", Value: pm.Task.ID}) + } + } + }) + + // Initialize Deduplication Manager + dedupConfig := DedupConfig{ + Window: config.DedupWindow, + CleanupInterval: config.DedupCleanupInterval, + Logger: config.Logger, + } + features.dedupManager = NewDeduplicationManager(dedupConfig) + + // Initialize Flow Controller + flowConfig := FlowControlConfig{ + MaxCredits: config.MaxCredits, + MinCredits: config.MinCredits, + RefillRate: config.CreditRefillRate, + RefillInterval: config.CreditRefillInterval, + Logger: config.Logger, + } + features.flowController = NewFlowController(flowConfig) + + // Initialize Backpressure Monitor + backpressureConfig := BackpressureConfig{ + QueueDepthThreshold: config.QueueDepthThreshold, + MemoryThreshold: config.MemoryThreshold, + ErrorRateThreshold: config.ErrorRateThreshold, + Logger: config.Logger, + } + features.backpressureMonitor = NewBackpressureMonitor(backpressureConfig) + + // Initialize Snapshot Manager + if config.SnapshotDirectory != "" { + snapshotConfig := SnapshotConfig{ + BaseDir: config.SnapshotDirectory, + SnapshotInterval: config.SnapshotInterval, + RetentionPeriod: config.SnapshotRetention, + Logger: config.Logger, + } + snapshotManager, err := NewSnapshotManager(b, snapshotConfig) + if err != nil { + return err + } + features.snapshotManager = snapshotManager + } + + // Initialize Tracing + if config.TracingEnabled { + traceConfig := TraceConfig{ + Storage: NewInMemoryTraceStorage(), + Retention: config.TraceRetention, + ExportInterval: config.TraceExportInterval, + Logger: config.Logger, + } + features.traceManager = NewTraceManager(traceConfig) + features.lifecycleTracker = NewMessageLifecycleTracker(features.traceManager, config.Logger) + } + + // Store features in broker (we'll need to add this field to Broker struct) + b.enhanced = features + + return nil +} + +// DefaultBrokerEnhancedConfig returns default configuration +func DefaultBrokerEnhancedConfig() *BrokerEnhancedConfig { + return &BrokerEnhancedConfig{ + DLQRetentionPeriod: 7 * 24 * time.Hour, + DLQMaxSize: 10000, + WALMaxFileSize: 100 * 1024 * 1024, + WALSyncInterval: 1 * time.Second, + WALFsyncOnWrite: false, + AckTimeout: 30 * time.Second, + AckMaxRetries: 3, + AckCheckInterval: 5 * time.Second, + WorkerHealthTimeout: 30 * time.Second, + WorkerCheckInterval: 10 * time.Second, + MinWorkers: 1, + MaxWorkers: 100, + ScaleUpThreshold: 0.75, + ScaleDownThreshold: 0.25, + DedupWindow: 5 * time.Minute, + DedupCleanupInterval: 1 * time.Minute, + MaxCredits: 1000, + MinCredits: 100, + CreditRefillRate: 10, + CreditRefillInterval: 100 * time.Millisecond, + QueueDepthThreshold: 1000, + MemoryThreshold: 1 * 1024 * 1024 * 1024, // 1GB + ErrorRateThreshold: 0.5, + SnapshotInterval: 5 * time.Minute, + SnapshotRetention: 24 * time.Hour, + TracingEnabled: true, + TraceRetention: 24 * time.Hour, + TraceExportInterval: 30 * time.Second, + EnableEnhancements: true, + } +} + +// EnhancedPublish publishes a message with enhanced features +func (b *Broker) EnhancedPublish(ctx context.Context, task *Task, queueName string) error { + if b.enhanced == nil || !b.enhanced.enabled { + // Fall back to regular publish logic + return b.regularPublish(ctx, task, queueName) + } + + // Check for duplicates + if b.enhanced.dedupManager != nil { + isDuplicate, err := b.enhanced.dedupManager.CheckDuplicate(ctx, task) + if err != nil { + return err + } + if isDuplicate { + b.logger.Debug("Duplicate message rejected", + logger.Field{Key: "taskID", Value: task.ID}) + return nil + } + } + + // Acquire flow control credits + if b.enhanced.flowController != nil { + if err := b.enhanced.flowController.AcquireCredit(ctx, 1); err != nil { + return err + } + defer b.enhanced.flowController.ReleaseCredit(1) + } + + // Write to WAL + if b.enhanced.walLog != nil { + walEntry := &WALEntry{ + EntryType: WALEntryEnqueue, + TaskID: task.ID, + QueueName: queueName, + Payload: json.RawMessage(task.Payload), + } + if err := b.enhanced.walLog.WriteEntry(ctx, walEntry); err != nil { + b.logger.Error("Failed to write WAL entry", + logger.Field{Key: "error", Value: err}) + } + } + + // Start tracing + if b.enhanced.lifecycleTracker != nil { + b.enhanced.lifecycleTracker.TrackEnqueue(ctx, task, queueName) + } + + // Continue with regular publish + return b.regularPublish(ctx, task, queueName) +} + +// regularPublish is the standard publish logic +func (b *Broker) regularPublish(ctx context.Context, task *Task, queueName string) error { + queue, exists := b.queues.Get(queueName) + if !exists { + queue = b.NewQueue(queueName) + } + + // Enqueue task + select { + case queue.tasks <- &QueuedTask{Task: task}: + // Track for acknowledgment if enhanced features enabled + if b.enhanced != nil && b.enhanced.ackManager != nil { + _ = b.enhanced.ackManager.TrackMessage(ctx, task, queueName, "") + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// RecoverFromWAL recovers broker state from WAL +func (b *Broker) RecoverFromWAL(ctx context.Context) error { + if b.enhanced == nil || b.enhanced.walLog == nil { + return nil + } + + b.logger.Info("Starting WAL recovery") + + return b.enhanced.walLog.Replay(func(entry *WALEntry) error { + switch entry.EntryType { + case WALEntryEnqueue: + var task Task + if err := json.Unmarshal(entry.Payload, &task); err != nil { + return err + } + + queue, exists := b.queues.Get(entry.QueueName) + if !exists { + queue = b.NewQueue(entry.QueueName) + } + + select { + case queue.tasks <- &QueuedTask{Task: &task}: + default: + b.logger.Warn("Queue full during recovery", + logger.Field{Key: "queue", Value: entry.QueueName}) + } + + case WALEntryComplete, WALEntryFailed: + // Already processed, no action needed + } + + return nil + }) +} + +// RecoverFromSnapshot recovers broker state from snapshots +func (b *Broker) RecoverFromSnapshot(ctx context.Context) error { + if b.enhanced == nil || b.enhanced.snapshotManager == nil { + return nil + } + + b.logger.Info("Starting snapshot recovery") + + // Recover all queues + var recoveredQueues int + b.queues.ForEach(func(queueName string, _ *Queue) bool { + if err := b.enhanced.snapshotManager.RestoreFromSnapshot(ctx, queueName); err != nil { + b.logger.Error("Failed to restore queue from snapshot", + logger.Field{Key: "queue", Value: queueName}, + logger.Field{Key: "error", Value: err}) + } else { + recoveredQueues++ + } + return true + }) + + b.logger.Info("Snapshot recovery complete", + logger.Field{Key: "queues", Value: recoveredQueues}) + + return nil +} + +// GetEnhancedStats returns comprehensive statistics +func (b *Broker) GetEnhancedStats() map[string]interface{} { + stats := make(map[string]interface{}) + + if b.enhanced == nil { + return stats + } + + if b.enhanced.ackManager != nil { + stats["acknowledgments"] = b.enhanced.ackManager.GetStats() + } + + if b.enhanced.walLog != nil { + stats["wal"] = b.enhanced.walLog.GetStats() + } + + if b.enhanced.dedupManager != nil { + stats["deduplication"] = b.enhanced.dedupManager.GetStats() + } + + if b.enhanced.flowController != nil { + stats["flow_control"] = b.enhanced.flowController.GetStats() + } + + if b.enhanced.snapshotManager != nil { + stats["snapshots"] = b.enhanced.snapshotManager.GetSnapshotStats() + } + + if b.enhanced.traceManager != nil { + stats["tracing"] = b.enhanced.traceManager.GetStats() + } + + return stats +} + +// ShutdownEnhanced gracefully shuts down all enhanced features +func (b *Broker) ShutdownEnhanced(ctx context.Context) error { + if b.enhanced == nil { + return nil + } + + b.logger.Info("Starting enhanced features shutdown") + + // Shutdown components in order + if b.enhanced.backpressureMonitor != nil { + b.enhanced.backpressureMonitor.Shutdown() + } + + if b.enhanced.flowController != nil { + b.enhanced.flowController.Shutdown() + } + + if b.enhanced.dedupManager != nil { + _ = b.enhanced.dedupManager.Shutdown(ctx) + } + + if b.enhanced.ackManager != nil { + _ = b.enhanced.ackManager.Shutdown(ctx) + } + + if b.enhanced.snapshotManager != nil { + _ = b.enhanced.snapshotManager.Shutdown(ctx) + } + + if b.enhanced.walLog != nil { + _ = b.enhanced.walLog.Shutdown(ctx) + } + + if b.enhanced.traceManager != nil { + _ = b.enhanced.traceManager.Shutdown(ctx) + } + + if b.enhanced.dlqStorage != nil { + _ = b.enhanced.dlqStorage.Close() + } + + b.logger.Info("Enhanced features shutdown complete") + + return nil +} + +// InitializeEnhancedPool initializes pool with enhanced features +func (p *Pool) InitializeEnhancedPool(config *BrokerEnhancedConfig) { + if config == nil || !config.EnableEnhancements { + return + } + + // Add health monitor + if config.WorkerHealthTimeout > 0 { + _ = NewWorkerHealthMonitor(p, config.WorkerHealthTimeout, config.WorkerCheckInterval, config.Logger) + } + + // Add dynamic scaler + if config.MinWorkers > 0 && config.MaxWorkers > int(p.numOfWorkers) { + _ = NewDynamicScaler(p, config.MinWorkers, config.MaxWorkers, config.Logger) + } +} diff --git a/examples/broker_server/main.go b/examples/broker_server/main.go new file mode 100644 index 0000000..5919f58 --- /dev/null +++ b/examples/broker_server/main.go @@ -0,0 +1,257 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/logger" +) + +// Comprehensive Broker Server Example +// This demonstrates a production-ready message broker with all enhanced features enabled + +func main() { + fmt.Println("🚀 Starting Production Message Broker Server") + fmt.Println("=" + "=" + string(make([]byte, 58))) + + // Configure logger + nullLogger := logger.NewNullLogger() + + // Configure all enhanced features + enhancedConfig := &mq.BrokerEnhancedConfig{ + // Dead Letter Queue (DLQ) - Store failed messages + DLQStoragePath: "./data/dlq", + DLQRetentionPeriod: 7 * 24 * time.Hour, // Keep failed messages for 7 days + DLQMaxSize: 10000, // Maximum failed messages to store + + // Write-Ahead Log (WAL) - Durability + WALDirectory: "./data/wal", + WALMaxFileSize: 100 * 1024 * 1024, // 100MB per WAL file + WALSyncInterval: 1 * time.Second, // Sync to disk every second + WALFsyncOnWrite: false, // Async for better performance + + // Message Acknowledgments - At-least-once delivery + AckTimeout: 30 * time.Second, // Timeout for message acknowledgment + AckMaxRetries: 3, // Retry failed messages 3 times + AckCheckInterval: 5 * time.Second, // Check for timeouts every 5 seconds + + // Worker Health Monitoring + WorkerHealthTimeout: 30 * time.Second, // Mark worker unhealthy after 30s + WorkerCheckInterval: 10 * time.Second, // Check worker health every 10s + + // Dynamic Worker Scaling + MinWorkers: 2, // Minimum number of workers per queue + MaxWorkers: 10, // Maximum number of workers per queue + ScaleUpThreshold: 0.75, // Scale up when 75% utilized + ScaleDownThreshold: 0.25, // Scale down when below 25% utilized + + // Message Deduplication + DedupWindow: 5 * time.Minute, // Deduplicate within 5-minute window + DedupCleanupInterval: 1 * time.Minute, // Clean up old entries every minute + DedupPersistent: false, // In-memory only for better performance + + // Flow Control & Backpressure + MaxCredits: 1000, // Maximum flow control credits + MinCredits: 100, // Minimum credits before refill + CreditRefillRate: 10, // Credits to refill + CreditRefillInterval: 100 * time.Millisecond, // Refill every 100ms + + // Backpressure Thresholds + QueueDepthThreshold: 1000, // Alert when queue has 1000+ messages + MemoryThreshold: 1 * 1024 * 1024 * 1024, // Alert at 1GB memory usage + ErrorRateThreshold: 0.5, // Alert at 50% error rate + + // Queue Snapshots - Fast recovery + SnapshotDirectory: "./data/snapshots", + SnapshotInterval: 5 * time.Minute, // Take snapshot every 5 minutes + SnapshotRetention: 24 * time.Hour, // Keep snapshots for 24 hours + + // Distributed Tracing + TracingEnabled: true, + TraceRetention: 24 * time.Hour, // Keep traces for 24 hours + TraceExportInterval: 30 * time.Second, // Export traces every 30 seconds + + Logger: nullLogger, + EnableEnhancements: true, // Master switch - set to false to disable all enhancements + } + + // Create broker with standard options + brokerOptions := []mq.Option{ + mq.WithBrokerURL(":9092"), // Broker listen address + mq.WithLogger(nullLogger), + } + + broker := mq.NewBroker(brokerOptions...) + + // Initialize all enhanced features + fmt.Println("\n📦 Initializing enhanced features...") + if err := broker.InitializeEnhancements(enhancedConfig); err != nil { + log.Fatalf("❌ Failed to initialize enhancements: %v", err) + } + fmt.Println("✅ Enhanced features initialized") + + // Attempt to recover from previous crash + fmt.Println("\n🔄 Attempting recovery...") + ctx := context.Background() + + // Recover from WAL (Write-Ahead Log) + if err := broker.RecoverFromWAL(ctx); err != nil { + fmt.Printf("⚠️ WAL recovery: %v (may be first run)\n", err) + } else { + fmt.Println("✅ WAL recovery complete") + } + + // Recover from snapshots + if err := broker.RecoverFromSnapshot(ctx); err != nil { + fmt.Printf("⚠️ Snapshot recovery: %v (may be first run)\n", err) + } else { + fmt.Println("✅ Snapshot recovery complete") + } + + // Create queues for different purposes + fmt.Println("\n📋 Creating queues...") + queues := []string{ + "orders", // Order processing + "payments", // Payment processing + "notifications", // Email/SMS notifications + "analytics", // Data analytics tasks + "reports", // Report generation + } + + for _, queueName := range queues { + _ = broker.NewQueue(queueName) + fmt.Printf(" ✅ Created queue: %s\n", queueName) + } + + // Start the broker server + fmt.Println("\n🌐 Starting broker server on :9092...") + go func() { + if err := broker.Start(ctx); err != nil { + log.Fatalf("❌ Failed to start broker: %v", err) + } + }() + + // Wait for server to be ready + time.Sleep(1 * time.Second) + fmt.Println("✅ Broker server is running") + + // Start periodic stats reporting + go reportStats(broker) + + // Print usage information + printUsageInfo() + + // Set up graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + fmt.Println("\n⏳ Server running. Press Ctrl+C to shutdown gracefully...") + + <-sigChan + + fmt.Println("\n\n🛑 Shutdown signal received...") + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Graceful shutdown + fmt.Println(" 1. Shutting down enhanced features...") + if err := broker.ShutdownEnhanced(shutdownCtx); err != nil { + fmt.Printf("❌ Enhanced shutdown error: %v\n", err) + } else { + fmt.Println(" ✅ Enhanced features shut down") + } + + fmt.Println(" 2. Closing broker...") + if err := broker.Close(); err != nil { + fmt.Printf("❌ Broker close error: %v\n", err) + } else { + fmt.Println(" ✅ Broker closed") + } + + fmt.Println("\n✅ Graceful shutdown complete") + fmt.Println("👋 Broker server stopped") +} + +// reportStats periodically reports broker statistics +func reportStats(broker *mq.Broker) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for range ticker.C { + stats := broker.GetEnhancedStats() + + fmt.Println("\n📊 Broker Statistics:") + fmt.Println(" " + "-" + string(make([]byte, 50))) + + // Acknowledgment stats + if ackStats, ok := stats["acknowledgments"].(map[string]interface{}); ok { + fmt.Printf(" 📝 Acknowledgments:\n") + fmt.Printf(" Pending: %v\n", ackStats["pending_count"]) + fmt.Printf(" Redeliver queue: %v\n", ackStats["redeliver_backlog"]) + } + + // WAL stats + if walStats, ok := stats["wal"].(map[string]interface{}); ok { + fmt.Printf(" 📚 Write-Ahead Log:\n") + fmt.Printf(" Sequence ID: %v\n", walStats["current_sequence_id"]) + fmt.Printf(" Files: %v\n", walStats["total_files"]) + } + + // Deduplication stats + if dedupStats, ok := stats["deduplication"].(map[string]interface{}); ok { + fmt.Printf(" 🔍 Deduplication:\n") + fmt.Printf(" Cache size: %v\n", dedupStats["cache_size"]) + fmt.Printf(" Duplicates blocked: %v\n", dedupStats["total_duplicates"]) + } + + // Flow control stats + if flowStats, ok := stats["flow_control"].(map[string]interface{}); ok { + fmt.Printf(" 🚦 Flow Control:\n") + fmt.Printf(" Credits available: %v\n", flowStats["credits"]) + if util, ok := flowStats["utilization"].(float64); ok { + fmt.Printf(" Utilization: %.1f%%\n", util*100) + } + } + + // Snapshot stats + if snapshotStats, ok := stats["snapshots"].(map[string]interface{}); ok { + fmt.Printf(" 💾 Snapshots:\n") + fmt.Printf(" Total snapshots: %v\n", snapshotStats["total_snapshots"]) + } + + // Tracing stats + if traceStats, ok := stats["tracing"].(map[string]interface{}); ok { + fmt.Printf(" 🔬 Tracing:\n") + fmt.Printf(" Active traces: %v\n", traceStats["active_traces"]) + } + + fmt.Println(" " + "-" + string(make([]byte, 50))) + } +} + +// printUsageInfo prints connection information +func printUsageInfo() { + fmt.Println("\n" + "=" + "=" + string(make([]byte, 58))) + fmt.Println("📡 Connection Information") + fmt.Println("=" + "=" + string(make([]byte, 58))) + fmt.Println("\nTo connect consumers and publishers to this broker:") + fmt.Println("\n Consumer Example:") + fmt.Println(" consumer := mq.NewConsumer(\"consumer-1\", \":9092\")") + fmt.Println(" consumer.Subscribe(\"orders\", handler)") + fmt.Println("\n Publisher Example:") + fmt.Println(" publisher := mq.NewPublisher(\"publisher-1\", \":9092\")") + fmt.Println(" publisher.Publish(task, \"orders\")") + fmt.Println("\n Available Queues:") + fmt.Println(" - orders (order processing)") + fmt.Println(" - payments (payment processing)") + fmt.Println(" - notifications (email/sms)") + fmt.Println(" - analytics (data analytics)") + fmt.Println(" - reports (report generation)") + fmt.Println("\n" + "=" + "=" + string(make([]byte, 58))) +} diff --git a/examples/consumer_example/main.go b/examples/consumer_example/main.go new file mode 100644 index 0000000..2fbef22 --- /dev/null +++ b/examples/consumer_example/main.go @@ -0,0 +1,336 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/oarkflow/json" + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/logger" +) + +// Comprehensive Consumer Example +// Demonstrates a production-ready consumer with: +// - Worker pool for concurrent processing +// - Error handling and retries +// - Security/authentication +// - Graceful shutdown +// - Statistics tracking + +func main() { + fmt.Println("🚀 Starting Production Message Consumer") + fmt.Println(strings.Repeat("=", 60)) + + // Configure logger + nullLogger := logger.NewNullLogger() + + // Consumer configuration + consumerID := "consumer-1" + brokerAddress := ":9092" + + // Subscribe to multiple queues + queues := []string{"orders", "payments", "notifications", "analytics", "reports"} + + fmt.Printf("\n📡 Consumer ID: %s\n", consumerID) + fmt.Printf("📡 Broker Address: %s\n", brokerAddress) + fmt.Printf("📋 Queues: %v\n", queues) + + // Create consumers for each queue + var consumers []*mq.Consumer + + for _, queue := range queues { + consumer := mq.NewConsumer( + fmt.Sprintf("%s-%s", consumerID, queue), // Consumer ID + queue, // Queue name (THIS WAS THE BUG!) + handleTask, // Task handler function + mq.WithBrokerURL(brokerAddress), + mq.WithWorkerPool( + 100, // Worker pool size + 4, // Number of workers + 50000, // Task queue size + ), + mq.WithLogger(nullLogger), + // Optional: Enable security + // mq.WithSecurity(true), + // mq.WithUsername("consumer"), + // mq.WithPassword("con123"), + ) + consumers = append(consumers, consumer) + fmt.Printf(" ✅ Created consumer for queue: %s\n", queue) + } + + fmt.Println("\n✅ Consumers created") + + // Start periodic statistics reporting for first consumer + go reportStatistics(consumers[0]) + + // Start consuming messages + fmt.Println("\n🔄 Starting message consumption...") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Run all consumers in background + for _, consumer := range consumers { + c := consumer // capture for goroutine + go func() { + if err := c.Consume(ctx); err != nil { + log.Printf("❌ Consumer error: %v", err) + } + }() + } + + fmt.Println("✅ All consumers are running") + fmt.Println("\n⏳ Consuming messages. Press Ctrl+C to shutdown gracefully...") + + // Set up graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + <-sigChan + + fmt.Println("\n\n🛑 Shutdown signal received...") + + // Cancel context to stop consumption + cancel() + + // Give a moment for context cancellation to propagate + time.Sleep(500 * time.Millisecond) + + fmt.Println(" 1. Closing consumers (this will stop worker pools)...") + for i, consumer := range consumers { + if err := consumer.Close(); err != nil { + fmt.Printf("❌ Consumer %d close error: %v\n", i, err) + } + } + fmt.Println(" ✅ All consumers closed") + + fmt.Println("\n✅ Graceful shutdown complete") + fmt.Println("👋 Consumer stopped") +} + +// handleTask processes incoming messages +// This is called by worker pool for each task +func handleTask(ctx context.Context, task *mq.Task) mq.Result { + startTime := time.Now() + + fmt.Printf("\n📦 Processing Task\n") + fmt.Printf(" Task ID: %s\n", task.ID) + fmt.Printf(" Priority: %d\n", task.Priority) + + // Parse task payload + var data map[string]interface{} + if err := json.Unmarshal(task.Payload, &data); err != nil { + fmt.Printf(" ❌ Failed to parse task data: %v\n", err) + return mq.Result{ + Status: mq.Failed, + Error: fmt.Errorf("invalid task data: %w", err), + } + } + + // Determine task type + taskType := "unknown" + if t, ok := data["type"].(string); ok { + taskType = t + } + fmt.Printf(" Type: %s\n", taskType) + + // Process based on task type + var err error + switch taskType { + case "order": + err = processOrder(data) + case "payment": + err = processPayment(data) + case "notification": + err = processNotification(data) + default: + err = processGeneric(data) + } + + // Calculate processing time + duration := time.Since(startTime) + fmt.Printf(" ⏱️ Processing time: %v\n", duration) + + // Return result + if err != nil { + fmt.Printf(" ❌ Task failed: %v\n", err) + + // Check if error is retryable + if isRetryableError(err) { + return mq.Result{ + Status: mq.Failed, + Error: err, + } + } + + // Permanent failure + return mq.Result{ + Status: mq.Failed, + Error: err, + } + } + + fmt.Printf(" ✅ Task completed successfully\n") + return mq.Result{ + Status: mq.Completed, + } +} + +// processOrder handles order processing tasks +func processOrder(data map[string]interface{}) error { + fmt.Printf(" 📦 Processing order...\n") + + // Extract order details + orderID := data["order_id"] + customerID := data["customer_id"] + amount := data["amount"] + + fmt.Printf(" Order ID: %v\n", orderID) + fmt.Printf(" Customer ID: %v\n", customerID) + fmt.Printf(" Amount: $%.2f\n", amount) + + // Simulate order processing + time.Sleep(500 * time.Millisecond) + + // Simulate occasional transient errors for testing + if orderID == "ORD-3" { + return fmt.Errorf("temporary database connection error") + } + + return nil +} + +// processPayment handles payment processing tasks +func processPayment(data map[string]interface{}) error { + fmt.Printf(" 💳 Processing payment...\n") + + paymentID := data["payment_id"] + orderID := data["order_id"] + amount := data["amount"] + method := data["method"] + + fmt.Printf(" Payment ID: %v\n", paymentID) + fmt.Printf(" Order ID: %v\n", orderID) + fmt.Printf(" Amount: $%.2f\n", amount) + fmt.Printf(" Method: %v\n", method) + + // Simulate payment processing + time.Sleep(1 * time.Second) + + // Validate payment amount + if amt, ok := amount.(float64); ok && amt < 0 { + return fmt.Errorf("invalid payment amount: %.2f", amt) + } + + return nil +} + +// processNotification handles notification tasks +func processNotification(data map[string]interface{}) error { + fmt.Printf(" 📧 Processing notification...\n") + + recipient := data["recipient"] + subject := data["subject"] + body := data["body"] + + fmt.Printf(" Recipient: %v\n", recipient) + fmt.Printf(" Subject: %v\n", subject) + fmt.Printf(" Body length: %d chars\n", len(fmt.Sprint(body))) + + // Simulate sending notification + time.Sleep(300 * time.Millisecond) + + return nil +} + +// processGeneric handles unknown task types +func processGeneric(data map[string]interface{}) error { + fmt.Printf(" ⚙️ Processing generic task...\n") + + // Just print the data + for key, value := range data { + fmt.Printf(" %s: %v\n", key, value) + } + + time.Sleep(200 * time.Millisecond) + return nil +} + +// isRetryableError determines if an error should trigger a retry +func isRetryableError(err error) bool { + if err == nil { + return false + } + + errMsg := err.Error() + + // List of retryable error patterns + retryablePatterns := []string{ + "temporary", + "timeout", + "connection", + "network", + "unavailable", + } + + for _, pattern := range retryablePatterns { + if contains(errMsg, pattern) { + return true + } + } + + return false +} + +// reportStatistics periodically reports consumer statistics +func reportStatistics(consumer *mq.Consumer) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for range ticker.C { + metrics := consumer.Metrics() + + fmt.Println("\n📊 Consumer Statistics:") + fmt.Println(" " + strings.Repeat("-", 50)) + fmt.Printf(" Consumer ID: %s\n", consumer.GetKey()) + fmt.Printf(" Total Tasks: %d\n", metrics.TotalTasks) + fmt.Printf(" Completed Tasks: %d\n", metrics.CompletedTasks) + fmt.Printf(" Failed Tasks: %d\n", metrics.ErrorCount) + fmt.Printf(" Scheduled Tasks: %d\n", metrics.TotalScheduled) + fmt.Printf(" Memory Used: %d bytes\n", metrics.TotalMemoryUsed) + + if metrics.TotalTasks > 0 { + successRate := float64(metrics.CompletedTasks) / float64(metrics.TotalTasks) * 100 + fmt.Printf(" Success Rate: %.1f%%\n", successRate) + } + + if metrics.TotalTasks > 0 && metrics.ExecutionTime > 0 { + avgTime := time.Duration(metrics.ExecutionTime/metrics.TotalTasks) * time.Millisecond + fmt.Printf(" Avg Processing Time: %v\n", avgTime) + } + + fmt.Println(" " + strings.Repeat("-", 50)) + } +} + +// Helper function +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && + (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || + findInString(s, substr))) +} + +func findInString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/examples/publisher_example/main.go b/examples/publisher_example/main.go new file mode 100644 index 0000000..9e1e200 --- /dev/null +++ b/examples/publisher_example/main.go @@ -0,0 +1,266 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/oarkflow/json" + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/logger" +) + +// Comprehensive Publisher Example +// Demonstrates a production-ready publisher with: +// - Connection pooling +// - Batch publishing +// - Error handling and retries +// - Message deduplication support +// - Security/authentication + +func main() { + fmt.Println("🚀 Starting Production Message Publisher") + fmt.Println(strings.Repeat("=", 60)) + + // Configure logger + nullLogger := logger.NewNullLogger() + + // Publisher configuration + publisherID := "publisher-1" + brokerAddress := ":9092" + + fmt.Printf("\n📡 Publisher ID: %s\n", publisherID) + fmt.Printf("📡 Broker Address: %s\n", brokerAddress) + + // Create publisher with authentication (optional) + publisher := mq.NewPublisher( + publisherID, + mq.WithBrokerURL(brokerAddress), + mq.WithLogger(nullLogger), + // Optional: Enable security + // mq.WithSecurity(true), + // mq.WithUsername("publisher"), + // mq.WithPassword("pub123"), + ) + + // Connect to broker + ctx := context.Background() + fmt.Println("\n🔌 Connecting to broker...") + + // Publisher connects automatically on first Publish, but we can test connection + time.Sleep(500 * time.Millisecond) + fmt.Println("✅ Publisher ready") + + // Publish messages to different queues + fmt.Println("\n📤 Publishing messages...") + + // Publish orders + if err := publishOrders(ctx, publisher); err != nil { + log.Printf("❌ Failed to publish orders: %v", err) + } + + // Publish payments + if err := publishPayments(ctx, publisher); err != nil { + log.Printf("❌ Failed to publish payments: %v", err) + } + + // Publish notifications + if err := publishNotifications(ctx, publisher); err != nil { + log.Printf("❌ Failed to publish notifications: %v", err) + } + + fmt.Println("\n✅ All messages published successfully") + + // Start periodic message publishing (optional) + fmt.Println("\n🔄 Starting periodic publishing (every 10 seconds)...") + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + // Set up graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + fmt.Println("⏳ Publisher running. Press Ctrl+C to shutdown...") + + for { + select { + case <-ticker.C: + // Publish periodic messages + if err := publishPeriodicMessage(ctx, publisher); err != nil { + log.Printf("⚠️ Periodic publish error: %v", err) + } + + case <-sigChan: + fmt.Println("\n\n🛑 Shutdown signal received...") + fmt.Println("✅ Graceful shutdown complete") + fmt.Println("👋 Publisher stopped") + return + } + } +} + +// publishOrders publishes sample order messages +func publishOrders(ctx context.Context, publisher *mq.Publisher) error { + fmt.Println("\n 📦 Publishing orders...") + + for i := 1; i <= 5; i++ { + orderData := map[string]interface{}{ + "type": "order", + "order_id": fmt.Sprintf("ORD-%d", i), + "customer_id": fmt.Sprintf("CUST-%d", i), + "amount": 100.0 * float64(i), + "items": []string{"item1", "item2", "item3"}, + "timestamp": time.Now().Unix(), + } + + payload, err := json.Marshal(orderData) + if err != nil { + return fmt.Errorf("failed to marshal order: %w", err) + } + + task := mq.NewTask( + fmt.Sprintf("order-task-%d-%d", i, time.Now().Unix()), + payload, + "orders", + mq.WithPriority(i), + mq.WithTaskMaxRetries(3), + mq.WithTTL(1*time.Hour), + mq.WithTags(map[string]string{ + "type": "order", + "priority": fmt.Sprintf("%d", i), + }), + ) + + if err := publisher.Publish(ctx, *task, "orders"); err != nil { + return fmt.Errorf("failed to publish order: %w", err) + } + + fmt.Printf(" ✅ Published: %s (priority: %d)\n", task.ID, task.Priority) + time.Sleep(100 * time.Millisecond) + } + + return nil +} + +// publishPayments publishes sample payment messages +func publishPayments(ctx context.Context, publisher *mq.Publisher) error { + fmt.Println("\n 💳 Publishing payments...") + + for i := 1; i <= 3; i++ { + paymentData := map[string]interface{}{ + "type": "payment", + "payment_id": fmt.Sprintf("PAY-%d", i), + "order_id": fmt.Sprintf("ORD-%d", i), + "amount": 50.0 * float64(i), + "method": "credit_card", + "timestamp": time.Now().Unix(), + } + + payload, err := json.Marshal(paymentData) + if err != nil { + return fmt.Errorf("failed to marshal payment: %w", err) + } + + task := mq.NewTask( + fmt.Sprintf("payment-task-%d-%d", i, time.Now().Unix()), + payload, + "payments", + mq.WithPriority(10), // High priority for payments + mq.WithTaskMaxRetries(3), + ) + + if err := publisher.Publish(ctx, *task, "payments"); err != nil { + return fmt.Errorf("failed to publish payment: %w", err) + } + + fmt.Printf(" ✅ Published: %s\n", task.ID) + time.Sleep(100 * time.Millisecond) + } + + return nil +} + +// publishNotifications publishes sample notification messages +func publishNotifications(ctx context.Context, publisher *mq.Publisher) error { + fmt.Println("\n 📧 Publishing notifications...") + + notifications := []map[string]interface{}{ + { + "type": "notification", + "notif_type": "email", + "recipient": "customer1@example.com", + "subject": "Order Confirmation", + "body": "Your order has been confirmed", + "timestamp": time.Now().Unix(), + }, + { + "type": "notification", + "notif_type": "sms", + "recipient": "+1234567890", + "subject": "Payment Received", + "body": "Your payment of $150 has been received", + "timestamp": time.Now().Unix(), + }, + } + + for i, notif := range notifications { + payload, err := json.Marshal(notif) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + task := mq.NewTask( + fmt.Sprintf("notif-task-%d-%d", i+1, time.Now().Unix()), + payload, + "notifications", + mq.WithPriority(5), + ) + + if err := publisher.Publish(ctx, *task, "notifications"); err != nil { + return fmt.Errorf("failed to publish notification: %w", err) + } + + fmt.Printf(" ✅ Published: %s (%s)\n", task.ID, notif["notif_type"]) + time.Sleep(100 * time.Millisecond) + } + + return nil +} + +// publishPeriodicMessage publishes a periodic heartbeat/analytics message +func publishPeriodicMessage(ctx context.Context, publisher *mq.Publisher) error { + data := map[string]interface{}{ + "type": "analytics", + "event": "heartbeat", + "timestamp": time.Now().Unix(), + "metrics": map[string]interface{}{ + "cpu_usage": 75.5, + "memory_usage": 60.2, + "active_users": 1250, + }, + } + + payload, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal periodic message: %w", err) + } + + task := mq.NewTask( + fmt.Sprintf("analytics-task-%d", time.Now().Unix()), + payload, + "analytics", + mq.WithPriority(1), // Low priority for analytics + ) + + if err := publisher.Publish(ctx, *task, "analytics"); err != nil { + return fmt.Errorf("failed to publish periodic message: %w", err) + } + + fmt.Printf(" 🔄 Published periodic message: %s\n", task.ID) + return nil +} diff --git a/mq.go b/mq.go index 1453649..bf7e491 100644 --- a/mq.go +++ b/mq.go @@ -465,13 +465,15 @@ type StoredMessage struct { type Broker struct { // Core broker functionality - queues storage.IMap[string, *Queue] // Modified to support tenant-specific queues - consumers storage.IMap[string, *consumer] - publishers storage.IMap[string, *publisher] - deadLetter storage.IMap[string, *Queue] - opts *Options - pIDs storage.IMap[string, bool] - listener net.Listener + queues storage.IMap[string, *Queue] // Modified to support tenant-specific queues + consumers storage.IMap[string, *consumer] + publishers storage.IMap[string, *publisher] + deadLetter storage.IMap[string, *Queue] + deferredTasks storage.IMap[string, *QueuedTask] // NEW: Store for deferred tasks + dlqHandlers storage.IMap[string, Handler] // NEW: Custom handlers for DLQ queues + opts *Options + pIDs storage.IMap[string, bool] + listener net.Listener // Enhanced production features connectionPool *ConnectionPool @@ -485,9 +487,11 @@ type Broker struct { authenticatedConns storage.IMap[string, bool] // authenticated connections taskHeaders storage.IMap[string, map[string]string] // task headers by task ID pendingTasks map[string]map[string]*Task // consumerID -> taskID -> task + enhanced *EnhancedFeatures // enhanced features (DLQ, WAL, ACK, etc.) mu sync.RWMutex // for pendingTasks isShutdown int32 shutdown chan struct{} + stopDeferredChan chan struct{} // NEW: Signal to stop deferred task processor wg sync.WaitGroup logger logger.Logger } @@ -497,13 +501,15 @@ func NewBroker(opts ...Option) *Broker { broker := &Broker{ // Core broker functionality - queues: memory.New[string, *Queue](), - publishers: memory.New[string, *publisher](), - consumers: memory.New[string, *consumer](), - deadLetter: memory.New[string, *Queue](), - pIDs: memory.New[string, bool](), - pendingTasks: make(map[string]map[string]*Task), - opts: options, + queues: memory.New[string, *Queue](), + publishers: memory.New[string, *publisher](), + consumers: memory.New[string, *consumer](), + deadLetter: memory.New[string, *Queue](), + deferredTasks: memory.New[string, *QueuedTask](), // NEW: Initialize deferred tasks map + dlqHandlers: memory.New[string, Handler](), // NEW: Initialize DLQ handlers map + pIDs: memory.New[string, bool](), + pendingTasks: make(map[string]map[string]*Task), + opts: options, // Enhanced production features connectionPool: NewConnectionPool(1000), // max 1000 connections @@ -514,6 +520,7 @@ func NewBroker(opts ...Option) *Broker { authenticatedConns: memory.New[string, bool](), taskHeaders: memory.New[string, map[string]string](), shutdown: make(chan struct{}), + stopDeferredChan: make(chan struct{}), // NEW: Initialize stop channel for deferred processor logger: options.Logger(), } @@ -886,6 +893,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M // Store headers for response routing b.taskHeaders.Set(taskID, msg.Headers) + // Send acknowledgment back to publisher ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) if err != nil { log.Printf("Error creating PUBLISH_ACK message: %v\n", err) @@ -894,7 +902,82 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M if err := b.send(ctx, conn, ack); err != nil { log.Printf("Error sending PUBLISH_ACK: %v\n", err) } + + // Apply enhanced features if enabled + if b.enhanced != nil && b.enhanced.enabled { + // Parse the task from the message + var task Task + if err := json.Unmarshal(msg.Payload, &task); err != nil { + log.Printf("Error parsing task for enhanced features: %v\n", err) + b.broadcastToConsumers(msg) + return + } + + // Check for duplicates + if b.enhanced.dedupManager != nil { + isDuplicate, err := b.enhanced.dedupManager.CheckDuplicate(ctx, &task) + if err != nil { + log.Printf("Error checking duplicate: %v\n", err) + } else if isDuplicate { + b.logger.Debug("Duplicate message rejected", + logger.Field{Key: "taskID", Value: task.ID}) + return // Don't broadcast duplicates + } + } + + // Acquire flow control credits + if b.enhanced.flowController != nil { + if err := b.enhanced.flowController.AcquireCredit(ctx, 1); err != nil { + log.Printf("Flow control credit acquisition failed: %v\n", err) + // Continue anyway - don't block + } else { + defer b.enhanced.flowController.ReleaseCredit(1) + } + } + + // Write to WAL + if b.enhanced.walLog != nil { + walEntry := &WALEntry{ + EntryType: WALEntryEnqueue, + TaskID: task.ID, + QueueName: msg.Queue, + Payload: msg.Payload, // already []byte + } + if err := b.enhanced.walLog.WriteEntry(ctx, walEntry); err != nil { + b.logger.Error("Failed to write WAL entry", + logger.Field{Key: "error", Value: err}) + } + } + + // Start tracing + if b.enhanced.lifecycleTracker != nil { + b.enhanced.lifecycleTracker.TrackEnqueue(ctx, &task, msg.Queue) + } + + // Track for acknowledgment + if b.enhanced.ackManager != nil { + _ = b.enhanced.ackManager.TrackMessage(ctx, &task, msg.Queue, pub.id) + } + + // Check if task is deferred + if !task.DeferUntil.IsZero() && task.DeferUntil.After(time.Now()) { + // Create QueuedTask for deferred execution + queuedTask := &QueuedTask{ + Message: msg, + Task: &task, + RetryCount: 0, + } + // Add to deferred tasks queue + b.AddDeferredTask(queuedTask) + log.Printf("[DEFERRED] Task %s deferred until %s", + task.ID, task.DeferUntil.Format(time.RFC3339)) + return // Don't broadcast yet + } + } + + // Broadcast to consumers b.broadcastToConsumers(msg) + go func() { select { case <-ctx.Done(): @@ -928,6 +1011,9 @@ func (b *Broker) Start(ctx context.Context) error { // Start health checker b.healthChecker.Start() + // Start deferred task processor + b.StartDeferredTaskProcessor(ctx) + // Start connection cleanup routine b.wg.Add(1) go b.connectionCleanupRoutine() @@ -1303,6 +1389,8 @@ func (b *Broker) URL() string { func (b *Broker) Close() error { if b != nil && b.listener != nil { log.Printf("Broker is closing...") + // Stop deferred task processor + b.StopDeferredTaskProcessor() return b.listener.Close() } return nil diff --git a/pool.go b/pool.go index ea87b4a..87ab0b4 100644 --- a/pool.go +++ b/pool.go @@ -382,6 +382,7 @@ var Config = &DynamicConfig{ type Pool struct { taskStorage TaskStorage stop chan struct{} + stopOnce sync.Once // Ensure Stop() is only called once taskNotify chan struct{} workerAdjust chan int handler Handler @@ -1169,65 +1170,68 @@ func min(a, b int) int { } func (wp *Pool) Stop() { - wp.logger.Info().Msg("Initiating graceful shutdown") - wp.gracefulShutdown = true + // Use sync.Once to ensure Stop is only executed once + wp.stopOnce.Do(func() { + wp.logger.Info().Msg("Initiating graceful shutdown") + wp.gracefulShutdown = true - // Pause new task processing - wp.Pause() + // Pause new task processing + wp.Pause() - // Signal all goroutines to stop - close(wp.stop) + // Signal all goroutines to stop + close(wp.stop) - // Create channels for coordinated shutdown - workersFinished := make(chan struct{}) - tasksFinished := make(chan struct{}) + // Create channels for coordinated shutdown + workersFinished := make(chan struct{}) + tasksFinished := make(chan struct{}) - // Wait for workers to finish - go func() { - wp.wg.Wait() - close(workersFinished) - }() + // Wait for workers to finish + go func() { + wp.wg.Wait() + close(workersFinished) + }() - // Wait for pending tasks to complete - go func() { - wp.taskCompletionNotifier.Wait() - close(tasksFinished) - }() + // Wait for pending tasks to complete + go func() { + wp.taskCompletionNotifier.Wait() + close(tasksFinished) + }() - // Wait with timeout - shutdownTimer := time.NewTimer(wp.gracefulShutdownTimeout) - defer shutdownTimer.Stop() + // Wait with timeout + shutdownTimer := time.NewTimer(wp.gracefulShutdownTimeout) + defer shutdownTimer.Stop() - workersComplete := false - tasksComplete := false + workersComplete := false + tasksComplete := false - for !workersComplete || !tasksComplete { - select { - case <-workersFinished: - if !workersComplete { - wp.logger.Info().Msg("All workers have finished") - workersComplete = true + for !workersComplete || !tasksComplete { + select { + case <-workersFinished: + if !workersComplete { + wp.logger.Info().Msg("All workers have finished") + workersComplete = true + } + case <-tasksFinished: + if !tasksComplete { + wp.logger.Info().Msg("All pending tasks have completed") + tasksComplete = true + } + case <-shutdownTimer.C: + wp.logger.Warn().Msgf("Graceful shutdown timeout (%v) reached, forcing shutdown", wp.gracefulShutdownTimeout) + goto forceShutdown } - case <-tasksFinished: - if !tasksComplete { - wp.logger.Info().Msg("All pending tasks have completed") - tasksComplete = true - } - case <-shutdownTimer.C: - wp.logger.Warn().Msgf("Graceful shutdown timeout (%v) reached, forcing shutdown", wp.gracefulShutdownTimeout) - goto forceShutdown } - } -forceShutdown: - // Final cleanup - wp.cleanup() + forceShutdown: + // Final cleanup + wp.cleanup() - if wp.completionCallback != nil { - wp.completionCallback() - } + if wp.completionCallback != nil { + wp.completionCallback() + } - wp.logger.Info().Msg("Pool shutdown completed") + wp.logger.Info().Msg("Pool shutdown completed") + }) // Close the Do() function } // cleanup performs final resource cleanup diff --git a/pool_enhancements.go b/pool_enhancements.go new file mode 100644 index 0000000..2a03a22 --- /dev/null +++ b/pool_enhancements.go @@ -0,0 +1,525 @@ +package mq + +import ( + "context" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/oarkflow/mq/logger" +) + +// WorkerHealth represents the health status of a worker +type WorkerHealth struct { + WorkerID int + IsHealthy bool + LastHeartbeat time.Time + TasksProcessed int64 + ErrorCount int64 + AvgProcessTime time.Duration + CurrentTaskID string + CurrentTaskStart time.Time + MemoryUsage uint64 +} + +// WorkerHealthMonitor monitors worker health and performs auto-recovery +type WorkerHealthMonitor struct { + pool *Pool + workers map[int]*WorkerHealth + mu sync.RWMutex + heartbeatTimeout time.Duration + checkInterval time.Duration + shutdown chan struct{} + logger logger.Logger + onUnhealthyWorker func(workerID int, health *WorkerHealth) + onRecoveredWorker func(workerID int) +} + +// NewWorkerHealthMonitor creates a new worker health monitor +func NewWorkerHealthMonitor(pool *Pool, heartbeatTimeout, checkInterval time.Duration, log logger.Logger) *WorkerHealthMonitor { + if heartbeatTimeout == 0 { + heartbeatTimeout = 30 * time.Second + } + if checkInterval == 0 { + checkInterval = 10 * time.Second + } + + monitor := &WorkerHealthMonitor{ + pool: pool, + workers: make(map[int]*WorkerHealth), + heartbeatTimeout: heartbeatTimeout, + checkInterval: checkInterval, + shutdown: make(chan struct{}), + logger: log, + } + + go monitor.monitorLoop() + + return monitor +} + +// RecordHeartbeat records a heartbeat from a worker +func (m *WorkerHealthMonitor) RecordHeartbeat(workerID int, taskID string) { + m.mu.Lock() + defer m.mu.Unlock() + + health, exists := m.workers[workerID] + if !exists { + health = &WorkerHealth{ + WorkerID: workerID, + IsHealthy: true, + } + m.workers[workerID] = health + } + + health.LastHeartbeat = time.Now() + health.IsHealthy = true + health.CurrentTaskID = taskID + if taskID != "" { + health.CurrentTaskStart = time.Now() + } else { + health.CurrentTaskStart = time.Time{} + } +} + +// RecordTaskCompletion records task completion metrics +func (m *WorkerHealthMonitor) RecordTaskCompletion(workerID int, processingTime time.Duration, isError bool) { + m.mu.Lock() + defer m.mu.Unlock() + + health, exists := m.workers[workerID] + if !exists { + return + } + + health.TasksProcessed++ + if isError { + health.ErrorCount++ + } + + // Update average processing time + if health.AvgProcessTime == 0 { + health.AvgProcessTime = processingTime + } else { + health.AvgProcessTime = (health.AvgProcessTime + processingTime) / 2 + } + + health.CurrentTaskID = "" + health.CurrentTaskStart = time.Time{} +} + +// monitorLoop continuously monitors worker health +func (m *WorkerHealthMonitor) monitorLoop() { + ticker := time.NewTicker(m.checkInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + m.checkWorkerHealth() + case <-m.shutdown: + return + } + } +} + +// checkWorkerHealth checks the health of all workers +func (m *WorkerHealthMonitor) checkWorkerHealth() { + now := time.Now() + m.mu.Lock() + defer m.mu.Unlock() + + for workerID, health := range m.workers { + // Check if worker has timed out + if now.Sub(health.LastHeartbeat) > m.heartbeatTimeout { + if health.IsHealthy { + health.IsHealthy = false + m.logger.Warn("Worker unhealthy - heartbeat timeout", + logger.Field{Key: "workerID", Value: workerID}, + logger.Field{Key: "lastHeartbeat", Value: health.LastHeartbeat}, + logger.Field{Key: "currentTask", Value: health.CurrentTaskID}) + + if m.onUnhealthyWorker != nil { + go m.onUnhealthyWorker(workerID, health) + } + + // Attempt to restart the worker + go m.restartWorker(workerID) + } + } + + // Check if worker is stuck on a task + if !health.CurrentTaskStart.IsZero() && now.Sub(health.CurrentTaskStart) > m.heartbeatTimeout*2 { + m.logger.Warn("Worker stuck on task", + logger.Field{Key: "workerID", Value: workerID}, + logger.Field{Key: "taskID", Value: health.CurrentTaskID}, + logger.Field{Key: "duration", Value: now.Sub(health.CurrentTaskStart)}) + } + + // Update memory usage + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + health.MemoryUsage = memStats.Alloc + } +} + +// restartWorker attempts to restart an unhealthy worker +func (m *WorkerHealthMonitor) restartWorker(workerID int) { + m.logger.Info("Attempting to restart unhealthy worker", + logger.Field{Key: "workerID", Value: workerID}) + + // Signal the pool to start a replacement worker + m.pool.wg.Add(1) + go m.pool.worker() + + // Mark as recovered + m.mu.Lock() + if health, exists := m.workers[workerID]; exists { + health.IsHealthy = true + health.LastHeartbeat = time.Now() + } + m.mu.Unlock() + + if m.onRecoveredWorker != nil { + m.onRecoveredWorker(workerID) + } + + m.logger.Info("Worker restarted successfully", + logger.Field{Key: "workerID", Value: workerID}) +} + +// GetHealthStats returns health statistics for all workers +func (m *WorkerHealthMonitor) GetHealthStats() map[string]interface{} { + m.mu.RLock() + defer m.mu.RUnlock() + + healthyCount := 0 + unhealthyCount := 0 + totalTasks := int64(0) + totalErrors := int64(0) + + for _, health := range m.workers { + if health.IsHealthy { + healthyCount++ + } else { + unhealthyCount++ + } + totalTasks += health.TasksProcessed + totalErrors += health.ErrorCount + } + + return map[string]interface{}{ + "total_workers": len(m.workers), + "healthy_workers": healthyCount, + "unhealthy_workers": unhealthyCount, + "total_tasks_processed": totalTasks, + "total_errors": totalErrors, + } +} + +// GetWorkerHealth returns health info for a specific worker +func (m *WorkerHealthMonitor) GetWorkerHealth(workerID int) (*WorkerHealth, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + health, exists := m.workers[workerID] + return health, exists +} + +// SetOnUnhealthyWorker sets callback for unhealthy worker detection +func (m *WorkerHealthMonitor) SetOnUnhealthyWorker(fn func(workerID int, health *WorkerHealth)) { + m.onUnhealthyWorker = fn +} + +// SetOnRecoveredWorker sets callback for worker recovery +func (m *WorkerHealthMonitor) SetOnRecoveredWorker(fn func(workerID int)) { + m.onRecoveredWorker = fn +} + +// Shutdown stops the health monitor +func (m *WorkerHealthMonitor) Shutdown() { + close(m.shutdown) +} + +// DynamicScaler dynamically scales workers based on load +type DynamicScaler struct { + pool *Pool + minWorkers int + maxWorkers int + scaleUpThreshold float64 // Task queue utilization % to scale up + scaleDownThreshold float64 // Task queue utilization % to scale down + cooldownPeriod time.Duration + lastScaleTime time.Time + mu sync.RWMutex + shutdown chan struct{} + logger logger.Logger + checkInterval time.Duration + onScaleUp func(oldCount, newCount int) + onScaleDown func(oldCount, newCount int) +} + +// NewDynamicScaler creates a new dynamic worker scaler +func NewDynamicScaler(pool *Pool, minWorkers, maxWorkers int, log logger.Logger) *DynamicScaler { + scaler := &DynamicScaler{ + pool: pool, + minWorkers: minWorkers, + maxWorkers: maxWorkers, + scaleUpThreshold: 0.75, // Scale up when 75% full + scaleDownThreshold: 0.25, // Scale down when 25% full + cooldownPeriod: 30 * time.Second, + shutdown: make(chan struct{}), + logger: log, + checkInterval: 10 * time.Second, + } + + go scaler.scaleLoop() + + return scaler +} + +// scaleLoop continuously monitors and scales workers +func (s *DynamicScaler) scaleLoop() { + ticker := time.NewTicker(s.checkInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.evaluateScaling() + case <-s.shutdown: + return + } + } +} + +// evaluateScaling evaluates whether to scale up or down +func (s *DynamicScaler) evaluateScaling() { + s.mu.Lock() + defer s.mu.Unlock() + + // Check cooldown period + if time.Since(s.lastScaleTime) < s.cooldownPeriod { + return + } + + // Get current metrics + currentWorkers := int(atomic.LoadInt32(&s.pool.numOfWorkers)) + + s.pool.taskQueueLock.Lock() + queueSize := len(s.pool.taskQueue) + s.pool.taskQueueLock.Unlock() + + // Calculate utilization + queueCapacity := cap(s.pool.taskQueue) + if queueCapacity == 0 { + return + } + + utilization := float64(queueSize) / float64(queueCapacity) + + // Decide to scale up or down + if utilization >= s.scaleUpThreshold && currentWorkers < s.maxWorkers { + newWorkers := currentWorkers + 1 + if newWorkers > s.maxWorkers { + newWorkers = s.maxWorkers + } + s.scaleUp(currentWorkers, newWorkers) + } else if utilization <= s.scaleDownThreshold && currentWorkers > s.minWorkers { + newWorkers := currentWorkers - 1 + if newWorkers < s.minWorkers { + newWorkers = s.minWorkers + } + s.scaleDown(currentWorkers, newWorkers) + } +} + +// scaleUp increases the number of workers +func (s *DynamicScaler) scaleUp(oldCount, newCount int) { + additionalWorkers := newCount - oldCount + + s.logger.Info("Scaling up workers", + logger.Field{Key: "oldCount", Value: oldCount}, + logger.Field{Key: "newCount", Value: newCount}) + + for i := 0; i < additionalWorkers; i++ { + s.pool.wg.Add(1) + go s.pool.worker() + } + + atomic.StoreInt32(&s.pool.numOfWorkers, int32(newCount)) + s.lastScaleTime = time.Now() + + if s.onScaleUp != nil { + s.onScaleUp(oldCount, newCount) + } +} + +// scaleDown decreases the number of workers +func (s *DynamicScaler) scaleDown(oldCount, newCount int) { + s.logger.Info("Scaling down workers", + logger.Field{Key: "oldCount", Value: oldCount}, + logger.Field{Key: "newCount", Value: newCount}) + + atomic.StoreInt32(&s.pool.numOfWorkers, int32(newCount)) + s.lastScaleTime = time.Now() + + // Workers will naturally exit when they check numOfWorkers + + if s.onScaleDown != nil { + s.onScaleDown(oldCount, newCount) + } +} + +// SetScaleUpThreshold sets the threshold for scaling up +func (s *DynamicScaler) SetScaleUpThreshold(threshold float64) { + s.mu.Lock() + defer s.mu.Unlock() + s.scaleUpThreshold = threshold +} + +// SetScaleDownThreshold sets the threshold for scaling down +func (s *DynamicScaler) SetScaleDownThreshold(threshold float64) { + s.mu.Lock() + defer s.mu.Unlock() + s.scaleDownThreshold = threshold +} + +// SetOnScaleUp sets callback for scale up events +func (s *DynamicScaler) SetOnScaleUp(fn func(oldCount, newCount int)) { + s.onScaleUp = fn +} + +// SetOnScaleDown sets callback for scale down events +func (s *DynamicScaler) SetOnScaleDown(fn func(oldCount, newCount int)) { + s.onScaleDown = fn +} + +// Shutdown stops the dynamic scaler +func (s *DynamicScaler) Shutdown() { + close(s.shutdown) +} + +// GracefulShutdownManager manages graceful shutdown of the pool +type GracefulShutdownManager struct { + pool *Pool + timeout time.Duration + drainTasks bool + logger logger.Logger + onShutdownStart func() + onShutdownEnd func() +} + +// NewGracefulShutdownManager creates a new graceful shutdown manager +func NewGracefulShutdownManager(pool *Pool, timeout time.Duration, drainTasks bool, log logger.Logger) *GracefulShutdownManager { + if timeout == 0 { + timeout = 30 * time.Second + } + + return &GracefulShutdownManager{ + pool: pool, + timeout: timeout, + drainTasks: drainTasks, + logger: log, + } +} + +// Shutdown performs a graceful shutdown +func (g *GracefulShutdownManager) Shutdown(ctx context.Context) error { + if g.onShutdownStart != nil { + g.onShutdownStart() + } + + g.logger.Info("Starting graceful shutdown", + logger.Field{Key: "timeout", Value: g.timeout}, + logger.Field{Key: "drainTasks", Value: g.drainTasks}) + + // Stop accepting new tasks + g.pool.gracefulShutdown = true + + if g.drainTasks { + // Wait for existing tasks to complete + g.logger.Info("Draining existing tasks") + + done := make(chan struct{}) + go func() { + g.pool.wg.Wait() + close(done) + }() + + select { + case <-done: + g.logger.Info("All tasks completed successfully") + case <-time.After(g.timeout): + g.logger.Warn("Graceful shutdown timeout exceeded, forcing shutdown") + case <-ctx.Done(): + return ctx.Err() + } + } + + // Stop the pool + close(g.pool.stop) + g.pool.wg.Wait() + + if g.onShutdownEnd != nil { + g.onShutdownEnd() + } + + g.logger.Info("Graceful shutdown complete") + + return nil +} + +// SetOnShutdownStart sets callback for shutdown start +func (g *GracefulShutdownManager) SetOnShutdownStart(fn func()) { + g.onShutdownStart = fn +} + +// SetOnShutdownEnd sets callback for shutdown end +func (g *GracefulShutdownManager) SetOnShutdownEnd(fn func()) { + g.onShutdownEnd = fn +} + +// PoolEnhancedStats returns enhanced statistics about the pool +func PoolEnhancedStats(pool *Pool) map[string]interface{} { + pool.taskQueueLock.Lock() + queueLen := len(pool.taskQueue) + queueCap := cap(pool.taskQueue) + pool.taskQueueLock.Unlock() + + pool.overflowBufferLock.RLock() + overflowLen := len(pool.overflowBuffer) + pool.overflowBufferLock.RUnlock() + + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + return map[string]interface{}{ + "workers": map[string]interface{}{ + "count": atomic.LoadInt32(&pool.numOfWorkers), + "paused": pool.paused, + }, + "queue": map[string]interface{}{ + "length": queueLen, + "capacity": queueCap, + "utilization": float64(queueLen) / float64(queueCap) * 100, + }, + "overflow": map[string]interface{}{ + "length": overflowLen, + }, + "tasks": map[string]interface{}{ + "total": atomic.LoadInt64(&pool.metrics.TotalTasks), + "completed": atomic.LoadInt64(&pool.metrics.CompletedTasks), + "errors": atomic.LoadInt64(&pool.metrics.ErrorCount), + }, + "memory": map[string]interface{}{ + "alloc": memStats.Alloc, + "total_alloc": memStats.TotalAlloc, + "sys": memStats.Sys, + "num_gc": memStats.NumGC, + }, + "dlq": map[string]interface{}{ + "size": pool.dlq.Size(), + }, + } +} diff --git a/snapshot.go b/snapshot.go new file mode 100644 index 0000000..c3fc8d5 --- /dev/null +++ b/snapshot.go @@ -0,0 +1,460 @@ +package mq + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/oarkflow/mq/logger" +) + +// QueueSnapshot represents a point-in-time snapshot of queue state +type QueueSnapshot struct { + Timestamp time.Time `json:"timestamp"` + QueueName string `json:"queue_name"` + PendingTasks []*Task `json:"pending_tasks"` + ConsumerState map[string]*ConsumerMeta `json:"consumer_state"` + Metrics *QueueMetrics `json:"metrics"` + Config *QueueConfig `json:"config"` + Version int `json:"version"` +} + +// ConsumerMeta represents consumer metadata in snapshot +type ConsumerMeta struct { + ID string `json:"id"` + State string `json:"state"` + LastActivity time.Time `json:"last_activity"` + TasksAssigned int `json:"tasks_assigned"` +} + +// SnapshotManager manages queue snapshots and recovery +type SnapshotManager struct { + baseDir string + broker *Broker + snapshotInterval time.Duration + retentionPeriod time.Duration + mu sync.RWMutex + logger logger.Logger + shutdown chan struct{} + wg sync.WaitGroup +} + +// SnapshotConfig holds snapshot configuration +type SnapshotConfig struct { + BaseDir string + SnapshotInterval time.Duration + RetentionPeriod time.Duration // How long to keep old snapshots + Logger logger.Logger +} + +// NewSnapshotManager creates a new snapshot manager +func NewSnapshotManager(broker *Broker, config SnapshotConfig) (*SnapshotManager, error) { + if config.SnapshotInterval == 0 { + config.SnapshotInterval = 5 * time.Minute + } + if config.RetentionPeriod == 0 { + config.RetentionPeriod = 24 * time.Hour + } + + if err := os.MkdirAll(config.BaseDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create snapshot directory: %w", err) + } + + sm := &SnapshotManager{ + baseDir: config.BaseDir, + broker: broker, + snapshotInterval: config.SnapshotInterval, + retentionPeriod: config.RetentionPeriod, + logger: config.Logger, + shutdown: make(chan struct{}), + } + + // Start periodic snapshot worker + sm.wg.Add(1) + go sm.snapshotLoop() + + return sm, nil +} + +// snapshotLoop periodically creates snapshots +func (sm *SnapshotManager) snapshotLoop() { + defer sm.wg.Done() + + ticker := time.NewTicker(sm.snapshotInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := sm.CreateSnapshotAll(context.Background()); err != nil { + sm.logger.Error("Failed to create periodic snapshot", + logger.Field{Key: "error", Value: err}) + } + + // Cleanup old snapshots + if err := sm.CleanupOldSnapshots(); err != nil { + sm.logger.Error("Failed to cleanup old snapshots", + logger.Field{Key: "error", Value: err}) + } + case <-sm.shutdown: + return + } + } +} + +// CreateSnapshot creates a snapshot of a specific queue +func (sm *SnapshotManager) CreateSnapshot(ctx context.Context, queueName string) (*QueueSnapshot, error) { + sm.mu.RLock() + defer sm.mu.RUnlock() + + queue, exists := sm.broker.queues.Get(queueName) + if !exists { + return nil, fmt.Errorf("queue not found: %s", queueName) + } + + // Collect pending tasks + var pendingTasks []*Task + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + // Drain channel into slice (non-blocking) + for { + select { + case qt := <-queue.tasks: + pendingTasks = append(pendingTasks, qt.Task) + // Put it back + queue.tasks <- qt + default: + goto done + } + } + } + +done: + // Collect consumer state + consumerState := make(map[string]*ConsumerMeta) + queue.consumers.ForEach(func(id string, consumer *consumer) bool { + consumerState[id] = &ConsumerMeta{ + ID: id, + State: string(consumer.state), + LastActivity: consumer.metrics.LastActivity, + } + return true + }) + + snapshot := &QueueSnapshot{ + Timestamp: time.Now(), + QueueName: queueName, + PendingTasks: pendingTasks, + ConsumerState: consumerState, + Metrics: queue.metrics, + Config: queue.config, + Version: 1, + } + + // Persist snapshot to disk + if err := sm.persistSnapshot(snapshot); err != nil { + return nil, fmt.Errorf("failed to persist snapshot: %w", err) + } + + sm.logger.Info("Created queue snapshot", + logger.Field{Key: "queue", Value: queueName}, + logger.Field{Key: "pendingTasks", Value: len(pendingTasks)}, + logger.Field{Key: "consumers", Value: len(consumerState)}) + + return snapshot, nil +} + +// CreateSnapshotAll creates snapshots of all queues +func (sm *SnapshotManager) CreateSnapshotAll(ctx context.Context) error { + var snapshots []*QueueSnapshot + var mu sync.Mutex + var wg sync.WaitGroup + + sm.broker.queues.ForEach(func(queueName string, _ *Queue) bool { + wg.Add(1) + go func(name string) { + defer wg.Done() + snapshot, err := sm.CreateSnapshot(ctx, name) + if err != nil { + sm.logger.Error("Failed to create snapshot", + logger.Field{Key: "queue", Value: name}, + logger.Field{Key: "error", Value: err}) + return + } + mu.Lock() + snapshots = append(snapshots, snapshot) + mu.Unlock() + }(queueName) + return true + }) + + wg.Wait() + + sm.logger.Info("Created snapshots for all queues", + logger.Field{Key: "count", Value: len(snapshots)}) + + return nil +} + +// persistSnapshot writes a snapshot to disk +func (sm *SnapshotManager) persistSnapshot(snapshot *QueueSnapshot) error { + data, err := json.MarshalIndent(snapshot, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal snapshot: %w", err) + } + + queueDir := filepath.Join(sm.baseDir, snapshot.QueueName) + if err := os.MkdirAll(queueDir, 0755); err != nil { + return fmt.Errorf("failed to create queue directory: %w", err) + } + + filename := fmt.Sprintf("snapshot-%d.json", snapshot.Timestamp.UnixNano()) + filePath := filepath.Join(queueDir, filename) + + // Write atomically using temp file + tempPath := filePath + ".tmp" + if err := os.WriteFile(tempPath, data, 0644); err != nil { + return fmt.Errorf("failed to write snapshot: %w", err) + } + + if err := os.Rename(tempPath, filePath); err != nil { + os.Remove(tempPath) + return fmt.Errorf("failed to rename snapshot: %w", err) + } + + return nil +} + +// RestoreFromSnapshot restores a queue from the latest snapshot +func (sm *SnapshotManager) RestoreFromSnapshot(ctx context.Context, queueName string) error { + snapshot, err := sm.GetLatestSnapshot(queueName) + if err != nil { + return fmt.Errorf("failed to get latest snapshot: %w", err) + } + + if snapshot == nil { + return fmt.Errorf("no snapshot found for queue: %s", queueName) + } + + sm.logger.Info("Restoring queue from snapshot", + logger.Field{Key: "queue", Value: queueName}, + logger.Field{Key: "timestamp", Value: snapshot.Timestamp}, + logger.Field{Key: "tasks", Value: len(snapshot.PendingTasks)}) + + // Get or create queue + queue, exists := sm.broker.queues.Get(queueName) + if !exists { + queue = sm.broker.NewQueue(queueName) + } + + // Restore pending tasks + restored := 0 + for _, task := range snapshot.PendingTasks { + select { + case queue.tasks <- &QueuedTask{Task: task}: + restored++ + case <-ctx.Done(): + return ctx.Err() + default: + sm.logger.Warn("Queue full during restore, task skipped", + logger.Field{Key: "taskID", Value: task.ID}) + } + } + + // Restore metrics + if snapshot.Metrics != nil { + queue.metrics = snapshot.Metrics + } + + // Restore config + if snapshot.Config != nil { + queue.config = snapshot.Config + } + + sm.logger.Info("Queue restored from snapshot", + logger.Field{Key: "queue", Value: queueName}, + logger.Field{Key: "restoredTasks", Value: restored}) + + return nil +} + +// GetLatestSnapshot retrieves the latest snapshot for a queue +func (sm *SnapshotManager) GetLatestSnapshot(queueName string) (*QueueSnapshot, error) { + queueDir := filepath.Join(sm.baseDir, queueName) + + files, err := filepath.Glob(filepath.Join(queueDir, "snapshot-*.json")) + if err != nil { + return nil, fmt.Errorf("failed to list snapshots: %w", err) + } + + if len(files) == 0 { + return nil, nil + } + + // Sort files to get latest (files are named with timestamp) + latestFile := files[0] + for _, file := range files { + if file > latestFile { + latestFile = file + } + } + + // Read and parse snapshot + data, err := os.ReadFile(latestFile) + if err != nil { + return nil, fmt.Errorf("failed to read snapshot: %w", err) + } + + var snapshot QueueSnapshot + if err := json.Unmarshal(data, &snapshot); err != nil { + return nil, fmt.Errorf("failed to unmarshal snapshot: %w", err) + } + + return &snapshot, nil +} + +// ListSnapshots lists all snapshots for a queue +func (sm *SnapshotManager) ListSnapshots(queueName string) ([]*QueueSnapshot, error) { + queueDir := filepath.Join(sm.baseDir, queueName) + + files, err := filepath.Glob(filepath.Join(queueDir, "snapshot-*.json")) + if err != nil { + return nil, fmt.Errorf("failed to list snapshots: %w", err) + } + + snapshots := make([]*QueueSnapshot, 0, len(files)) + for _, file := range files { + data, err := os.ReadFile(file) + if err != nil { + sm.logger.Error("Failed to read snapshot file", + logger.Field{Key: "file", Value: file}, + logger.Field{Key: "error", Value: err}) + continue + } + + var snapshot QueueSnapshot + if err := json.Unmarshal(data, &snapshot); err != nil { + sm.logger.Error("Failed to unmarshal snapshot", + logger.Field{Key: "file", Value: file}, + logger.Field{Key: "error", Value: err}) + continue + } + + snapshots = append(snapshots, &snapshot) + } + + return snapshots, nil +} + +// CleanupOldSnapshots removes snapshots older than retention period +func (sm *SnapshotManager) CleanupOldSnapshots() error { + cutoff := time.Now().Add(-sm.retentionPeriod) + removed := 0 + + err := filepath.Walk(sm.baseDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() || filepath.Ext(path) != ".json" { + return nil + } + + if info.ModTime().Before(cutoff) { + if err := os.Remove(path); err != nil { + sm.logger.Error("Failed to remove old snapshot", + logger.Field{Key: "file", Value: path}, + logger.Field{Key: "error", Value: err}) + return nil + } + removed++ + } + + return nil + }) + + if err != nil { + return err + } + + if removed > 0 { + sm.logger.Info("Cleaned up old snapshots", + logger.Field{Key: "removed", Value: removed}) + } + + return nil +} + +// DeleteSnapshot deletes a specific snapshot +func (sm *SnapshotManager) DeleteSnapshot(queueName string, timestamp time.Time) error { + filename := fmt.Sprintf("snapshot-%d.json", timestamp.UnixNano()) + filePath := filepath.Join(sm.baseDir, queueName, filename) + + if err := os.Remove(filePath); err != nil { + return fmt.Errorf("failed to delete snapshot: %w", err) + } + + sm.logger.Info("Deleted snapshot", + logger.Field{Key: "queue", Value: queueName}, + logger.Field{Key: "timestamp", Value: timestamp}) + + return nil +} + +// GetSnapshotStats returns statistics about snapshots +func (sm *SnapshotManager) GetSnapshotStats() map[string]interface{} { + totalSnapshots := 0 + var totalSize int64 + + filepath.Walk(sm.baseDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if !info.IsDir() && filepath.Ext(path) == ".json" { + totalSnapshots++ + totalSize += info.Size() + } + + return nil + }) + + return map[string]interface{}{ + "total_snapshots": totalSnapshots, + "total_size_bytes": totalSize, + "retention_period": sm.retentionPeriod, + "snapshot_interval": sm.snapshotInterval, + } +} + +// Shutdown gracefully shuts down the snapshot manager +func (sm *SnapshotManager) Shutdown(ctx context.Context) error { + close(sm.shutdown) + + // Wait for workers to finish with timeout + done := make(chan struct{}) + go func() { + sm.wg.Wait() + close(done) + }() + + // Create final snapshot before shutdown + if err := sm.CreateSnapshotAll(ctx); err != nil { + sm.logger.Error("Failed to create final snapshot", + logger.Field{Key: "error", Value: err}) + } + + select { + case <-done: + sm.logger.Info("Snapshot manager shutdown complete") + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/task.go b/task.go index f88e235..dca1c46 100644 --- a/task.go +++ b/task.go @@ -115,8 +115,9 @@ type Task struct { CreatedAt time.Time `json:"created_at"` ProcessedAt time.Time `json:"processed_at"` Expiry time.Time `json:"expiry"` - Error error `json:"-"` // Don't serialize errors directly - ErrorMsg string `json:"error,omitempty"` // Serialize error message if present + DeferUntil time.Time `json:"defer_until,omitempty"` // NEW: For deferred task execution + Error error `json:"-"` // Don't serialize errors directly + ErrorMsg string `json:"error,omitempty"` // Serialize error message if present ID string `json:"id"` Topic string `json:"topic"` Status Status `json:"status"` // Use Status type instead of string @@ -232,6 +233,20 @@ func WithTTL(ttl time.Duration) TaskOption { } } +// TaskOption for deferring task execution until a specific time +func WithDeferUntil(deferUntil time.Time) TaskOption { + return func(t *Task) { + t.DeferUntil = deferUntil + } +} + +// TaskOption for deferring task execution for a specific duration +func WithDeferDuration(duration time.Duration) TaskOption { + return func(t *Task) { + t.DeferUntil = time.Now().Add(duration) + } +} + // TaskOption for adding tags func WithTags(tags map[string]string) TaskOption { return func(t *Task) { diff --git a/tracing.go b/tracing.go new file mode 100644 index 0000000..42fcaf4 --- /dev/null +++ b/tracing.go @@ -0,0 +1,562 @@ +package mq + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/oarkflow/mq/logger" +) + +// TraceEvent represents a single event in a message's lifecycle +type TraceEvent struct { + Timestamp time.Time `json:"timestamp"` + EventType TraceEventType `json:"event_type"` + Source string `json:"source"` + Details map[string]string `json:"details,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + Error string `json:"error,omitempty"` +} + +// TraceEventType defines types of trace events +type TraceEventType string + +const ( + TraceEventEnqueued TraceEventType = "ENQUEUED" + TraceEventDequeued TraceEventType = "DEQUEUED" + TraceEventProcessing TraceEventType = "PROCESSING" + TraceEventCompleted TraceEventType = "COMPLETED" + TraceEventFailed TraceEventType = "FAILED" + TraceEventRetried TraceEventType = "RETRIED" + TraceEventRejected TraceEventType = "REJECTED" + TraceEventAcked TraceEventType = "ACKED" + TraceEventNacked TraceEventType = "NACKED" + TraceEventTimeout TraceEventType = "TIMEOUT" + TraceEventDLQ TraceEventType = "DLQ" +) + +// MessageTrace represents the complete trace of a message +type MessageTrace struct { + TraceID string `json:"trace_id"` + SpanID string `json:"span_id"` + ParentSpanID string `json:"parent_span_id,omitempty"` + TaskID string `json:"task_id"` + QueueName string `json:"queue_name"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + Events []*TraceEvent `json:"events"` + Tags map[string]string `json:"tags,omitempty"` + Status string `json:"status"` + mu sync.RWMutex +} + +// TraceManager manages distributed tracing for messages +type TraceManager struct { + traces map[string]*MessageTrace + mu sync.RWMutex + logger logger.Logger + storage TraceStorage + shutdown chan struct{} + retention time.Duration + exportInterval time.Duration + onTraceComplete func(*MessageTrace) +} + +// TraceStorage interface for persistent trace storage +type TraceStorage interface { + Store(ctx context.Context, trace *MessageTrace) error + Get(ctx context.Context, traceID string) (*MessageTrace, error) + Query(ctx context.Context, filter TraceFilter) ([]*MessageTrace, error) + Delete(ctx context.Context, traceID string) error + DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) + Close() error +} + +// TraceFilter for querying traces +type TraceFilter struct { + QueueName string + Status string + StartTime time.Time + EndTime time.Time + MinDuration time.Duration + MaxDuration time.Duration + HasError bool + Limit int + Offset int +} + +// TraceConfig holds tracing configuration +type TraceConfig struct { + Storage TraceStorage + Retention time.Duration + ExportInterval time.Duration + Logger logger.Logger +} + +// NewTraceManager creates a new trace manager +func NewTraceManager(config TraceConfig) *TraceManager { + if config.Retention == 0 { + config.Retention = 24 * time.Hour + } + if config.ExportInterval == 0 { + config.ExportInterval = 30 * time.Second + } + + tm := &TraceManager{ + traces: make(map[string]*MessageTrace), + logger: config.Logger, + storage: config.Storage, + shutdown: make(chan struct{}), + retention: config.Retention, + exportInterval: config.ExportInterval, + } + + go tm.exportLoop() + go tm.cleanupLoop() + + return tm +} + +// StartTrace initiates tracing for a message +func (tm *TraceManager) StartTrace(ctx context.Context, task *Task, queueName string) *MessageTrace { + trace := &MessageTrace{ + TraceID: task.TraceID, + SpanID: task.SpanID, + TaskID: task.ID, + QueueName: queueName, + StartTime: time.Now(), + Events: make([]*TraceEvent, 0), + Tags: task.Tags, + Status: "in_progress", + } + + tm.mu.Lock() + tm.traces[trace.TraceID] = trace + tm.mu.Unlock() + + tm.RecordEvent(trace.TraceID, TraceEventEnqueued, "broker", nil) + + return trace +} + +// RecordEvent records an event in a trace +func (tm *TraceManager) RecordEvent(traceID string, eventType TraceEventType, source string, details map[string]string) { + tm.mu.RLock() + trace, exists := tm.traces[traceID] + tm.mu.RUnlock() + + if !exists { + tm.logger.Warn("Trace not found for event recording", + logger.Field{Key: "traceID", Value: traceID}, + logger.Field{Key: "eventType", Value: eventType}) + return + } + + trace.mu.Lock() + defer trace.mu.Unlock() + + event := &TraceEvent{ + Timestamp: time.Now(), + EventType: eventType, + Source: source, + Details: details, + } + + trace.Events = append(trace.Events, event) + + tm.logger.Debug("Recorded trace event", + logger.Field{Key: "traceID", Value: traceID}, + logger.Field{Key: "eventType", Value: eventType}, + logger.Field{Key: "source", Value: source}) +} + +// RecordError records an error event in a trace +func (tm *TraceManager) RecordError(traceID string, source string, err error, details map[string]string) { + tm.mu.RLock() + trace, exists := tm.traces[traceID] + tm.mu.RUnlock() + + if !exists { + return + } + + trace.mu.Lock() + defer trace.mu.Unlock() + + if details == nil { + details = make(map[string]string) + } + + event := &TraceEvent{ + Timestamp: time.Now(), + EventType: TraceEventFailed, + Source: source, + Details: details, + Error: err.Error(), + } + + trace.Events = append(trace.Events, event) + trace.Status = "failed" +} + +// CompleteTrace marks a trace as complete +func (tm *TraceManager) CompleteTrace(traceID string, status string) { + tm.mu.Lock() + trace, exists := tm.traces[traceID] + if !exists { + tm.mu.Unlock() + return + } + delete(tm.traces, traceID) + tm.mu.Unlock() + + trace.mu.Lock() + trace.EndTime = time.Now() + trace.Duration = trace.EndTime.Sub(trace.StartTime) + trace.Status = status + trace.mu.Unlock() + + // Store in persistent storage + if tm.storage != nil { + go func() { + if err := tm.storage.Store(context.Background(), trace); err != nil { + tm.logger.Error("Failed to store trace", + logger.Field{Key: "traceID", Value: traceID}, + logger.Field{Key: "error", Value: err}) + } + }() + } + + if tm.onTraceComplete != nil { + go tm.onTraceComplete(trace) + } + + tm.logger.Debug("Completed trace", + logger.Field{Key: "traceID", Value: traceID}, + logger.Field{Key: "duration", Value: trace.Duration}, + logger.Field{Key: "status", Value: status}) +} + +// GetTrace retrieves a trace by ID +func (tm *TraceManager) GetTrace(ctx context.Context, traceID string) (*MessageTrace, error) { + // Check in-memory traces first + tm.mu.RLock() + trace, exists := tm.traces[traceID] + tm.mu.RUnlock() + + if exists { + return trace, nil + } + + // Check persistent storage + if tm.storage != nil { + return tm.storage.Get(ctx, traceID) + } + + return nil, fmt.Errorf("trace not found: %s", traceID) +} + +// QueryTraces queries traces based on filter +func (tm *TraceManager) QueryTraces(ctx context.Context, filter TraceFilter) ([]*MessageTrace, error) { + if tm.storage != nil { + return tm.storage.Query(ctx, filter) + } + + // Fallback to in-memory search + tm.mu.RLock() + defer tm.mu.RUnlock() + + var results []*MessageTrace + for _, trace := range tm.traces { + if tm.matchesFilter(trace, filter) { + results = append(results, trace) + } + } + + return results, nil +} + +// matchesFilter checks if a trace matches the filter +func (tm *TraceManager) matchesFilter(trace *MessageTrace, filter TraceFilter) bool { + if filter.QueueName != "" && trace.QueueName != filter.QueueName { + return false + } + if filter.Status != "" && trace.Status != filter.Status { + return false + } + if !filter.StartTime.IsZero() && trace.StartTime.Before(filter.StartTime) { + return false + } + if !filter.EndTime.IsZero() && !trace.EndTime.IsZero() && trace.EndTime.After(filter.EndTime) { + return false + } + if filter.MinDuration > 0 && trace.Duration < filter.MinDuration { + return false + } + if filter.MaxDuration > 0 && trace.Duration > filter.MaxDuration { + return false + } + return true +} + +// exportLoop periodically exports traces to storage +func (tm *TraceManager) exportLoop() { + ticker := time.NewTicker(tm.exportInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + tm.exportTraces() + case <-tm.shutdown: + return + } + } +} + +// exportTraces exports active traces to storage +func (tm *TraceManager) exportTraces() { + if tm.storage == nil { + return + } + + tm.mu.RLock() + traces := make([]*MessageTrace, 0, len(tm.traces)) + for _, trace := range tm.traces { + traces = append(traces, trace) + } + tm.mu.RUnlock() + + for _, trace := range traces { + if err := tm.storage.Store(context.Background(), trace); err != nil { + tm.logger.Error("Failed to export trace", + logger.Field{Key: "traceID", Value: trace.TraceID}, + logger.Field{Key: "error", Value: err}) + } + } +} + +// cleanupLoop periodically cleans up old traces +func (tm *TraceManager) cleanupLoop() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + tm.cleanup() + case <-tm.shutdown: + return + } + } +} + +// cleanup removes old traces +func (tm *TraceManager) cleanup() { + if tm.storage != nil { + deleted, err := tm.storage.DeleteOlderThan(context.Background(), tm.retention) + if err != nil { + tm.logger.Error("Failed to cleanup old traces", + logger.Field{Key: "error", Value: err}) + } else if deleted > 0 { + tm.logger.Info("Cleaned up old traces", + logger.Field{Key: "deleted", Value: deleted}) + } + } +} + +// SetOnTraceComplete sets callback for trace completion +func (tm *TraceManager) SetOnTraceComplete(fn func(*MessageTrace)) { + tm.onTraceComplete = fn +} + +// GetStats returns tracing statistics +func (tm *TraceManager) GetStats() map[string]interface{} { + tm.mu.RLock() + defer tm.mu.RUnlock() + + activeTraces := len(tm.traces) + + return map[string]interface{}{ + "active_traces": activeTraces, + "retention": tm.retention, + "export_interval": tm.exportInterval, + } +} + +// Shutdown gracefully shuts down the trace manager +func (tm *TraceManager) Shutdown(ctx context.Context) error { + close(tm.shutdown) + + // Export remaining traces + tm.exportTraces() + + if tm.storage != nil { + return tm.storage.Close() + } + + return nil +} + +// InMemoryTraceStorage implements in-memory trace storage +type InMemoryTraceStorage struct { + traces map[string]*MessageTrace + mu sync.RWMutex +} + +// NewInMemoryTraceStorage creates a new in-memory trace storage +func NewInMemoryTraceStorage() *InMemoryTraceStorage { + return &InMemoryTraceStorage{ + traces: make(map[string]*MessageTrace), + } +} + +func (s *InMemoryTraceStorage) Store(ctx context.Context, trace *MessageTrace) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Deep copy to avoid concurrent modifications + data, err := json.Marshal(trace) + if err != nil { + return err + } + + var traceCopy MessageTrace + if err := json.Unmarshal(data, &traceCopy); err != nil { + return err + } + + s.traces[trace.TraceID] = &traceCopy + return nil +} + +func (s *InMemoryTraceStorage) Get(ctx context.Context, traceID string) (*MessageTrace, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + trace, exists := s.traces[traceID] + if !exists { + return nil, fmt.Errorf("trace not found: %s", traceID) + } + + return trace, nil +} + +func (s *InMemoryTraceStorage) Query(ctx context.Context, filter TraceFilter) ([]*MessageTrace, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var results []*MessageTrace + for _, trace := range s.traces { + results = append(results, trace) + } + + // Apply pagination + start := filter.Offset + if start >= len(results) { + return []*MessageTrace{}, nil + } + + end := start + filter.Limit + if filter.Limit == 0 || end > len(results) { + end = len(results) + } + + return results[start:end], nil +} + +func (s *InMemoryTraceStorage) Delete(ctx context.Context, traceID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.traces, traceID) + return nil +} + +func (s *InMemoryTraceStorage) DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + cutoff := time.Now().Add(-duration) + deleted := 0 + + for traceID, trace := range s.traces { + if !trace.EndTime.IsZero() && trace.EndTime.Before(cutoff) { + delete(s.traces, traceID) + deleted++ + } + } + + return deleted, nil +} + +func (s *InMemoryTraceStorage) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + s.traces = nil + return nil +} + +// MessageLifecycleTracker tracks the complete lifecycle of messages +type MessageLifecycleTracker struct { + traceManager *TraceManager + logger logger.Logger +} + +// NewMessageLifecycleTracker creates a new lifecycle tracker +func NewMessageLifecycleTracker(traceManager *TraceManager, log logger.Logger) *MessageLifecycleTracker { + return &MessageLifecycleTracker{ + traceManager: traceManager, + logger: log, + } +} + +// TrackEnqueue tracks message enqueue +func (lt *MessageLifecycleTracker) TrackEnqueue(ctx context.Context, task *Task, queueName string) { + lt.traceManager.StartTrace(ctx, task, queueName) +} + +// TrackDequeue tracks message dequeue +func (lt *MessageLifecycleTracker) TrackDequeue(traceID, consumerID string) { + lt.traceManager.RecordEvent(traceID, TraceEventDequeued, consumerID, nil) +} + +// TrackProcessing tracks message processing start +func (lt *MessageLifecycleTracker) TrackProcessing(traceID, workerID string) { + lt.traceManager.RecordEvent(traceID, TraceEventProcessing, workerID, nil) +} + +// TrackCompletion tracks successful completion +func (lt *MessageLifecycleTracker) TrackCompletion(traceID string, duration time.Duration) { + details := map[string]string{ + "duration": duration.String(), + } + lt.traceManager.RecordEvent(traceID, TraceEventCompleted, "worker", details) + lt.traceManager.CompleteTrace(traceID, "completed") +} + +// TrackError tracks processing error +func (lt *MessageLifecycleTracker) TrackError(traceID, source string, err error) { + lt.traceManager.RecordError(traceID, source, err, nil) +} + +// TrackRetry tracks retry attempt +func (lt *MessageLifecycleTracker) TrackRetry(traceID string, retryCount int) { + details := map[string]string{ + "retry_count": fmt.Sprintf("%d", retryCount), + } + lt.traceManager.RecordEvent(traceID, TraceEventRetried, "broker", details) +} + +// TrackDLQ tracks movement to dead letter queue +func (lt *MessageLifecycleTracker) TrackDLQ(traceID, reason string) { + details := map[string]string{ + "reason": reason, + } + lt.traceManager.RecordEvent(traceID, TraceEventDLQ, "broker", details) + lt.traceManager.CompleteTrace(traceID, "dlq") +} diff --git a/wal.go b/wal.go new file mode 100644 index 0000000..4f3ef39 --- /dev/null +++ b/wal.go @@ -0,0 +1,444 @@ +package mq + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/oarkflow/mq/logger" +) + +// WALEntry represents a single write-ahead log entry +type WALEntry struct { + EntryType WALEntryType `json:"entry_type"` + TaskID string `json:"task_id"` + QueueName string `json:"queue_name"` + Timestamp time.Time `json:"timestamp"` + Payload json.RawMessage `json:"payload,omitempty"` + SequenceID int64 `json:"sequence_id"` +} + +// WALEntryType defines the type of WAL entry +type WALEntryType string + +const ( + WALEntryEnqueue WALEntryType = "ENQUEUE" + WALEntryDequeue WALEntryType = "DEQUEUE" + WALEntryComplete WALEntryType = "COMPLETE" + WALEntryFailed WALEntryType = "FAILED" + WALEntryCheckpoint WALEntryType = "CHECKPOINT" +) + +// WriteAheadLog provides message durability through persistent logging +type WriteAheadLog struct { + dir string + currentFile *os.File + currentWriter *bufio.Writer + sequenceID int64 + maxFileSize int64 + syncInterval time.Duration + mu sync.Mutex + logger logger.Logger + shutdown chan struct{} + wg sync.WaitGroup + entries chan *WALEntry + fsyncOnWrite bool +} + +// WALConfig holds configuration for the WAL +type WALConfig struct { + Directory string + MaxFileSize int64 // Maximum file size before rotation + SyncInterval time.Duration // Interval for syncing to disk + FsyncOnWrite bool // Sync after every write (slower but more durable) + Logger logger.Logger +} + +// NewWriteAheadLog creates a new write-ahead log +func NewWriteAheadLog(config WALConfig) (*WriteAheadLog, error) { + if config.MaxFileSize == 0 { + config.MaxFileSize = 100 * 1024 * 1024 // 100MB default + } + if config.SyncInterval == 0 { + config.SyncInterval = 1 * time.Second + } + + if err := os.MkdirAll(config.Directory, 0755); err != nil { + return nil, fmt.Errorf("failed to create WAL directory: %w", err) + } + + wal := &WriteAheadLog{ + dir: config.Directory, + maxFileSize: config.MaxFileSize, + syncInterval: config.SyncInterval, + logger: config.Logger, + shutdown: make(chan struct{}), + entries: make(chan *WALEntry, 10000), + fsyncOnWrite: config.FsyncOnWrite, + } + + // Recover sequence ID from existing logs + if err := wal.recoverSequenceID(); err != nil { + return nil, fmt.Errorf("failed to recover sequence ID: %w", err) + } + + // Open or create current log file + if err := wal.openNewFile(); err != nil { + return nil, fmt.Errorf("failed to open WAL file: %w", err) + } + + // Start background workers + wal.wg.Add(2) + go wal.writeWorker() + go wal.syncWorker() + + return wal, nil +} + +// WriteEntry writes a new entry to the WAL +func (w *WriteAheadLog) WriteEntry(ctx context.Context, entry *WALEntry) error { + w.mu.Lock() + w.sequenceID++ + entry.SequenceID = w.sequenceID + entry.Timestamp = time.Now() + w.mu.Unlock() + + select { + case w.entries <- entry: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-w.shutdown: + return fmt.Errorf("WAL is shutting down") + } +} + +// writeWorker processes WAL entries in the background +func (w *WriteAheadLog) writeWorker() { + defer w.wg.Done() + + for { + select { + case entry := <-w.entries: + if err := w.writeEntryToFile(entry); err != nil { + w.logger.Error("Failed to write WAL entry", + logger.Field{Key: "error", Value: err}, + logger.Field{Key: "taskID", Value: entry.TaskID}) + } + case <-w.shutdown: + // Drain remaining entries + for len(w.entries) > 0 { + entry := <-w.entries + _ = w.writeEntryToFile(entry) + } + return + } + } +} + +// writeEntryToFile writes a single entry to the current WAL file +func (w *WriteAheadLog) writeEntryToFile(entry *WALEntry) error { + w.mu.Lock() + defer w.mu.Unlock() + + data, err := json.Marshal(entry) + if err != nil { + return fmt.Errorf("failed to marshal WAL entry: %w", err) + } + + // Write entry with newline delimiter + if _, err := w.currentWriter.Write(append(data, '\n')); err != nil { + return fmt.Errorf("failed to write WAL entry: %w", err) + } + + if w.fsyncOnWrite { + if err := w.currentWriter.Flush(); err != nil { + return fmt.Errorf("failed to flush WAL: %w", err) + } + if err := w.currentFile.Sync(); err != nil { + return fmt.Errorf("failed to sync WAL: %w", err) + } + } + + // Check if we need to rotate the file + stat, err := w.currentFile.Stat() + if err == nil && stat.Size() >= w.maxFileSize { + return w.rotateFile() + } + + return nil +} + +// syncWorker periodically syncs the WAL to disk +func (w *WriteAheadLog) syncWorker() { + defer w.wg.Done() + + ticker := time.NewTicker(w.syncInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + w.mu.Lock() + if w.currentWriter != nil { + _ = w.currentWriter.Flush() + } + if w.currentFile != nil { + _ = w.currentFile.Sync() + } + w.mu.Unlock() + case <-w.shutdown: + return + } + } +} + +// openNewFile creates a new WAL file +func (w *WriteAheadLog) openNewFile() error { + filename := fmt.Sprintf("wal-%d.log", time.Now().UnixNano()) + filepath := filepath.Join(w.dir, filename) + + file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return fmt.Errorf("failed to open WAL file: %w", err) + } + + w.currentFile = file + w.currentWriter = bufio.NewWriter(file) + + w.logger.Info("Opened new WAL file", logger.Field{Key: "filename", Value: filename}) + + return nil +} + +// rotateFile rotates to a new WAL file +func (w *WriteAheadLog) rotateFile() error { + // Flush and close current file + if w.currentWriter != nil { + if err := w.currentWriter.Flush(); err != nil { + return err + } + } + if w.currentFile != nil { + if err := w.currentFile.Close(); err != nil { + return err + } + } + + // Open new file + return w.openNewFile() +} + +// recoverSequenceID recovers the last sequence ID from existing WAL files +func (w *WriteAheadLog) recoverSequenceID() error { + files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log")) + if err != nil { + return err + } + + maxSeq := int64(0) + + for _, filepath := range files { + file, err := os.Open(filepath) + if err != nil { + continue + } + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + var entry WALEntry + if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil { + continue + } + if entry.SequenceID > maxSeq { + maxSeq = entry.SequenceID + } + } + + file.Close() + } + + w.sequenceID = maxSeq + w.logger.Info("Recovered sequence ID", logger.Field{Key: "sequenceID", Value: maxSeq}) + + return nil +} + +// Replay replays all WAL entries to recover state +func (w *WriteAheadLog) Replay(handler func(*WALEntry) error) error { + files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log")) + if err != nil { + return fmt.Errorf("failed to list WAL files: %w", err) + } + + // Sort files by name (which includes timestamp) + // Simple bubble sort since the list is typically small + for i := 0; i < len(files)-1; i++ { + for j := i + 1; j < len(files); j++ { + if files[i] > files[j] { + files[i], files[j] = files[j], files[i] + } + } + } + + entriesReplayed := 0 + + for _, filepath := range files { + file, err := os.Open(filepath) + if err != nil { + w.logger.Error("Failed to open WAL file for replay", + logger.Field{Key: "error", Value: err}, + logger.Field{Key: "file", Value: filepath}) + continue + } + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + var entry WALEntry + if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil { + w.logger.Error("Failed to unmarshal WAL entry", + logger.Field{Key: "error", Value: err}) + continue + } + + if err := handler(&entry); err != nil { + w.logger.Error("Failed to replay WAL entry", + logger.Field{Key: "error", Value: err}, + logger.Field{Key: "taskID", Value: entry.TaskID}) + continue + } + + entriesReplayed++ + } + + file.Close() + } + + w.logger.Info("WAL replay complete", + logger.Field{Key: "entries", Value: entriesReplayed}) + + return nil +} + +// Checkpoint writes a checkpoint entry and optionally truncates old logs +func (w *WriteAheadLog) Checkpoint(ctx context.Context, state map[string]interface{}) error { + stateData, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("failed to marshal checkpoint state: %w", err) + } + + entry := &WALEntry{ + EntryType: WALEntryCheckpoint, + TaskID: "checkpoint", + Payload: stateData, + } + + return w.WriteEntry(ctx, entry) +} + +// TruncateOldLogs removes old WAL files (called after checkpoint) +func (w *WriteAheadLog) TruncateOldLogs(keepRecent int) error { + files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log")) + if err != nil { + return fmt.Errorf("failed to list WAL files: %w", err) + } + + // Sort files by modification time + type fileInfo struct { + path string + modTime time.Time + } + + var fileInfos []fileInfo + for _, path := range files { + stat, err := os.Stat(path) + if err != nil { + continue + } + fileInfos = append(fileInfos, fileInfo{path: path, modTime: stat.ModTime()}) + } + + // Sort by modification time (newest first) + for i := 0; i < len(fileInfos)-1; i++ { + for j := i + 1; j < len(fileInfos); j++ { + if fileInfos[i].modTime.Before(fileInfos[j].modTime) { + fileInfos[i], fileInfos[j] = fileInfos[j], fileInfos[i] + } + } + } + + // Remove old files + removed := 0 + for i := keepRecent; i < len(fileInfos); i++ { + if err := os.Remove(fileInfos[i].path); err != nil { + w.logger.Error("Failed to remove old WAL file", + logger.Field{Key: "error", Value: err}, + logger.Field{Key: "file", Value: fileInfos[i].path}) + continue + } + removed++ + } + + w.logger.Info("Truncated old WAL files", + logger.Field{Key: "removed", Value: removed}) + + return nil +} + +// Shutdown gracefully shuts down the WAL +func (w *WriteAheadLog) Shutdown(ctx context.Context) error { + close(w.shutdown) + + // Wait for workers to finish with timeout + done := make(chan struct{}) + go func() { + w.wg.Wait() + close(done) + }() + + select { + case <-done: + w.mu.Lock() + defer w.mu.Unlock() + + if w.currentWriter != nil { + _ = w.currentWriter.Flush() + } + if w.currentFile != nil { + _ = w.currentFile.Sync() + _ = w.currentFile.Close() + } + + w.logger.Info("WAL shutdown complete") + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// GetStats returns statistics about the WAL +func (w *WriteAheadLog) GetStats() map[string]interface{} { + w.mu.Lock() + defer w.mu.Unlock() + + var currentFileSize int64 + if w.currentFile != nil { + if stat, err := w.currentFile.Stat(); err == nil { + currentFileSize = stat.Size() + } + } + + files, _ := filepath.Glob(filepath.Join(w.dir, "wal-*.log")) + + return map[string]interface{}{ + "current_sequence_id": w.sequenceID, + "current_file_size": currentFileSize, + "total_files": len(files), + "entries_backlog": len(w.entries), + } +}