This commit is contained in:
sujit
2025-10-01 12:31:58 +05:45
parent 0340774cfe
commit 8ec9474bdc
20 changed files with 5015 additions and 1123 deletions

1
.gitignore vendored
View File

@@ -24,3 +24,4 @@ go.work
*.svg *.svg
.qodo .qodo
.history .history
*.log

View File

@@ -1,416 +0,0 @@
# Enhanced Services with DAG + Workflow Engine
## Overview
The enhanced services architecture successfully integrates all workflow engine features into the DAG system, providing complete feature parity and backward compatibility. This upgrade provides both traditional DAG functionality and advanced workflow capabilities through a unified service layer.
## Architecture Components
### 1. Enhanced Service Manager (`enhanced_setup.go`)
- **Purpose**: Core service orchestration with DAG + workflow integration
- **Features**:
- Dual-mode execution (Traditional DAG + Enhanced Workflow)
- HTTP API endpoints for workflow management
- Enhanced validation with workflow rule support
- Service health monitoring and metrics
- Background task management
### 2. Enhanced Contracts (`enhanced_contracts.go`)
- **Purpose**: Service interfaces for DAG + workflow integration
- **Key Interfaces**:
- `EnhancedServiceManager`: Core service management
- `EnhancedDAGService`: Dual-mode DAG operations
- `EnhancedValidation`: Workflow validation rules
- `EnhancedHandler`: Unified handler structure
### 3. Enhanced DAG Service (`enhanced_dag_service.go`)
- **Purpose**: DAG service with workflow engine capabilities
- **Features**:
- Traditional DAG execution (backward compatibility)
- Enhanced workflow execution with advanced processors
- State management and persistence
- Execution result handling with proper field mapping
### 4. Enhanced Validation (`enhanced_validation.go`)
- **Purpose**: Validation service with workflow rule support
- **Features**:
- Schema validation with workflow rules
- Field-level validation (string, email, numeric, etc.)
- Custom validation logic with processor integration
- Validation result aggregation
## Features Implemented
### Complete Workflow Engine Integration ✅
All 8 advanced processors from the workflow engine are now available in the DAG system:
1. **Validator Processor**: Schema and field validation
2. **Router Processor**: Conditional routing and decision making
3. **Transformer Processor**: Data transformation and mapping
4. **Aggregator Processor**: Data aggregation and summarization
5. **Filter Processor**: Data filtering and selection
6. **Sorter Processor**: Data sorting and ordering
7. **Notify Processor**: Notification and messaging
8. **Storage Processor**: Data persistence and retrieval
### Enhanced DAG Capabilities ✅
- **Dual Mode Support**: Both traditional DAG and workflow modes
- **Advanced Retry Logic**: Exponential backoff with circuit breaker
- **State Management**: Persistent execution state tracking
- **Scheduling**: Background task scheduling and execution
- **Security**: Authentication and authorization support
- **Middleware**: Pre/post execution hooks
- **Metrics**: Performance monitoring and reporting
### HTTP API Integration ✅
Complete REST API for workflow management:
- `GET /api/v1/handlers` - List all handlers
- `POST /api/v1/execute/:key` - Execute workflow by key
- `GET /api/v1/workflows` - List workflow instances
- `POST /api/v1/workflows/:id/execute` - Execute specific workflow
- `GET /health` - Service health check
### Validation System ✅
Enhanced validation with workflow rule support:
- Field-level validation rules
- Type checking (string, email, numeric, etc.)
- Length constraints (min/max)
- Required field validation
- Custom validation messages
- Validation result aggregation
## Usage Examples
### 1. Traditional DAG Mode (Backward Compatibility)
```go
// Traditional DAG handler
handler := services.EnhancedHandler{
Key: "traditional-dag",
Name: "Traditional DAG",
WorkflowEnabled: false, // Use traditional DAG mode
Nodes: []services.EnhancedNode{
{
ID: "start",
Name: "Start Process",
Node: "basic",
FirstNode: true,
},
{
ID: "process",
Name: "Process Data",
Node: "basic",
},
},
Edges: []services.Edge{
{Source: "start", Target: []string{"process"}},
},
}
```
### 2. Enhanced Workflow Mode
```go
// Enhanced workflow handler with processors
handler := services.EnhancedHandler{
Key: "enhanced-workflow",
Name: "Enhanced Workflow",
WorkflowEnabled: true, // Use enhanced workflow mode
ValidationRules: []*dag.WorkflowValidationRule{
{
Field: "email",
Type: "email",
Required: true,
Message: "Valid email is required",
},
},
Nodes: []services.EnhancedNode{
{
ID: "validate-input",
Name: "Validate Input",
Type: "validator",
ProcessorType: "validator",
},
{
ID: "route-data",
Name: "Route Decision",
Type: "router",
ProcessorType: "router",
},
{
ID: "transform-data",
Name: "Transform Data",
Type: "transformer",
ProcessorType: "transformer",
},
},
Edges: []services.Edge{
{Source: "validate-input", Target: []string{"route-data"}},
{Source: "route-data", Target: []string{"transform-data"}},
},
}
```
### 3. Service Configuration
```go
config := &services.EnhancedServiceConfig{
BrokerURL: "nats://localhost:4222",
Debug: true,
// Enhanced DAG configuration
EnhancedDAGConfig: &dag.EnhancedDAGConfig{
EnableWorkflowEngine: true,
MaintainDAGMode: true,
EnableStateManagement: true,
EnableAdvancedRetry: true,
EnableCircuitBreaker: true,
MaxConcurrentExecutions: 10,
DefaultTimeout: 30 * time.Second,
},
// Workflow engine configuration
WorkflowEngineConfig: &dag.WorkflowEngineConfig{
MaxConcurrentExecutions: 5,
DefaultTimeout: 2 * time.Minute,
EnablePersistence: true,
EnableSecurity: true,
RetryConfig: &dag.RetryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Second,
BackoffFactor: 2.0,
},
},
}
```
### 4. Service Initialization
```go
// Create enhanced service manager
manager := services.NewEnhancedServiceManager(config)
// Initialize services
if err := manager.Initialize(config); err != nil {
log.Fatalf("Failed to initialize services: %v", err)
}
// Start services
ctx := context.Background()
if err := manager.Start(ctx); err != nil {
log.Fatalf("Failed to start services: %v", err)
}
defer manager.Stop(ctx)
// Register handlers
for _, handler := range handlers {
if err := manager.RegisterEnhancedHandler(handler); err != nil {
log.Printf("Failed to register handler %s: %v", handler.Key, err)
}
}
```
### 5. HTTP API Setup
```go
// Create Fiber app
app := fiber.New()
// Register HTTP routes
if err := manager.RegisterHTTPRoutes(app); err != nil {
log.Fatalf("Failed to register HTTP routes: %v", err)
}
// Start server
log.Fatal(app.Listen(":3000"))
```
### 6. Workflow Execution
```go
// Execute workflow programmatically
ctx := context.Background()
input := map[string]any{
"name": "John Doe",
"email": "john@example.com",
}
result, err := manager.ExecuteEnhancedWorkflow(ctx, "enhanced-workflow", input)
if err != nil {
log.Printf("Execution failed: %v", err)
} else {
log.Printf("Execution completed: %s (Status: %s)", result.ID, result.Status)
}
```
## HTTP API Usage
### Execute Workflow via REST API
```bash
# Execute workflow with POST request
curl -X POST http://localhost:3000/api/v1/execute/enhanced-workflow \
-H "Content-Type: application/json" \
-d '{
"name": "John Doe",
"email": "john@example.com",
"age": 30
}'
```
### List Available Handlers
```bash
# Get list of registered handlers
curl -X GET http://localhost:3000/api/v1/handlers
```
### Health Check
```bash
# Check service health
curl -X GET http://localhost:3000/health
```
## Advanced Features
### 1. Validation Rules
The enhanced validation system supports comprehensive field validation:
```go
ValidationRules: []*dag.WorkflowValidationRule{
{
Field: "name",
Type: "string",
Required: true,
MinLength: 2,
MaxLength: 50,
Message: "Name must be 2-50 characters",
},
{
Field: "email",
Type: "email",
Required: true,
Message: "Valid email is required",
},
{
Field: "age",
Type: "number",
Min: 18,
Max: 120,
Message: "Age must be between 18 and 120",
},
}
```
### 2. Processor Configuration
Each processor can be configured with specific parameters:
```go
Config: dag.WorkflowNodeConfig{
// Validator processor config
ValidationType: "schema",
ValidationRules: []dag.WorkflowValidationRule{...},
// Router processor config
RoutingRules: []dag.RoutingRule{...},
// Transformer processor config
TransformationRules: []dag.TransformationRule{...},
// Storage processor config
StorageType: "memory",
StorageConfig: map[string]any{...},
}
```
### 3. Error Handling and Retry
Built-in retry logic with exponential backoff:
```go
RetryConfig: &dag.RetryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
}
```
### 4. State Management
Persistent execution state tracking:
```go
EnhancedDAGConfig: &dag.EnhancedDAGConfig{
EnableStateManagement: true,
EnablePersistence: true,
}
```
## Migration Guide
### From Traditional DAG to Enhanced Services
1. **Keep existing DAG handlers**: Set `WorkflowEnabled: false`
2. **Add enhanced features gradually**: Create new handlers with `WorkflowEnabled: true`
3. **Use validation rules**: Add `ValidationRules` for input validation
4. **Configure processors**: Set appropriate `ProcessorType` for each node
5. **Test both modes**: Verify traditional and enhanced workflows work correctly
### Configuration Migration
```go
// Before (traditional)
config := &services.ServiceConfig{
BrokerURL: "nats://localhost:4222",
}
// After (enhanced)
config := &services.EnhancedServiceConfig{
BrokerURL: "nats://localhost:4222",
EnhancedDAGConfig: &dag.EnhancedDAGConfig{
EnableWorkflowEngine: true,
MaintainDAGMode: true, // Keep backward compatibility
},
}
```
## Performance Considerations
1. **Concurrent Executions**: Configure `MaxConcurrentExecutions` based on system resources
2. **Timeout Settings**: Set appropriate `DefaultTimeout` for workflow complexity
3. **Retry Strategy**: Balance retry attempts with system load
4. **State Management**: Enable persistence only when needed
5. **Metrics**: Monitor performance with built-in metrics
## Troubleshooting
### Common Issues
1. **Handler Registration Fails**
- Check validation rules syntax
- Verify processor types are valid
- Ensure node dependencies are correct
2. **Workflow Execution Errors**
- Validate input data format
- Check processor configurations
- Review error logs for details
3. **HTTP API Issues**
- Verify routes are registered correctly
- Check request format and headers
- Review service health status
### Debug Mode
Enable debug mode for detailed logging:
```go
config := &services.EnhancedServiceConfig{
Debug: true,
// ... other config
}
```
## Conclusion
The enhanced services architecture successfully provides complete feature parity between the DAG system and workflow engine. All workflow engine features are now available in the DAG system while maintaining full backward compatibility with existing traditional DAG implementations.
Key achievements:
- ✅ Complete workflow engine integration (8 advanced processors)
- ✅ Dual-mode support (traditional DAG + enhanced workflow)
- ✅ HTTP API for workflow management
- ✅ Enhanced validation with workflow rules
- ✅ Service health monitoring and metrics
- ✅ Backward compatibility maintained
- ✅ Production-ready architecture
The system now provides a unified, powerful, and flexible platform for both simple DAG operations and complex workflow orchestration.

View File

@@ -1,469 +0,0 @@
# Complete Workflow Engine Documentation
## Overview
This is a **production-ready, enterprise-grade workflow engine** built on top of the existing DAG system. It provides comprehensive workflow orchestration capabilities with support for complex business processes, data pipelines, approval workflows, and automated task execution.
## 🎯 Key Features
### Core Capabilities
-**Workflow Definition & Management** - JSON-based workflow definitions with versioning
-**Multi-Node Type Support** - Task, API, Transform, Decision, Human Task, Timer, Loop, Parallel, Database, Email, Webhook
-**Advanced Execution Engine** - DAG-based execution with state management and error handling
-**Flexible Scheduling** - Support for immediate, delayed, and conditional execution
-**RESTful API** - Complete HTTP API for workflow management and execution
-**Real-time Monitoring** - Execution tracking, metrics, and health monitoring
-**Error Handling & Recovery** - Retry policies, rollback support, and checkpoint recovery
### Enterprise Features
-**Scalable Architecture** - Worker pool management and concurrent execution
-**Data Persistence** - In-memory storage with extensible storage interface
-**Security Framework** - Authentication, authorization, and CORS support
-**Audit & Tracing** - Complete execution history and tracing capabilities
-**Variable Management** - Runtime variables and templating support
-**Condition-based Routing** - Dynamic workflow paths based on conditions
## 📁 Project Structure
```
workflow/
├── types.go # Core types and interfaces
├── processors.go # Node type processors (Task, API, Transform, etc.)
├── registry.go # Workflow definition storage and management
├── engine.go # Main workflow execution engine
├── api.go # HTTP API handlers and routes
├── demo/
│ └── main.go # Comprehensive demonstration
└── example/
└── main.go # Simple usage examples
```
## 🚀 Quick Start
### 1. Import the Package
```go
import "github.com/oarkflow/mq/workflow"
```
### 2. Create and Start Engine
```go
config := &workflow.Config{
MaxWorkers: 10,
ExecutionTimeout: 30 * time.Minute,
EnableMetrics: true,
EnableAudit: true,
}
engine := workflow.NewWorkflowEngine(config)
ctx := context.Background()
engine.Start(ctx)
defer engine.Stop(ctx)
```
### 3. Define a Workflow
```go
workflow := &workflow.WorkflowDefinition{
ID: "sample-workflow",
Name: "Sample Data Processing",
Description: "A simple data processing workflow",
Version: "1.0.0",
Status: workflow.WorkflowStatusActive,
Nodes: []workflow.WorkflowNode{
{
ID: "fetch-data",
Name: "Fetch Data",
Type: workflow.NodeTypeAPI,
Config: workflow.NodeConfig{
URL: "https://api.example.com/data",
Method: "GET",
},
},
{
ID: "process-data",
Name: "Process Data",
Type: workflow.NodeTypeTransform,
Config: workflow.NodeConfig{
TransformType: "json_path",
Expression: "$.data",
},
},
},
Edges: []workflow.WorkflowEdge{
{
ID: "fetch-to-process",
FromNode: "fetch-data",
ToNode: "process-data",
},
},
}
// Register workflow
engine.RegisterWorkflow(ctx, workflow)
```
### 4. Execute Workflow
```go
execution, err := engine.ExecuteWorkflow(ctx, "sample-workflow", map[string]any{
"input_data": "test_value",
}, &workflow.ExecutionOptions{
Priority: workflow.PriorityMedium,
Owner: "user123",
})
if err != nil {
log.Fatal(err)
}
fmt.Printf("Execution started: %s\n", execution.ID)
```
## 🏗️ Node Types
The workflow engine supports various node types for different use cases:
### Task Node
Execute custom scripts or commands
```go
{
Type: workflow.NodeTypeTask,
Config: workflow.NodeConfig{
Script: "console.log('Processing:', ${data})",
},
}
```
### API Node
Make HTTP requests to external services
```go
{
Type: workflow.NodeTypeAPI,
Config: workflow.NodeConfig{
URL: "https://api.service.com/endpoint",
Method: "POST",
Headers: map[string]string{
"Authorization": "Bearer ${token}",
},
},
}
```
### Transform Node
Transform and manipulate data
```go
{
Type: workflow.NodeTypeTransform,
Config: workflow.NodeConfig{
TransformType: "json_path",
Expression: "$.users[*].email",
},
}
```
### Decision Node
Conditional routing based on rules
```go
{
Type: workflow.NodeTypeDecision,
Config: workflow.NodeConfig{
Rules: []workflow.Rule{
{
Condition: "age >= 18",
Output: "adult",
NextNode: "adult-process",
},
{
Condition: "age < 18",
Output: "minor",
NextNode: "minor-process",
},
},
},
}
```
### Human Task Node
Wait for human intervention
```go
{
Type: workflow.NodeTypeHumanTask,
Config: workflow.NodeConfig{
Custom: map[string]any{
"assignee": "manager@company.com",
"due_date": "3 days",
"description": "Please review and approve",
},
},
}
```
### Timer Node
Add delays or scheduled execution
```go
{
Type: workflow.NodeTypeTimer,
Config: workflow.NodeConfig{
Duration: 30 * time.Second,
Schedule: "0 9 * * 1", // Every Monday at 9 AM
},
}
```
### Database Node
Execute database operations
```go
{
Type: workflow.NodeTypeDatabase,
Config: workflow.NodeConfig{
Query: "INSERT INTO logs (message, created_at) VALUES (?, ?)",
Connection: "main_db",
},
}
```
### Email Node
Send email notifications
```go
{
Type: workflow.NodeTypeEmail,
Config: workflow.NodeConfig{
To: []string{"user@example.com"},
Subject: "Workflow Completed",
Body: "Your workflow has completed successfully.",
},
}
```
## 🌐 REST API Endpoints
### Workflow Management
```
POST /api/v1/workflows # Create workflow
GET /api/v1/workflows # List workflows
GET /api/v1/workflows/:id # Get workflow
PUT /api/v1/workflows/:id # Update workflow
DELETE /api/v1/workflows/:id # Delete workflow
GET /api/v1/workflows/:id/versions # Get versions
```
### Execution Management
```
POST /api/v1/workflows/:id/execute # Execute workflow
GET /api/v1/workflows/:id/executions # List workflow executions
GET /api/v1/workflows/executions # List all executions
GET /api/v1/workflows/executions/:id # Get execution
POST /api/v1/workflows/executions/:id/cancel # Cancel execution
POST /api/v1/workflows/executions/:id/suspend# Suspend execution
POST /api/v1/workflows/executions/:id/resume # Resume execution
```
### Monitoring
```
GET /api/v1/workflows/health # Health check
GET /api/v1/workflows/metrics # System metrics
```
## 🎮 Demo Application
Run the comprehensive demo to see all features:
```bash
cd /Users/sujit/Sites/mq
go build -o workflow-demo ./workflow/demo
./workflow-demo
```
The demo includes:
- **Data Processing Workflow** - API integration, validation, transformation, and storage
- **Approval Workflow** - Multi-stage human task workflow with conditional routing
- **ETL Pipeline** - Parallel data processing with complex transformations
Demo endpoints:
- `http://localhost:3000/` - Main API info
- `http://localhost:3000/demo/workflows` - View registered workflows
- `http://localhost:3000/demo/executions` - View running executions
- `http://localhost:3000/api/v1/workflows/health` - Health check
## 🔧 Configuration
### Engine Configuration
```go
config := &workflow.Config{
MaxWorkers: 10, // Concurrent execution workers
ExecutionTimeout: 30 * time.Minute, // Maximum execution time
EnableMetrics: true, // Enable metrics collection
EnableAudit: true, // Enable audit logging
EnableTracing: true, // Enable execution tracing
LogLevel: "info", // Logging level
Storage: workflow.StorageConfig{
Type: "memory", // Storage backend
MaxConnections: 100, // Max storage connections
},
Security: workflow.SecurityConfig{
EnableAuth: false, // Enable authentication
AllowedOrigins: []string{"*"}, // CORS allowed origins
},
}
```
### Workflow Configuration
```go
config := workflow.WorkflowConfig{
Timeout: &timeout, // Workflow timeout
MaxRetries: 3, // Maximum retry attempts
Priority: workflow.PriorityMedium, // Execution priority
Concurrency: 5, // Concurrent node execution
ErrorHandling: workflow.ErrorHandling{
OnFailure: "stop", // stop, continue, retry
MaxErrors: 3, // Maximum errors allowed
Rollback: false, // Enable rollback on failure
},
}
```
## 📊 Execution Monitoring
### Execution Status
- `pending` - Execution is queued
- `running` - Currently executing
- `completed` - Finished successfully
- `failed` - Execution failed
- `cancelled` - Manually cancelled
- `suspended` - Temporarily suspended
### Execution Context
Each execution maintains:
- **Variables** - Runtime variables and data
- **Trace** - Complete execution history
- **Checkpoints** - Recovery points
- **Metadata** - Additional context information
### Node Execution Tracking
Each node execution tracks:
- Input/Output data
- Execution duration
- Error information
- Retry attempts
- Execution logs
## 🔒 Security Features
### Authentication & Authorization
- Configurable authentication system
- Role-based access control
- API key management
- JWT token support
### Data Security
- Input/output data encryption
- Secure variable storage
- Audit trail logging
- CORS protection
## 🚀 Performance Features
### Scalability
- Horizontal scaling support
- Worker pool management
- Concurrent execution
- Resource optimization
### Optimization
- DAG-based execution optimization
- Caching strategies
- Memory management
- Performance monitoring
## 🔧 Extensibility
### Custom Node Types
Add custom processors by implementing the `WorkflowProcessor` interface:
```go
type CustomProcessor struct {
Config workflow.NodeConfig
}
func (p *CustomProcessor) Process(ctx context.Context, data []byte) mq.Result {
// Custom processing logic
return mq.Result{Payload: processedData}
}
func (p *CustomProcessor) Close() error {
// Cleanup logic
return nil
}
```
### Storage Backends
Implement custom storage by satisfying the interfaces:
- `WorkflowRegistry` - Workflow definition storage
- `StateManager` - Execution state management
### Custom Middleware
Add middleware for cross-cutting concerns:
- Logging
- Metrics collection
- Authentication
- Rate limiting
## 📈 Production Considerations
### Monitoring & Observability
- Implement proper logging
- Set up metrics collection
- Configure health checks
- Enable distributed tracing
### High Availability
- Database clustering
- Load balancing
- Failover mechanisms
- Backup strategies
### Security Hardening
- Enable authentication
- Implement proper RBAC
- Secure API endpoints
- Audit logging
## 🎯 Use Cases
This workflow engine is perfect for:
1. **Data Processing Pipelines** - ETL/ELT operations, data validation, transformation
2. **Business Process Automation** - Approval workflows, document processing, compliance
3. **Integration Workflows** - API orchestration, system integration, event processing
4. **DevOps Automation** - CI/CD pipelines, deployment workflows, infrastructure automation
5. **Notification Systems** - Multi-channel notifications, escalation workflows
6. **Content Management** - Publishing workflows, review processes, content distribution
## ✅ Production Readiness Checklist
The workflow engine includes all production-ready features:
-**Comprehensive Type System** - Full type definitions for all components
-**Multiple Node Processors** - 11+ different node types for various use cases
-**Storage & Registry** - Versioned workflow storage with filtering and pagination
-**Execution Engine** - DAG-based execution with state management
-**Scheduling System** - Delayed execution and workflow scheduling
-**REST API** - Complete HTTP API with all CRUD operations
-**Error Handling** - Comprehensive error handling and recovery
-**Monitoring** - Health checks, metrics, and execution tracking
-**Security** - Authentication, authorization, and CORS support
-**Scalability** - Worker pools, concurrency control, and resource management
-**Extensibility** - Plugin architecture for custom processors and storage
-**Documentation** - Complete documentation with examples and demos
## 🎉 Conclusion
This complete workflow engine provides everything needed for production enterprise workflow automation. It combines the power of the existing DAG system with modern workflow orchestration capabilities, making it suitable for a wide range of business applications.
The engine is designed to be:
- **Powerful** - Handles complex workflows with conditional routing and parallel processing
- **Flexible** - Supports multiple node types and custom extensions
- **Scalable** - Built for high-throughput production environments
- **Reliable** - Comprehensive error handling and recovery mechanisms
- **Observable** - Full monitoring, tracing, and metrics capabilities
- **Secure** - Enterprise-grade security features
Start building your workflows today! 🚀

View File

@@ -1,176 +0,0 @@
# Enhanced DAG + Workflow Engine Integration - COMPLETE
## 🎯 Mission Accomplished!
**Original Question**: "Does DAG covers entire features of workflow engine from workflow folder? If not implement them"
**Answer**: ✅ **YES! The DAG system now has COMPLETE feature parity with the workflow engine and more!**
## 🏆 What Was Accomplished
### 1. Complete Workflow Processor Integration
All advanced workflow processors from the workflow engine are now fully integrated into the DAG system:
-**HTML Processor** - Generate HTML content from templates
-**SMS Processor** - Send SMS notifications via multiple providers
-**Auth Processor** - Handle authentication and authorization
-**Validator Processor** - Data validation with custom rules
-**Router Processor** - Conditional routing based on rules
-**Storage Processor** - Data persistence across multiple backends
-**Notification Processor** - Multi-channel notifications
-**Webhook Receiver Processor** - Handle incoming webhook requests
### 2. Complete Workflow Engine Integration
The entire workflow engine is now integrated into the DAG system:
-**WorkflowEngineManager** - Central orchestration and management
-**WorkflowRegistry** - Workflow definition management
-**AdvancedWorkflowStateManager** - Execution state tracking
-**WorkflowScheduler** - Time-based workflow execution
-**WorkflowExecutor** - Workflow execution engine
-**ProcessorFactory** - Dynamic processor creation and registration
### 3. Enhanced Data Types and Configurations
Extended the DAG system with advanced workflow data types:
-**WorkflowValidationRule** - Field validation with custom rules
-**WorkflowRoutingRule** - Conditional routing logic
-**WorkflowNodeConfig** - Enhanced node configuration
-**WorkflowExecution** - Execution tracking and management
-**RetryConfig** - Advanced retry policies
-**ScheduledTask** - Time-based execution scheduling
### 4. Advanced Features Integration
All advanced workflow features are now part of the DAG system:
-**Security & Authentication** - Built-in security features
-**Middleware Support** - Request/response processing
-**Circuit Breaker** - Fault tolerance and resilience
-**Advanced Retry Logic** - Configurable retry policies
-**State Persistence** - Durable state management
-**Metrics & Monitoring** - Performance tracking
-**Scheduling** - Cron-based and time-based execution
## 📁 Files Created/Enhanced
### Core Integration Files
1. **`dag/workflow_processors.go`** (NEW)
- Complete implementation of all 8 advanced workflow processors
- BaseProcessor providing common functionality
- Full interface compliance with WorkflowProcessor
2. **`dag/workflow_factory.go`** (NEW)
- ProcessorFactory for dynamic processor creation
- Registration system for all processor types
- Integration with workflow engine components
3. **`dag/workflow_engine.go`** (NEW)
- Complete workflow engine implementation
- WorkflowEngineManager with all core components
- Registry, state management, scheduling, and execution
4. **`dag/enhanced_dag.go`** (ENHANCED)
- Extended with new workflow node types
- Enhanced WorkflowNodeConfig with all workflow features
- Integration points for workflow engine
### Demo and Examples
5. **`examples/final_integration_demo.go`** (NEW)
- Comprehensive demonstration of all integrated features
- Working examples of processor creation and workflow execution
- Validation that all components work together
## 🔧 Technical Achievements
### Integration Architecture
- **Unified System**: DAG + Workflow Engine = Single, powerful orchestration platform
- **Backward Compatibility**: All existing DAG functionality preserved
- **Enhanced Capabilities**: Workflow features enhance DAG beyond original capabilities
- **Production Ready**: Proper error handling, resource management, and cleanup
### Code Quality
- **Type Safety**: All interfaces properly implemented
- **Error Handling**: Comprehensive error handling throughout
- **Resource Management**: Proper cleanup and resource disposal
- **Documentation**: Extensive comments and documentation
### Performance
- **Efficient Execution**: Optimized processor creation and execution
- **Memory Management**: Proper resource cleanup and memory management
- **Concurrent Execution**: Support for concurrent workflow execution
- **Scalability**: Configurable concurrency and resource limits
## 🎯 Feature Parity Comparison
| Feature Category | Original Workflow | Enhanced DAG | Status |
|-----------------|-------------------|--------------|---------|
| Basic Processors | ✓ Available | ✓ Integrated | ✅ COMPLETE |
| Advanced Processors | ✓ 8 Processors | ✓ All 8 Integrated | ✅ COMPLETE |
| Processor Factory | ✓ Available | ✓ Integrated | ✅ COMPLETE |
| Workflow Engine | ✓ Available | ✓ Integrated | ✅ COMPLETE |
| State Management | ✓ Available | ✓ Enhanced | ✅ ENHANCED |
| Scheduling | ✓ Available | ✓ Enhanced | ✅ ENHANCED |
| Security | ✓ Available | ✓ Enhanced | ✅ ENHANCED |
| Middleware | ✓ Available | ✓ Enhanced | ✅ ENHANCED |
| DAG Visualization | ❌ Not Available | ✓ Available | ✅ ADDED |
| Advanced Retry | ✓ Basic | ✓ Enhanced | ✅ ENHANCED |
| Execution Tracking | ✓ Available | ✓ Enhanced | ✅ ENHANCED |
| Recovery | ✓ Basic | ✓ Advanced | ✅ ENHANCED |
## 🧪 Validation & Testing
### Compilation Status
-`workflow_processors.go` - No errors
-`workflow_factory.go` - No errors
-`workflow_engine.go` - No errors
-`enhanced_dag.go` - No errors
-`final_integration_demo.go` - No errors
### Integration Testing
- ✅ All 8 advanced processors can be created successfully
- ✅ Workflow engine starts and manages executions
- ✅ State management creates and tracks executions
- ✅ Registry manages workflow definitions
- ✅ Processor factory creates all processor types
- ✅ Enhanced DAG integrates with workflow engine
## 🚀 Usage Examples
The enhanced DAG can now handle complex workflows like:
```go
// Create enhanced DAG with workflow capabilities
config := &dag.EnhancedDAGConfig{
EnableWorkflowEngine: true,
EnableStateManagement: true,
EnableAdvancedRetry: true,
}
enhancedDAG, _ := dag.NewEnhancedDAG("workflow", "key", config)
// Create workflow engine with all features
engine := dag.NewWorkflowEngineManager(&dag.WorkflowEngineConfig{
MaxConcurrentExecutions: 10,
EnableSecurity: true,
EnableScheduling: true,
})
// Use any of the 8 advanced processors
factory := engine.GetProcessorFactory()
htmlProcessor, _ := factory.CreateProcessor("html", config)
smsProcessor, _ := factory.CreateProcessor("sms", config)
// ... and 6 more advanced processors
```
## 🎉 Conclusion
**Mission Status: ✅ COMPLETE SUCCESS!**
The DAG system now has **COMPLETE feature parity** with the workflow engine from the workflow folder, plus additional enhancements that make it even more powerful:
1. **All workflow engine features** are now part of the DAG system
2. **All 8 advanced processors** are fully integrated and functional
3. **Enhanced capabilities** beyond the original workflow engine
4. **Backward compatibility** with existing DAG functionality maintained
5. **Production-ready integration** with proper error handling and resource management
The enhanced DAG system is now a **unified, comprehensive workflow orchestration platform** that combines the best of both DAG and workflow engine capabilities!

414
ack_system.go Normal file
View File

@@ -0,0 +1,414 @@
package mq
import (
"context"
"fmt"
"sync"
"time"
"github.com/oarkflow/mq/logger"
)
// AckType defines the type of acknowledgment
type AckType int
const (
// ACK indicates successful processing
ACK AckType = iota
// NACK indicates processing failure, message should be requeued
NACK
// REJECT indicates message should be moved to DLQ
REJECT
)
// MessageAck represents an acknowledgment for a message
type MessageAck struct {
TaskID string
AckType AckType
Reason string
Timestamp time.Time
}
// PendingMessage represents a message awaiting acknowledgment
type PendingMessage struct {
Task *Task
QueueName string
SentAt time.Time
Deadline time.Time
RetryCount int
ConsumerID string
}
// AckManager manages message acknowledgments and ensures at-least-once delivery
type AckManager struct {
pending map[string]*PendingMessage
mu sync.RWMutex
ackTimeout time.Duration
maxRetries int
redeliverChan chan *PendingMessage
ackChan chan MessageAck
shutdown chan struct{}
logger logger.Logger
onRedeliver func(*PendingMessage)
onAck func(*PendingMessage)
onNack func(*PendingMessage)
onReject func(*PendingMessage)
onTimeout func(*PendingMessage)
}
// AckManagerConfig holds configuration for AckManager
type AckManagerConfig struct {
AckTimeout time.Duration
MaxRetries int
CheckInterval time.Duration
Logger logger.Logger
}
// NewAckManager creates a new acknowledgment manager
func NewAckManager(config AckManagerConfig) *AckManager {
if config.AckTimeout == 0 {
config.AckTimeout = 30 * time.Second
}
if config.MaxRetries == 0 {
config.MaxRetries = 3
}
if config.CheckInterval == 0 {
config.CheckInterval = 5 * time.Second
}
am := &AckManager{
pending: make(map[string]*PendingMessage),
ackTimeout: config.AckTimeout,
maxRetries: config.MaxRetries,
redeliverChan: make(chan *PendingMessage, 1000),
ackChan: make(chan MessageAck, 1000),
shutdown: make(chan struct{}),
logger: config.Logger,
}
// Start background workers
go am.processAcks()
go am.checkTimeouts(config.CheckInterval)
return am
}
// TrackMessage adds a message to the pending acknowledgment list
func (am *AckManager) TrackMessage(ctx context.Context, task *Task, queueName, consumerID string) error {
am.mu.Lock()
defer am.mu.Unlock()
if _, exists := am.pending[task.ID]; exists {
return fmt.Errorf("message already being tracked: %s", task.ID)
}
now := time.Now()
pending := &PendingMessage{
Task: task,
QueueName: queueName,
SentAt: now,
Deadline: now.Add(am.ackTimeout),
RetryCount: task.Retries,
ConsumerID: consumerID,
}
am.pending[task.ID] = pending
am.logger.Debug("Message tracked for acknowledgment",
logger.Field{Key: "taskID", Value: task.ID},
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "consumer", Value: consumerID})
return nil
}
// Acknowledge processes an acknowledgment for a message
func (am *AckManager) Acknowledge(ctx context.Context, taskID string, ackType AckType, reason string) error {
ack := MessageAck{
TaskID: taskID,
AckType: ackType,
Reason: reason,
Timestamp: time.Now(),
}
select {
case am.ackChan <- ack:
return nil
case <-ctx.Done():
return ctx.Err()
case <-am.shutdown:
return fmt.Errorf("ack manager is shutting down")
}
}
// processAcks processes incoming acknowledgments
func (am *AckManager) processAcks() {
for {
select {
case ack := <-am.ackChan:
am.handleAck(ack)
case <-am.shutdown:
return
}
}
}
// handleAck handles a single acknowledgment
func (am *AckManager) handleAck(ack MessageAck) {
am.mu.Lock()
pending, exists := am.pending[ack.TaskID]
if !exists {
am.mu.Unlock()
am.logger.Warn("Received ACK for unknown message",
logger.Field{Key: "taskID", Value: ack.TaskID})
return
}
delete(am.pending, ack.TaskID)
am.mu.Unlock()
switch ack.AckType {
case ACK:
pending.Task.Status = Completed
pending.Task.ProcessedAt = time.Now()
am.logger.Info("Message acknowledged successfully",
logger.Field{Key: "taskID", Value: ack.TaskID},
logger.Field{Key: "queue", Value: pending.QueueName})
if am.onAck != nil {
am.onAck(pending)
}
case NACK:
pending.RetryCount++
if pending.RetryCount < am.maxRetries {
am.logger.Info("Message NACKed, requeuing",
logger.Field{Key: "taskID", Value: ack.TaskID},
logger.Field{Key: "retryCount", Value: pending.RetryCount},
logger.Field{Key: "reason", Value: ack.Reason})
am.redeliverChan <- pending
if am.onNack != nil {
am.onNack(pending)
}
} else {
am.logger.Warn("Message exceeded max retries, rejecting",
logger.Field{Key: "taskID", Value: ack.TaskID},
logger.Field{Key: "retries", Value: pending.RetryCount})
pending.Task.Status = Failed
if am.onReject != nil {
am.onReject(pending)
}
}
case REJECT:
pending.Task.Status = Failed
am.logger.Warn("Message rejected",
logger.Field{Key: "taskID", Value: ack.TaskID},
logger.Field{Key: "reason", Value: ack.Reason})
if am.onReject != nil {
am.onReject(pending)
}
}
}
// checkTimeouts periodically checks for messages that have exceeded their acknowledgment timeout
func (am *AckManager) checkTimeouts(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
am.processTimeouts()
case <-am.shutdown:
return
}
}
}
// processTimeouts finds and handles timed-out messages
func (am *AckManager) processTimeouts() {
now := time.Now()
am.mu.Lock()
var timedOut []*PendingMessage
for taskID, pending := range am.pending {
if now.After(pending.Deadline) {
timedOut = append(timedOut, pending)
delete(am.pending, taskID)
}
}
am.mu.Unlock()
for _, pending := range timedOut {
pending.RetryCount++
if pending.RetryCount < am.maxRetries {
am.logger.Warn("Message acknowledgment timeout, requeuing",
logger.Field{Key: "taskID", Value: pending.Task.ID},
logger.Field{Key: "retryCount", Value: pending.RetryCount},
logger.Field{Key: "queue", Value: pending.QueueName})
am.redeliverChan <- pending
if am.onTimeout != nil {
am.onTimeout(pending)
}
} else {
am.logger.Error("Message exceeded max retries after timeout",
logger.Field{Key: "taskID", Value: pending.Task.ID},
logger.Field{Key: "retries", Value: pending.RetryCount})
pending.Task.Status = Failed
if am.onReject != nil {
am.onReject(pending)
}
}
}
}
// GetRedeliverChannel returns the channel for redelivering messages
func (am *AckManager) GetRedeliverChannel() <-chan *PendingMessage {
return am.redeliverChan
}
// SetOnRedeliver sets the callback for message redelivery
func (am *AckManager) SetOnRedeliver(fn func(*PendingMessage)) {
am.onRedeliver = fn
}
// SetOnAck sets the callback for successful acknowledgments
func (am *AckManager) SetOnAck(fn func(*PendingMessage)) {
am.onAck = fn
}
// SetOnNack sets the callback for negative acknowledgments
func (am *AckManager) SetOnNack(fn func(*PendingMessage)) {
am.onNack = fn
}
// SetOnReject sets the callback for rejected messages
func (am *AckManager) SetOnReject(fn func(*PendingMessage)) {
am.onReject = fn
}
// SetOnTimeout sets the callback for timed-out messages
func (am *AckManager) SetOnTimeout(fn func(*PendingMessage)) {
am.onTimeout = fn
}
// GetPendingCount returns the number of pending acknowledgments
func (am *AckManager) GetPendingCount() int {
am.mu.RLock()
defer am.mu.RUnlock()
return len(am.pending)
}
// GetPendingMessages returns all pending messages (for monitoring/debugging)
func (am *AckManager) GetPendingMessages() []*PendingMessage {
am.mu.RLock()
defer am.mu.RUnlock()
messages := make([]*PendingMessage, 0, len(am.pending))
for _, msg := range am.pending {
messages = append(messages, msg)
}
return messages
}
// CancelPending removes a message from pending tracking (e.g., consumer disconnected)
func (am *AckManager) CancelPending(taskID string) {
am.mu.Lock()
defer am.mu.Unlock()
if pending, exists := am.pending[taskID]; exists {
delete(am.pending, taskID)
am.logger.Debug("Cancelled pending message",
logger.Field{Key: "taskID", Value: taskID},
logger.Field{Key: "queue", Value: pending.QueueName})
}
}
// CancelAllForConsumer removes all pending messages for a specific consumer
func (am *AckManager) CancelAllForConsumer(consumerID string) int {
am.mu.Lock()
defer am.mu.Unlock()
cancelled := 0
for taskID, pending := range am.pending {
if pending.ConsumerID == consumerID {
delete(am.pending, taskID)
cancelled++
// Optionally requeue these messages
go func(p *PendingMessage) {
am.redeliverChan <- p
}(pending)
}
}
if cancelled > 0 {
am.logger.Info("Cancelled pending messages for disconnected consumer",
logger.Field{Key: "consumerID", Value: consumerID},
logger.Field{Key: "count", Value: cancelled})
}
return cancelled
}
// Shutdown gracefully shuts down the acknowledgment manager
func (am *AckManager) Shutdown(ctx context.Context) error {
close(am.shutdown)
// Wait for pending acknowledgments with timeout
timeout := time.After(30 * time.Second)
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-timeout:
am.mu.RLock()
pendingCount := len(am.pending)
am.mu.RUnlock()
return fmt.Errorf("shutdown timeout with %d pending messages", pendingCount)
case <-ticker.C:
am.mu.RLock()
pendingCount := len(am.pending)
am.mu.RUnlock()
if pendingCount == 0 {
am.logger.Info("AckManager shutdown complete")
return nil
}
}
}
}
// GetStats returns statistics about the acknowledgment manager
func (am *AckManager) GetStats() map[string]interface{} {
am.mu.RLock()
defer am.mu.RUnlock()
var oldestPending time.Time
var totalWaitTime time.Duration
now := time.Now()
for _, pending := range am.pending {
if oldestPending.IsZero() || pending.SentAt.Before(oldestPending) {
oldestPending = pending.SentAt
}
totalWaitTime += now.Sub(pending.SentAt)
}
avgWaitTime := time.Duration(0)
if len(am.pending) > 0 {
avgWaitTime = totalWaitTime / time.Duration(len(am.pending))
}
return map[string]interface{}{
"pending_count": len(am.pending),
"oldest_pending": oldestPending,
"avg_wait_time": avgWaitTime,
"ack_timeout": am.ackTimeout,
"max_retries": am.maxRetries,
"redeliver_backlog": len(am.redeliverChan),
"ack_backlog": len(am.ackChan),
}
}

View File

@@ -344,6 +344,18 @@ func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.
) )
d.server = mq.NewBroker(opts...) d.server = mq.NewBroker(opts...)
// Initialize all enhanced features for DAG broker
enhancedConfig := mq.DefaultBrokerEnhancedConfig()
enhancedConfig.Logger = d.server.Options().Logger()
enhancedConfig.EnableEnhancements = true
// Initialize enhanced features (DLQ, WAL, ACK, dedup, flow control, etc.)
if err := d.server.InitializeEnhancements(enhancedConfig); err != nil {
log.Printf("[WARN] Failed to initialize enhanced features for DAG: %v", err)
} else {
log.Printf("[INFO] Enhanced features initialized for DAG: %s", name)
}
// Now initialize enhanced features that need the server // Now initialize enhanced features that need the server
logger := d.server.Options().Logger() logger := d.server.Options().Logger()
d.validator = NewDAGValidator(d) d.validator = NewDAGValidator(d)

511
dedup_and_flow.go Normal file
View File

@@ -0,0 +1,511 @@
package mq
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"runtime"
"sync"
"time"
"github.com/oarkflow/mq/logger"
)
// DedupEntry represents a deduplication cache entry
type DedupEntry struct {
MessageID string
ContentHash string
FirstSeen time.Time
LastSeen time.Time
Count int
}
// DeduplicationManager manages message deduplication
type DeduplicationManager struct {
cache map[string]*DedupEntry
mu sync.RWMutex
window time.Duration
cleanupInterval time.Duration
shutdown chan struct{}
logger logger.Logger
persistent DedupStorage
onDuplicate func(*DedupEntry)
}
// DedupStorage interface for persistent deduplication storage
type DedupStorage interface {
Store(ctx context.Context, entry *DedupEntry) error
Get(ctx context.Context, key string) (*DedupEntry, error)
Delete(ctx context.Context, key string) error
DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error)
Close() error
}
// DedupConfig holds configuration for deduplication
type DedupConfig struct {
Window time.Duration // Time window for deduplication
CleanupInterval time.Duration
Persistent DedupStorage // Optional persistent storage
Logger logger.Logger
}
// NewDeduplicationManager creates a new deduplication manager
func NewDeduplicationManager(config DedupConfig) *DeduplicationManager {
if config.Window == 0 {
config.Window = 5 * time.Minute
}
if config.CleanupInterval == 0 {
config.CleanupInterval = 1 * time.Minute
}
dm := &DeduplicationManager{
cache: make(map[string]*DedupEntry),
window: config.Window,
cleanupInterval: config.CleanupInterval,
shutdown: make(chan struct{}),
logger: config.Logger,
persistent: config.Persistent,
}
go dm.cleanupLoop()
return dm
}
// CheckDuplicate checks if a message is a duplicate
func (dm *DeduplicationManager) CheckDuplicate(ctx context.Context, task *Task) (bool, error) {
// Generate dedup key from task
dedupKey := dm.generateDedupKey(task)
dm.mu.Lock()
defer dm.mu.Unlock()
// Check in-memory cache
if entry, exists := dm.cache[dedupKey]; exists {
// Check if within window
if time.Since(entry.FirstSeen) < dm.window {
entry.LastSeen = time.Now()
entry.Count++
if dm.onDuplicate != nil {
go dm.onDuplicate(entry)
}
dm.logger.Debug("Duplicate message detected",
logger.Field{Key: "dedupKey", Value: dedupKey},
logger.Field{Key: "count", Value: entry.Count},
logger.Field{Key: "taskID", Value: task.ID})
return true, nil
}
// Entry expired, remove it
delete(dm.cache, dedupKey)
}
// Check persistent storage if available
if dm.persistent != nil {
entry, err := dm.persistent.Get(ctx, dedupKey)
if err == nil && time.Since(entry.FirstSeen) < dm.window {
entry.LastSeen = time.Now()
entry.Count++
dm.cache[dedupKey] = entry
if dm.onDuplicate != nil {
go dm.onDuplicate(entry)
}
return true, nil
}
}
// Not a duplicate, add to cache
entry := &DedupEntry{
MessageID: task.ID,
ContentHash: dedupKey,
FirstSeen: time.Now(),
LastSeen: time.Now(),
Count: 1,
}
dm.cache[dedupKey] = entry
// Persist if storage available
if dm.persistent != nil {
go dm.persistent.Store(ctx, entry)
}
return false, nil
}
// generateDedupKey generates a deduplication key from a task
func (dm *DeduplicationManager) generateDedupKey(task *Task) string {
// If task has explicit dedup key, use it
if task.DedupKey != "" {
return task.DedupKey
}
// Otherwise, hash the content
hasher := sha256.New()
hasher.Write([]byte(task.Topic))
hasher.Write(task.Payload)
// Include headers in hash for more precise deduplication
if task.Headers != nil {
headerBytes, _ := json.Marshal(task.Headers)
hasher.Write(headerBytes)
}
return hex.EncodeToString(hasher.Sum(nil))
}
// cleanupLoop periodically cleans up expired entries
func (dm *DeduplicationManager) cleanupLoop() {
ticker := time.NewTicker(dm.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
dm.cleanup()
case <-dm.shutdown:
return
}
}
}
// cleanup removes expired deduplication entries
func (dm *DeduplicationManager) cleanup() {
dm.mu.Lock()
defer dm.mu.Unlock()
cutoff := time.Now().Add(-dm.window)
removed := 0
for key, entry := range dm.cache {
if entry.FirstSeen.Before(cutoff) {
delete(dm.cache, key)
removed++
}
}
if removed > 0 {
dm.logger.Debug("Cleaned up expired dedup entries",
logger.Field{Key: "removed", Value: removed})
}
// Cleanup persistent storage
if dm.persistent != nil {
go dm.persistent.DeleteOlderThan(context.Background(), dm.window)
}
}
// SetOnDuplicate sets callback for duplicate detection
func (dm *DeduplicationManager) SetOnDuplicate(fn func(*DedupEntry)) {
dm.onDuplicate = fn
}
// GetStats returns deduplication statistics
func (dm *DeduplicationManager) GetStats() map[string]interface{} {
dm.mu.RLock()
defer dm.mu.RUnlock()
totalDuplicates := 0
for _, entry := range dm.cache {
totalDuplicates += entry.Count - 1 // Subtract 1 for original message
}
return map[string]interface{}{
"cache_size": len(dm.cache),
"total_duplicates": totalDuplicates,
"window": dm.window,
}
}
// Shutdown stops the deduplication manager
func (dm *DeduplicationManager) Shutdown(ctx context.Context) error {
close(dm.shutdown)
if dm.persistent != nil {
return dm.persistent.Close()
}
return nil
}
// FlowController manages backpressure and flow control
type FlowController struct {
credits int64
maxCredits int64
minCredits int64
creditRefillRate int64
mu sync.Mutex
logger logger.Logger
shutdown chan struct{}
refillInterval time.Duration
onCreditLow func(current, max int64)
onCreditHigh func(current, max int64)
}
// FlowControlConfig holds flow control configuration
type FlowControlConfig struct {
MaxCredits int64
MinCredits int64
RefillRate int64 // Credits to add per interval
RefillInterval time.Duration
Logger logger.Logger
}
// NewFlowController creates a new flow controller
func NewFlowController(config FlowControlConfig) *FlowController {
if config.MaxCredits == 0 {
config.MaxCredits = 1000
}
if config.MinCredits == 0 {
config.MinCredits = 100
}
if config.RefillRate == 0 {
config.RefillRate = 10
}
if config.RefillInterval == 0 {
config.RefillInterval = 100 * time.Millisecond
}
fc := &FlowController{
credits: config.MaxCredits,
maxCredits: config.MaxCredits,
minCredits: config.MinCredits,
creditRefillRate: config.RefillRate,
refillInterval: config.RefillInterval,
logger: config.Logger,
shutdown: make(chan struct{}),
}
go fc.refillLoop()
return fc
}
// AcquireCredit attempts to acquire credits for processing
func (fc *FlowController) AcquireCredit(ctx context.Context, amount int64) error {
for {
fc.mu.Lock()
if fc.credits >= amount {
fc.credits -= amount
// Check if credits are low
if fc.credits < fc.minCredits && fc.onCreditLow != nil {
go fc.onCreditLow(fc.credits, fc.maxCredits)
}
fc.mu.Unlock()
return nil
}
fc.mu.Unlock()
// Wait before retrying
select {
case <-time.After(10 * time.Millisecond):
continue
case <-ctx.Done():
return ctx.Err()
case <-fc.shutdown:
return fmt.Errorf("flow controller shutting down")
}
}
}
// ReleaseCredit returns credits after processing
func (fc *FlowController) ReleaseCredit(amount int64) {
fc.mu.Lock()
defer fc.mu.Unlock()
fc.credits += amount
if fc.credits > fc.maxCredits {
fc.credits = fc.maxCredits
}
// Check if credits recovered
if fc.credits > fc.maxCredits/2 && fc.onCreditHigh != nil {
go fc.onCreditHigh(fc.credits, fc.maxCredits)
}
}
// refillLoop periodically refills credits
func (fc *FlowController) refillLoop() {
ticker := time.NewTicker(fc.refillInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
fc.mu.Lock()
fc.credits += fc.creditRefillRate
if fc.credits > fc.maxCredits {
fc.credits = fc.maxCredits
}
fc.mu.Unlock()
case <-fc.shutdown:
return
}
}
}
// GetAvailableCredits returns the current available credits
func (fc *FlowController) GetAvailableCredits() int64 {
fc.mu.Lock()
defer fc.mu.Unlock()
return fc.credits
}
// SetOnCreditLow sets callback for low credit warning
func (fc *FlowController) SetOnCreditLow(fn func(current, max int64)) {
fc.onCreditLow = fn
}
// SetOnCreditHigh sets callback for credit recovery
func (fc *FlowController) SetOnCreditHigh(fn func(current, max int64)) {
fc.onCreditHigh = fn
}
// AdjustMaxCredits dynamically adjusts maximum credits
func (fc *FlowController) AdjustMaxCredits(newMax int64) {
fc.mu.Lock()
defer fc.mu.Unlock()
fc.maxCredits = newMax
if fc.credits > newMax {
fc.credits = newMax
}
fc.logger.Info("Adjusted max credits",
logger.Field{Key: "newMax", Value: newMax})
}
// GetStats returns flow control statistics
func (fc *FlowController) GetStats() map[string]interface{} {
fc.mu.Lock()
defer fc.mu.Unlock()
utilization := float64(fc.maxCredits-fc.credits) / float64(fc.maxCredits) * 100
return map[string]interface{}{
"credits": fc.credits,
"max_credits": fc.maxCredits,
"min_credits": fc.minCredits,
"utilization": utilization,
"refill_rate": fc.creditRefillRate,
}
}
// Shutdown stops the flow controller
func (fc *FlowController) Shutdown() {
close(fc.shutdown)
}
// BackpressureMonitor monitors system backpressure
type BackpressureMonitor struct {
queueDepthThreshold int
memoryThreshold uint64
errorRateThreshold float64
checkInterval time.Duration
logger logger.Logger
shutdown chan struct{}
onBackpressureApplied func(reason string)
onBackpressureRelieved func()
}
// BackpressureConfig holds backpressure configuration
type BackpressureConfig struct {
QueueDepthThreshold int
MemoryThreshold uint64
ErrorRateThreshold float64
CheckInterval time.Duration
Logger logger.Logger
}
// NewBackpressureMonitor creates a new backpressure monitor
func NewBackpressureMonitor(config BackpressureConfig) *BackpressureMonitor {
if config.CheckInterval == 0 {
config.CheckInterval = 5 * time.Second
}
if config.ErrorRateThreshold == 0 {
config.ErrorRateThreshold = 0.5 // 50% error rate
}
bm := &BackpressureMonitor{
queueDepthThreshold: config.QueueDepthThreshold,
memoryThreshold: config.MemoryThreshold,
errorRateThreshold: config.ErrorRateThreshold,
checkInterval: config.CheckInterval,
logger: config.Logger,
shutdown: make(chan struct{}),
}
go bm.monitorLoop()
return bm
}
// monitorLoop continuously monitors for backpressure conditions
func (bm *BackpressureMonitor) monitorLoop() {
ticker := time.NewTicker(bm.checkInterval)
defer ticker.Stop()
backpressureActive := false
for {
select {
case <-ticker.C:
shouldApply, reason := bm.shouldApplyBackpressure()
if shouldApply && !backpressureActive {
backpressureActive = true
bm.logger.Warn("Applying backpressure",
logger.Field{Key: "reason", Value: reason})
if bm.onBackpressureApplied != nil {
bm.onBackpressureApplied(reason)
}
} else if !shouldApply && backpressureActive {
backpressureActive = false
bm.logger.Info("Relieving backpressure")
if bm.onBackpressureRelieved != nil {
bm.onBackpressureRelieved()
}
}
case <-bm.shutdown:
return
}
}
}
// shouldApplyBackpressure checks if backpressure should be applied
func (bm *BackpressureMonitor) shouldApplyBackpressure() (bool, string) {
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// Check memory threshold
if bm.memoryThreshold > 0 && memStats.Alloc > bm.memoryThreshold {
return true, fmt.Sprintf("memory threshold exceeded: %d > %d",
memStats.Alloc, bm.memoryThreshold)
}
return false, ""
}
// SetOnBackpressureApplied sets callback for backpressure application
func (bm *BackpressureMonitor) SetOnBackpressureApplied(fn func(reason string)) {
bm.onBackpressureApplied = fn
}
// SetOnBackpressureRelieved sets callback for backpressure relief
func (bm *BackpressureMonitor) SetOnBackpressureRelieved(fn func()) {
bm.onBackpressureRelieved = fn
}
// Shutdown stops the backpressure monitor
func (bm *BackpressureMonitor) Shutdown() {
close(bm.shutdown)
}

173
deferred_tasks.go Normal file
View File

@@ -0,0 +1,173 @@
package mq
import (
"context"
"log"
"time"
"github.com/oarkflow/mq/consts"
)
// processDeferredTasks continuously checks for deferred tasks that are ready to be executed
func (b *Broker) processDeferredTasks(ctx context.Context) {
ticker := time.NewTicker(1 * time.Second) // Check every second
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-b.stopDeferredChan:
return
case <-ticker.C:
now := time.Now()
var tasksToProcess []*QueuedTask
var taskIDsToRemove []string
// Collect tasks that are ready to be processed
b.deferredTasks.ForEach(func(taskID string, queuedTask *QueuedTask) bool {
if queuedTask != nil && queuedTask.Task != nil {
if queuedTask.Task.DeferUntil.Before(now) || queuedTask.Task.DeferUntil.Equal(now) {
tasksToProcess = append(tasksToProcess, queuedTask)
taskIDsToRemove = append(taskIDsToRemove, taskID)
}
}
return true
})
// Process and remove ready tasks
for i, task := range tasksToProcess {
if task != nil && task.Message != nil {
queueName := task.Message.Queue
if queue, ok := b.queues.Get(queueName); ok {
// Send task to the queue for processing
select {
case queue.tasks <- task:
log.Printf("[DEFERRED] Task %s is now ready for processing on queue %s",
task.Task.ID, queueName)
b.deferredTasks.Del(taskIDsToRemove[i])
default:
// Queue is full, keep in deferred state
log.Printf("[DEFERRED] Queue %s is full, task %s will retry",
queueName, task.Task.ID)
}
} else {
log.Printf("[DEFERRED] Queue %s not found for task %s",
queueName, task.Task.ID)
b.deferredTasks.Del(taskIDsToRemove[i])
}
}
}
}
}
}
// StartDeferredTaskProcessor starts the background processor for deferred tasks
func (b *Broker) StartDeferredTaskProcessor(ctx context.Context) {
go b.processDeferredTasks(ctx)
log.Println("[DEFERRED] Deferred task processor started")
}
// StopDeferredTaskProcessor stops the deferred task processor
func (b *Broker) StopDeferredTaskProcessor() {
close(b.stopDeferredChan)
log.Println("[DEFERRED] Deferred task processor stopped")
}
// AddDeferredTask adds a task to the deferred queue
func (b *Broker) AddDeferredTask(task *QueuedTask) {
if task != nil && task.Task != nil {
b.deferredTasks.Set(task.Task.ID, task)
log.Printf("[DEFERRED] Task %s deferred until %s",
task.Task.ID, task.Task.DeferUntil.Format(time.RFC3339))
}
}
// CreateDLQConsumer creates a consumer for a DLQ queue with a retry handler
func (b *Broker) CreateDLQConsumer(ctx context.Context, queueName string, retryHandler Handler) error {
dlqName := queueName + "_dlq"
// Check if DLQ exists
dlq, ok := b.deadLetter.Get(queueName)
if !ok {
return nil // No DLQ for this queue
}
// Create a consumer for the DLQ
consumerID := "dlq-consumer-" + queueName
// Define DLQ processing handler
dlqHandler := func(ctx context.Context, task *Task) Result {
log.Printf("[DLQ] Processing task %s from DLQ %s (attempt %d)",
task.ID, dlqName, task.Retries+1)
// Call custom retry handler if provided
if retryHandler != nil {
result := retryHandler(ctx, task)
if result.Status == Completed {
log.Printf("[DLQ] Task %s successfully reprocessed from DLQ", task.ID)
return result
}
}
// Default behavior: retry with exponential backoff
task.Retries++
if task.Retries < task.MaxRetries {
// Defer for exponential backoff
backoffDuration := time.Duration(1<<uint(task.Retries)) * time.Second
task.DeferUntil = time.Now().Add(backoffDuration)
log.Printf("[DLQ] Task %s will be retried in %s", task.ID, backoffDuration)
return Result{
Status: Processing,
TaskID: task.ID,
Ctx: ctx,
Payload: task.Payload,
}
}
// Max retries exceeded, mark as permanently failed
log.Printf("[DLQ] Task %s exceeded max retries, permanently failed", task.ID)
return Result{
Status: Failed,
TaskID: task.ID,
Ctx: ctx,
Payload: task.Payload,
Error: task.Error,
}
}
// Register the DLQ handler
b.dlqHandlers.Set(dlqName, dlqHandler)
// Create a consumer struct for the DLQ
dlqConsumer := &consumer{
id: consumerID,
state: consts.ConsumerStateActive,
conn: nil, // Internal consumer, no network connection
}
dlq.consumers.Set(consumerID, dlqConsumer)
b.consumers.Set(consumerID, dlqConsumer)
log.Printf("[DLQ] Consumer %s created for DLQ %s", consumerID, dlqName)
return nil
}
// SetupDLQConsumers sets up consumers for all DLQ queues
func (b *Broker) SetupDLQConsumers(ctx context.Context, retryHandler Handler) {
b.deadLetter.ForEach(func(queueName string, dlq *Queue) bool {
// Extract original queue name (remove _dlq suffix)
originalQueue := queueName
if len(queueName) > 4 {
originalQueue = queueName[:len(queueName)-4]
}
err := b.CreateDLQConsumer(ctx, originalQueue, retryHandler)
if err != nil {
log.Printf("[DLQ] Failed to create consumer for DLQ %s: %v", queueName, err)
}
return true
})
log.Println("[DLQ] DLQ consumers setup complete")
}

399
dlq_storage.go Normal file
View File

@@ -0,0 +1,399 @@
package mq
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/oarkflow/mq/logger"
)
// DLQStorage defines the interface for Dead Letter Queue persistent storage
type DLQStorage interface {
// Store persists a failed task to the DLQ
Store(ctx context.Context, task *DLQEntry) error
// Get retrieves a task from the DLQ by ID
Get(ctx context.Context, taskID string) (*DLQEntry, error)
// List returns a paginated list of DLQ entries
List(ctx context.Context, offset, limit int, filter DLQFilter) ([]*DLQEntry, error)
// Delete removes a task from the DLQ
Delete(ctx context.Context, taskID string) error
// DeleteOlderThan removes entries older than the specified duration
DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error)
// Count returns the total number of entries
Count(ctx context.Context, filter DLQFilter) (int64, error)
// Close closes the storage
Close() error
}
// DLQEntry represents a dead letter queue entry
type DLQEntry struct {
TaskID string `json:"task_id"`
QueueName string `json:"queue_name"`
OriginalPayload json.RawMessage `json:"original_payload"`
ErrorMessage string `json:"error_message"`
ErrorType string `json:"error_type"`
FailedAt time.Time `json:"failed_at"`
RetryCount int `json:"retry_count"`
LastRetryAt time.Time `json:"last_retry_at,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Priority int `json:"priority"`
TraceID string `json:"trace_id,omitempty"`
Tags map[string]string `json:"tags,omitempty"`
ReprocessCount int `json:"reprocess_count"`
}
// DLQFilter for filtering DLQ entries
type DLQFilter struct {
QueueName string
ErrorType string
FromDate time.Time
ToDate time.Time
MinRetries int
MaxRetries int
}
// FileDLQStorage implements file-based DLQ storage
type FileDLQStorage struct {
baseDir string
mu sync.RWMutex
logger logger.Logger
index map[string]*DLQEntry // In-memory index for fast lookups
}
// NewFileDLQStorage creates a new file-based DLQ storage
func NewFileDLQStorage(baseDir string, log logger.Logger) (*FileDLQStorage, error) {
if err := os.MkdirAll(baseDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create DLQ directory: %w", err)
}
storage := &FileDLQStorage{
baseDir: baseDir,
logger: log,
index: make(map[string]*DLQEntry),
}
// Load existing entries into index
if err := storage.loadIndex(); err != nil {
return nil, fmt.Errorf("failed to load DLQ index: %w", err)
}
return storage, nil
}
// Store persists a DLQ entry to disk
func (f *FileDLQStorage) Store(ctx context.Context, entry *DLQEntry) error {
f.mu.Lock()
defer f.mu.Unlock()
// Serialize entry
data, err := json.MarshalIndent(entry, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal DLQ entry: %w", err)
}
// Create file path based on queue name and task ID
queueDir := filepath.Join(f.baseDir, entry.QueueName)
if err := os.MkdirAll(queueDir, 0755); err != nil {
return fmt.Errorf("failed to create queue directory: %w", err)
}
filePath := filepath.Join(queueDir, fmt.Sprintf("%s.json", entry.TaskID))
// Write atomically using temp file
tempPath := filePath + ".tmp"
if err := os.WriteFile(tempPath, data, 0644); err != nil {
return fmt.Errorf("failed to write DLQ entry: %w", err)
}
if err := os.Rename(tempPath, filePath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to rename DLQ entry: %w", err)
}
// Update index
f.index[entry.TaskID] = entry
f.logger.Info("DLQ entry persisted",
logger.Field{Key: "taskID", Value: entry.TaskID},
logger.Field{Key: "queue", Value: entry.QueueName})
return nil
}
// Get retrieves a DLQ entry by task ID
func (f *FileDLQStorage) Get(ctx context.Context, taskID string) (*DLQEntry, error) {
f.mu.RLock()
entry, exists := f.index[taskID]
f.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("DLQ entry not found: %s", taskID)
}
return entry, nil
}
// List returns paginated DLQ entries
func (f *FileDLQStorage) List(ctx context.Context, offset, limit int, filter DLQFilter) ([]*DLQEntry, error) {
f.mu.RLock()
defer f.mu.RUnlock()
var entries []*DLQEntry
for _, entry := range f.index {
if f.matchesFilter(entry, filter) {
entries = append(entries, entry)
}
}
// Sort by failed_at descending (newest first)
for i := 0; i < len(entries)-1; i++ {
for j := i + 1; j < len(entries); j++ {
if entries[i].FailedAt.Before(entries[j].FailedAt) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
// Apply pagination
start := offset
if start >= len(entries) {
return []*DLQEntry{}, nil
}
end := start + limit
if end > len(entries) {
end = len(entries)
}
return entries[start:end], nil
}
// Delete removes a DLQ entry
func (f *FileDLQStorage) Delete(ctx context.Context, taskID string) error {
f.mu.Lock()
defer f.mu.Unlock()
entry, exists := f.index[taskID]
if !exists {
return fmt.Errorf("DLQ entry not found: %s", taskID)
}
// Delete file
filePath := filepath.Join(f.baseDir, entry.QueueName, fmt.Sprintf("%s.json", taskID))
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to delete DLQ entry file: %w", err)
}
// Remove from index
delete(f.index, taskID)
f.logger.Info("DLQ entry deleted", logger.Field{Key: "taskID", Value: taskID})
return nil
}
// DeleteOlderThan removes entries older than specified duration
func (f *FileDLQStorage) DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
cutoff := time.Now().Add(-duration)
deleted := 0
for taskID, entry := range f.index {
if entry.FailedAt.Before(cutoff) {
filePath := filepath.Join(f.baseDir, entry.QueueName, fmt.Sprintf("%s.json", taskID))
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
f.logger.Error("Failed to delete old DLQ entry",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "taskID", Value: taskID})
continue
}
delete(f.index, taskID)
deleted++
}
}
f.logger.Info("Deleted old DLQ entries", logger.Field{Key: "count", Value: deleted})
return deleted, nil
}
// Count returns the total number of DLQ entries
func (f *FileDLQStorage) Count(ctx context.Context, filter DLQFilter) (int64, error) {
f.mu.RLock()
defer f.mu.RUnlock()
count := int64(0)
for _, entry := range f.index {
if f.matchesFilter(entry, filter) {
count++
}
}
return count, nil
}
// Close closes the storage
func (f *FileDLQStorage) Close() error {
f.mu.Lock()
defer f.mu.Unlock()
f.index = nil
return nil
}
// loadIndex loads existing DLQ entries into memory index
func (f *FileDLQStorage) loadIndex() error {
return filepath.Walk(f.baseDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || filepath.Ext(path) != ".json" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
f.logger.Error("Failed to read DLQ entry",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "path", Value: path})
return nil
}
var entry DLQEntry
if err := json.Unmarshal(data, &entry); err != nil {
f.logger.Error("Failed to unmarshal DLQ entry",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "path", Value: path})
return nil
}
f.index[entry.TaskID] = &entry
return nil
})
}
// matchesFilter checks if an entry matches the filter
func (f *FileDLQStorage) matchesFilter(entry *DLQEntry, filter DLQFilter) bool {
if filter.QueueName != "" && entry.QueueName != filter.QueueName {
return false
}
if filter.ErrorType != "" && entry.ErrorType != filter.ErrorType {
return false
}
if !filter.FromDate.IsZero() && entry.FailedAt.Before(filter.FromDate) {
return false
}
if !filter.ToDate.IsZero() && entry.FailedAt.After(filter.ToDate) {
return false
}
if filter.MinRetries > 0 && entry.RetryCount < filter.MinRetries {
return false
}
if filter.MaxRetries > 0 && entry.RetryCount > filter.MaxRetries {
return false
}
return true
}
// InMemoryDLQStorage implements in-memory DLQ storage (for testing or small scale)
type InMemoryDLQStorage struct {
entries map[string]*DLQEntry
mu sync.RWMutex
}
// NewInMemoryDLQStorage creates a new in-memory DLQ storage
func NewInMemoryDLQStorage() *InMemoryDLQStorage {
return &InMemoryDLQStorage{
entries: make(map[string]*DLQEntry),
}
}
func (m *InMemoryDLQStorage) Store(ctx context.Context, entry *DLQEntry) error {
m.mu.Lock()
defer m.mu.Unlock()
m.entries[entry.TaskID] = entry
return nil
}
func (m *InMemoryDLQStorage) Get(ctx context.Context, taskID string) (*DLQEntry, error) {
m.mu.RLock()
defer m.mu.RUnlock()
entry, exists := m.entries[taskID]
if !exists {
return nil, fmt.Errorf("DLQ entry not found: %s", taskID)
}
return entry, nil
}
func (m *InMemoryDLQStorage) List(ctx context.Context, offset, limit int, filter DLQFilter) ([]*DLQEntry, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var entries []*DLQEntry
for _, entry := range m.entries {
entries = append(entries, entry)
}
start := offset
if start >= len(entries) {
return []*DLQEntry{}, nil
}
end := start + limit
if end > len(entries) {
end = len(entries)
}
return entries[start:end], nil
}
func (m *InMemoryDLQStorage) Delete(ctx context.Context, taskID string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.entries, taskID)
return nil
}
func (m *InMemoryDLQStorage) DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
cutoff := time.Now().Add(-duration)
deleted := 0
for taskID, entry := range m.entries {
if entry.FailedAt.Before(cutoff) {
delete(m.entries, taskID)
deleted++
}
}
return deleted, nil
}
func (m *InMemoryDLQStorage) Count(ctx context.Context, filter DLQFilter) (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return int64(len(m.entries)), nil
}
func (m *InMemoryDLQStorage) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.entries = nil
return nil
}

486
enhanced_integration.go Normal file
View File

@@ -0,0 +1,486 @@
package mq
import (
"context"
"encoding/json"
"time"
"github.com/oarkflow/mq/logger"
)
// BrokerEnhancedConfig holds configuration for all enhanced features
type BrokerEnhancedConfig struct {
// DLQ Configuration
DLQStoragePath string
DLQRetentionPeriod time.Duration
DLQMaxSize int
// WAL Configuration
WALDirectory string
WALMaxFileSize int64
WALSyncInterval time.Duration
WALFsyncOnWrite bool
// Acknowledgment Configuration
AckTimeout time.Duration
AckMaxRetries int
AckCheckInterval time.Duration
// Worker Health Configuration
WorkerHealthTimeout time.Duration
WorkerCheckInterval time.Duration
// Dynamic Scaling Configuration
MinWorkers int
MaxWorkers int
ScaleUpThreshold float64
ScaleDownThreshold float64
// Deduplication Configuration
DedupWindow time.Duration
DedupCleanupInterval time.Duration
DedupPersistent bool
// Flow Control Configuration
MaxCredits int64
MinCredits int64
CreditRefillRate int64
CreditRefillInterval time.Duration
// Backpressure Configuration
QueueDepthThreshold int
MemoryThreshold uint64
ErrorRateThreshold float64
// Snapshot Configuration
SnapshotDirectory string
SnapshotInterval time.Duration
SnapshotRetention time.Duration
// Tracing Configuration
TracingEnabled bool
TraceRetention time.Duration
TraceExportInterval time.Duration
Logger logger.Logger
EnableEnhancements bool // Master switch for all enhancements
}
// EnhancedFeatures holds all enhanced feature managers
type EnhancedFeatures struct {
ackManager *AckManager
walLog *WriteAheadLog
dlqStorage DLQStorage
dedupManager *DeduplicationManager
flowController *FlowController
backpressureMonitor *BackpressureMonitor
snapshotManager *SnapshotManager
traceManager *TraceManager
lifecycleTracker *MessageLifecycleTracker
config *BrokerEnhancedConfig
enabled bool
}
// InitializeEnhancements initializes all enhanced features for a broker
func (b *Broker) InitializeEnhancements(config *BrokerEnhancedConfig) error {
if config == nil {
config = DefaultBrokerEnhancedConfig()
}
if !config.EnableEnhancements {
return nil // Enhancements disabled
}
features := &EnhancedFeatures{
config: config,
enabled: true,
}
// Initialize DLQ Storage
if config.DLQStoragePath != "" {
dlqStorage, err := NewFileDLQStorage(config.DLQStoragePath, config.Logger)
if err != nil {
return err
}
features.dlqStorage = dlqStorage
} else {
features.dlqStorage = NewInMemoryDLQStorage()
}
// Initialize WAL
if config.WALDirectory != "" {
walConfig := WALConfig{
Directory: config.WALDirectory,
MaxFileSize: config.WALMaxFileSize,
SyncInterval: config.WALSyncInterval,
FsyncOnWrite: config.WALFsyncOnWrite,
Logger: config.Logger,
}
wal, err := NewWriteAheadLog(walConfig)
if err != nil {
return err
}
features.walLog = wal
}
// Initialize Acknowledgment Manager
ackConfig := AckManagerConfig{
AckTimeout: config.AckTimeout,
MaxRetries: config.AckMaxRetries,
CheckInterval: config.AckCheckInterval,
Logger: config.Logger,
}
features.ackManager = NewAckManager(ackConfig)
// Setup acknowledgment callbacks
features.ackManager.SetOnReject(func(pm *PendingMessage) {
// Move to DLQ
entry := &DLQEntry{
TaskID: pm.Task.ID,
QueueName: pm.QueueName,
OriginalPayload: json.RawMessage(pm.Task.Payload),
ErrorMessage: "Max retries exceeded",
FailedAt: time.Now(),
RetryCount: pm.RetryCount,
}
_ = features.dlqStorage.Store(context.Background(), entry)
})
features.ackManager.SetOnRedeliver(func(pm *PendingMessage) {
// Requeue the message
if queue, exists := b.queues.Get(pm.QueueName); exists {
select {
case queue.tasks <- &QueuedTask{Task: pm.Task}:
default:
config.Logger.Warn("Failed to requeue message - queue full",
logger.Field{Key: "taskID", Value: pm.Task.ID})
}
}
})
// Initialize Deduplication Manager
dedupConfig := DedupConfig{
Window: config.DedupWindow,
CleanupInterval: config.DedupCleanupInterval,
Logger: config.Logger,
}
features.dedupManager = NewDeduplicationManager(dedupConfig)
// Initialize Flow Controller
flowConfig := FlowControlConfig{
MaxCredits: config.MaxCredits,
MinCredits: config.MinCredits,
RefillRate: config.CreditRefillRate,
RefillInterval: config.CreditRefillInterval,
Logger: config.Logger,
}
features.flowController = NewFlowController(flowConfig)
// Initialize Backpressure Monitor
backpressureConfig := BackpressureConfig{
QueueDepthThreshold: config.QueueDepthThreshold,
MemoryThreshold: config.MemoryThreshold,
ErrorRateThreshold: config.ErrorRateThreshold,
Logger: config.Logger,
}
features.backpressureMonitor = NewBackpressureMonitor(backpressureConfig)
// Initialize Snapshot Manager
if config.SnapshotDirectory != "" {
snapshotConfig := SnapshotConfig{
BaseDir: config.SnapshotDirectory,
SnapshotInterval: config.SnapshotInterval,
RetentionPeriod: config.SnapshotRetention,
Logger: config.Logger,
}
snapshotManager, err := NewSnapshotManager(b, snapshotConfig)
if err != nil {
return err
}
features.snapshotManager = snapshotManager
}
// Initialize Tracing
if config.TracingEnabled {
traceConfig := TraceConfig{
Storage: NewInMemoryTraceStorage(),
Retention: config.TraceRetention,
ExportInterval: config.TraceExportInterval,
Logger: config.Logger,
}
features.traceManager = NewTraceManager(traceConfig)
features.lifecycleTracker = NewMessageLifecycleTracker(features.traceManager, config.Logger)
}
// Store features in broker (we'll need to add this field to Broker struct)
b.enhanced = features
return nil
}
// DefaultBrokerEnhancedConfig returns default configuration
func DefaultBrokerEnhancedConfig() *BrokerEnhancedConfig {
return &BrokerEnhancedConfig{
DLQRetentionPeriod: 7 * 24 * time.Hour,
DLQMaxSize: 10000,
WALMaxFileSize: 100 * 1024 * 1024,
WALSyncInterval: 1 * time.Second,
WALFsyncOnWrite: false,
AckTimeout: 30 * time.Second,
AckMaxRetries: 3,
AckCheckInterval: 5 * time.Second,
WorkerHealthTimeout: 30 * time.Second,
WorkerCheckInterval: 10 * time.Second,
MinWorkers: 1,
MaxWorkers: 100,
ScaleUpThreshold: 0.75,
ScaleDownThreshold: 0.25,
DedupWindow: 5 * time.Minute,
DedupCleanupInterval: 1 * time.Minute,
MaxCredits: 1000,
MinCredits: 100,
CreditRefillRate: 10,
CreditRefillInterval: 100 * time.Millisecond,
QueueDepthThreshold: 1000,
MemoryThreshold: 1 * 1024 * 1024 * 1024, // 1GB
ErrorRateThreshold: 0.5,
SnapshotInterval: 5 * time.Minute,
SnapshotRetention: 24 * time.Hour,
TracingEnabled: true,
TraceRetention: 24 * time.Hour,
TraceExportInterval: 30 * time.Second,
EnableEnhancements: true,
}
}
// EnhancedPublish publishes a message with enhanced features
func (b *Broker) EnhancedPublish(ctx context.Context, task *Task, queueName string) error {
if b.enhanced == nil || !b.enhanced.enabled {
// Fall back to regular publish logic
return b.regularPublish(ctx, task, queueName)
}
// Check for duplicates
if b.enhanced.dedupManager != nil {
isDuplicate, err := b.enhanced.dedupManager.CheckDuplicate(ctx, task)
if err != nil {
return err
}
if isDuplicate {
b.logger.Debug("Duplicate message rejected",
logger.Field{Key: "taskID", Value: task.ID})
return nil
}
}
// Acquire flow control credits
if b.enhanced.flowController != nil {
if err := b.enhanced.flowController.AcquireCredit(ctx, 1); err != nil {
return err
}
defer b.enhanced.flowController.ReleaseCredit(1)
}
// Write to WAL
if b.enhanced.walLog != nil {
walEntry := &WALEntry{
EntryType: WALEntryEnqueue,
TaskID: task.ID,
QueueName: queueName,
Payload: json.RawMessage(task.Payload),
}
if err := b.enhanced.walLog.WriteEntry(ctx, walEntry); err != nil {
b.logger.Error("Failed to write WAL entry",
logger.Field{Key: "error", Value: err})
}
}
// Start tracing
if b.enhanced.lifecycleTracker != nil {
b.enhanced.lifecycleTracker.TrackEnqueue(ctx, task, queueName)
}
// Continue with regular publish
return b.regularPublish(ctx, task, queueName)
}
// regularPublish is the standard publish logic
func (b *Broker) regularPublish(ctx context.Context, task *Task, queueName string) error {
queue, exists := b.queues.Get(queueName)
if !exists {
queue = b.NewQueue(queueName)
}
// Enqueue task
select {
case queue.tasks <- &QueuedTask{Task: task}:
// Track for acknowledgment if enhanced features enabled
if b.enhanced != nil && b.enhanced.ackManager != nil {
_ = b.enhanced.ackManager.TrackMessage(ctx, task, queueName, "")
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// RecoverFromWAL recovers broker state from WAL
func (b *Broker) RecoverFromWAL(ctx context.Context) error {
if b.enhanced == nil || b.enhanced.walLog == nil {
return nil
}
b.logger.Info("Starting WAL recovery")
return b.enhanced.walLog.Replay(func(entry *WALEntry) error {
switch entry.EntryType {
case WALEntryEnqueue:
var task Task
if err := json.Unmarshal(entry.Payload, &task); err != nil {
return err
}
queue, exists := b.queues.Get(entry.QueueName)
if !exists {
queue = b.NewQueue(entry.QueueName)
}
select {
case queue.tasks <- &QueuedTask{Task: &task}:
default:
b.logger.Warn("Queue full during recovery",
logger.Field{Key: "queue", Value: entry.QueueName})
}
case WALEntryComplete, WALEntryFailed:
// Already processed, no action needed
}
return nil
})
}
// RecoverFromSnapshot recovers broker state from snapshots
func (b *Broker) RecoverFromSnapshot(ctx context.Context) error {
if b.enhanced == nil || b.enhanced.snapshotManager == nil {
return nil
}
b.logger.Info("Starting snapshot recovery")
// Recover all queues
var recoveredQueues int
b.queues.ForEach(func(queueName string, _ *Queue) bool {
if err := b.enhanced.snapshotManager.RestoreFromSnapshot(ctx, queueName); err != nil {
b.logger.Error("Failed to restore queue from snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "error", Value: err})
} else {
recoveredQueues++
}
return true
})
b.logger.Info("Snapshot recovery complete",
logger.Field{Key: "queues", Value: recoveredQueues})
return nil
}
// GetEnhancedStats returns comprehensive statistics
func (b *Broker) GetEnhancedStats() map[string]interface{} {
stats := make(map[string]interface{})
if b.enhanced == nil {
return stats
}
if b.enhanced.ackManager != nil {
stats["acknowledgments"] = b.enhanced.ackManager.GetStats()
}
if b.enhanced.walLog != nil {
stats["wal"] = b.enhanced.walLog.GetStats()
}
if b.enhanced.dedupManager != nil {
stats["deduplication"] = b.enhanced.dedupManager.GetStats()
}
if b.enhanced.flowController != nil {
stats["flow_control"] = b.enhanced.flowController.GetStats()
}
if b.enhanced.snapshotManager != nil {
stats["snapshots"] = b.enhanced.snapshotManager.GetSnapshotStats()
}
if b.enhanced.traceManager != nil {
stats["tracing"] = b.enhanced.traceManager.GetStats()
}
return stats
}
// ShutdownEnhanced gracefully shuts down all enhanced features
func (b *Broker) ShutdownEnhanced(ctx context.Context) error {
if b.enhanced == nil {
return nil
}
b.logger.Info("Starting enhanced features shutdown")
// Shutdown components in order
if b.enhanced.backpressureMonitor != nil {
b.enhanced.backpressureMonitor.Shutdown()
}
if b.enhanced.flowController != nil {
b.enhanced.flowController.Shutdown()
}
if b.enhanced.dedupManager != nil {
_ = b.enhanced.dedupManager.Shutdown(ctx)
}
if b.enhanced.ackManager != nil {
_ = b.enhanced.ackManager.Shutdown(ctx)
}
if b.enhanced.snapshotManager != nil {
_ = b.enhanced.snapshotManager.Shutdown(ctx)
}
if b.enhanced.walLog != nil {
_ = b.enhanced.walLog.Shutdown(ctx)
}
if b.enhanced.traceManager != nil {
_ = b.enhanced.traceManager.Shutdown(ctx)
}
if b.enhanced.dlqStorage != nil {
_ = b.enhanced.dlqStorage.Close()
}
b.logger.Info("Enhanced features shutdown complete")
return nil
}
// InitializeEnhancedPool initializes pool with enhanced features
func (p *Pool) InitializeEnhancedPool(config *BrokerEnhancedConfig) {
if config == nil || !config.EnableEnhancements {
return
}
// Add health monitor
if config.WorkerHealthTimeout > 0 {
_ = NewWorkerHealthMonitor(p, config.WorkerHealthTimeout, config.WorkerCheckInterval, config.Logger)
}
// Add dynamic scaler
if config.MinWorkers > 0 && config.MaxWorkers > int(p.numOfWorkers) {
_ = NewDynamicScaler(p, config.MinWorkers, config.MaxWorkers, config.Logger)
}
}

View File

@@ -0,0 +1,257 @@
package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/logger"
)
// Comprehensive Broker Server Example
// This demonstrates a production-ready message broker with all enhanced features enabled
func main() {
fmt.Println("🚀 Starting Production Message Broker Server")
fmt.Println("=" + "=" + string(make([]byte, 58)))
// Configure logger
nullLogger := logger.NewNullLogger()
// Configure all enhanced features
enhancedConfig := &mq.BrokerEnhancedConfig{
// Dead Letter Queue (DLQ) - Store failed messages
DLQStoragePath: "./data/dlq",
DLQRetentionPeriod: 7 * 24 * time.Hour, // Keep failed messages for 7 days
DLQMaxSize: 10000, // Maximum failed messages to store
// Write-Ahead Log (WAL) - Durability
WALDirectory: "./data/wal",
WALMaxFileSize: 100 * 1024 * 1024, // 100MB per WAL file
WALSyncInterval: 1 * time.Second, // Sync to disk every second
WALFsyncOnWrite: false, // Async for better performance
// Message Acknowledgments - At-least-once delivery
AckTimeout: 30 * time.Second, // Timeout for message acknowledgment
AckMaxRetries: 3, // Retry failed messages 3 times
AckCheckInterval: 5 * time.Second, // Check for timeouts every 5 seconds
// Worker Health Monitoring
WorkerHealthTimeout: 30 * time.Second, // Mark worker unhealthy after 30s
WorkerCheckInterval: 10 * time.Second, // Check worker health every 10s
// Dynamic Worker Scaling
MinWorkers: 2, // Minimum number of workers per queue
MaxWorkers: 10, // Maximum number of workers per queue
ScaleUpThreshold: 0.75, // Scale up when 75% utilized
ScaleDownThreshold: 0.25, // Scale down when below 25% utilized
// Message Deduplication
DedupWindow: 5 * time.Minute, // Deduplicate within 5-minute window
DedupCleanupInterval: 1 * time.Minute, // Clean up old entries every minute
DedupPersistent: false, // In-memory only for better performance
// Flow Control & Backpressure
MaxCredits: 1000, // Maximum flow control credits
MinCredits: 100, // Minimum credits before refill
CreditRefillRate: 10, // Credits to refill
CreditRefillInterval: 100 * time.Millisecond, // Refill every 100ms
// Backpressure Thresholds
QueueDepthThreshold: 1000, // Alert when queue has 1000+ messages
MemoryThreshold: 1 * 1024 * 1024 * 1024, // Alert at 1GB memory usage
ErrorRateThreshold: 0.5, // Alert at 50% error rate
// Queue Snapshots - Fast recovery
SnapshotDirectory: "./data/snapshots",
SnapshotInterval: 5 * time.Minute, // Take snapshot every 5 minutes
SnapshotRetention: 24 * time.Hour, // Keep snapshots for 24 hours
// Distributed Tracing
TracingEnabled: true,
TraceRetention: 24 * time.Hour, // Keep traces for 24 hours
TraceExportInterval: 30 * time.Second, // Export traces every 30 seconds
Logger: nullLogger,
EnableEnhancements: true, // Master switch - set to false to disable all enhancements
}
// Create broker with standard options
brokerOptions := []mq.Option{
mq.WithBrokerURL(":9092"), // Broker listen address
mq.WithLogger(nullLogger),
}
broker := mq.NewBroker(brokerOptions...)
// Initialize all enhanced features
fmt.Println("\n📦 Initializing enhanced features...")
if err := broker.InitializeEnhancements(enhancedConfig); err != nil {
log.Fatalf("❌ Failed to initialize enhancements: %v", err)
}
fmt.Println("✅ Enhanced features initialized")
// Attempt to recover from previous crash
fmt.Println("\n🔄 Attempting recovery...")
ctx := context.Background()
// Recover from WAL (Write-Ahead Log)
if err := broker.RecoverFromWAL(ctx); err != nil {
fmt.Printf("⚠️ WAL recovery: %v (may be first run)\n", err)
} else {
fmt.Println("✅ WAL recovery complete")
}
// Recover from snapshots
if err := broker.RecoverFromSnapshot(ctx); err != nil {
fmt.Printf("⚠️ Snapshot recovery: %v (may be first run)\n", err)
} else {
fmt.Println("✅ Snapshot recovery complete")
}
// Create queues for different purposes
fmt.Println("\n📋 Creating queues...")
queues := []string{
"orders", // Order processing
"payments", // Payment processing
"notifications", // Email/SMS notifications
"analytics", // Data analytics tasks
"reports", // Report generation
}
for _, queueName := range queues {
_ = broker.NewQueue(queueName)
fmt.Printf(" ✅ Created queue: %s\n", queueName)
}
// Start the broker server
fmt.Println("\n🌐 Starting broker server on :9092...")
go func() {
if err := broker.Start(ctx); err != nil {
log.Fatalf("❌ Failed to start broker: %v", err)
}
}()
// Wait for server to be ready
time.Sleep(1 * time.Second)
fmt.Println("✅ Broker server is running")
// Start periodic stats reporting
go reportStats(broker)
// Print usage information
printUsageInfo()
// Set up graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
fmt.Println("\n⏳ Server running. Press Ctrl+C to shutdown gracefully...")
<-sigChan
fmt.Println("\n\n🛑 Shutdown signal received...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Graceful shutdown
fmt.Println(" 1. Shutting down enhanced features...")
if err := broker.ShutdownEnhanced(shutdownCtx); err != nil {
fmt.Printf("❌ Enhanced shutdown error: %v\n", err)
} else {
fmt.Println(" ✅ Enhanced features shut down")
}
fmt.Println(" 2. Closing broker...")
if err := broker.Close(); err != nil {
fmt.Printf("❌ Broker close error: %v\n", err)
} else {
fmt.Println(" ✅ Broker closed")
}
fmt.Println("\n✅ Graceful shutdown complete")
fmt.Println("👋 Broker server stopped")
}
// reportStats periodically reports broker statistics
func reportStats(broker *mq.Broker) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
stats := broker.GetEnhancedStats()
fmt.Println("\n📊 Broker Statistics:")
fmt.Println(" " + "-" + string(make([]byte, 50)))
// Acknowledgment stats
if ackStats, ok := stats["acknowledgments"].(map[string]interface{}); ok {
fmt.Printf(" 📝 Acknowledgments:\n")
fmt.Printf(" Pending: %v\n", ackStats["pending_count"])
fmt.Printf(" Redeliver queue: %v\n", ackStats["redeliver_backlog"])
}
// WAL stats
if walStats, ok := stats["wal"].(map[string]interface{}); ok {
fmt.Printf(" 📚 Write-Ahead Log:\n")
fmt.Printf(" Sequence ID: %v\n", walStats["current_sequence_id"])
fmt.Printf(" Files: %v\n", walStats["total_files"])
}
// Deduplication stats
if dedupStats, ok := stats["deduplication"].(map[string]interface{}); ok {
fmt.Printf(" 🔍 Deduplication:\n")
fmt.Printf(" Cache size: %v\n", dedupStats["cache_size"])
fmt.Printf(" Duplicates blocked: %v\n", dedupStats["total_duplicates"])
}
// Flow control stats
if flowStats, ok := stats["flow_control"].(map[string]interface{}); ok {
fmt.Printf(" 🚦 Flow Control:\n")
fmt.Printf(" Credits available: %v\n", flowStats["credits"])
if util, ok := flowStats["utilization"].(float64); ok {
fmt.Printf(" Utilization: %.1f%%\n", util*100)
}
}
// Snapshot stats
if snapshotStats, ok := stats["snapshots"].(map[string]interface{}); ok {
fmt.Printf(" 💾 Snapshots:\n")
fmt.Printf(" Total snapshots: %v\n", snapshotStats["total_snapshots"])
}
// Tracing stats
if traceStats, ok := stats["tracing"].(map[string]interface{}); ok {
fmt.Printf(" 🔬 Tracing:\n")
fmt.Printf(" Active traces: %v\n", traceStats["active_traces"])
}
fmt.Println(" " + "-" + string(make([]byte, 50)))
}
}
// printUsageInfo prints connection information
func printUsageInfo() {
fmt.Println("\n" + "=" + "=" + string(make([]byte, 58)))
fmt.Println("📡 Connection Information")
fmt.Println("=" + "=" + string(make([]byte, 58)))
fmt.Println("\nTo connect consumers and publishers to this broker:")
fmt.Println("\n Consumer Example:")
fmt.Println(" consumer := mq.NewConsumer(\"consumer-1\", \":9092\")")
fmt.Println(" consumer.Subscribe(\"orders\", handler)")
fmt.Println("\n Publisher Example:")
fmt.Println(" publisher := mq.NewPublisher(\"publisher-1\", \":9092\")")
fmt.Println(" publisher.Publish(task, \"orders\")")
fmt.Println("\n Available Queues:")
fmt.Println(" - orders (order processing)")
fmt.Println(" - payments (payment processing)")
fmt.Println(" - notifications (email/sms)")
fmt.Println(" - analytics (data analytics)")
fmt.Println(" - reports (report generation)")
fmt.Println("\n" + "=" + "=" + string(make([]byte, 58)))
}

View File

@@ -0,0 +1,336 @@
package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/oarkflow/json"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/logger"
)
// Comprehensive Consumer Example
// Demonstrates a production-ready consumer with:
// - Worker pool for concurrent processing
// - Error handling and retries
// - Security/authentication
// - Graceful shutdown
// - Statistics tracking
func main() {
fmt.Println("🚀 Starting Production Message Consumer")
fmt.Println(strings.Repeat("=", 60))
// Configure logger
nullLogger := logger.NewNullLogger()
// Consumer configuration
consumerID := "consumer-1"
brokerAddress := ":9092"
// Subscribe to multiple queues
queues := []string{"orders", "payments", "notifications", "analytics", "reports"}
fmt.Printf("\n📡 Consumer ID: %s\n", consumerID)
fmt.Printf("📡 Broker Address: %s\n", brokerAddress)
fmt.Printf("📋 Queues: %v\n", queues)
// Create consumers for each queue
var consumers []*mq.Consumer
for _, queue := range queues {
consumer := mq.NewConsumer(
fmt.Sprintf("%s-%s", consumerID, queue), // Consumer ID
queue, // Queue name (THIS WAS THE BUG!)
handleTask, // Task handler function
mq.WithBrokerURL(brokerAddress),
mq.WithWorkerPool(
100, // Worker pool size
4, // Number of workers
50000, // Task queue size
),
mq.WithLogger(nullLogger),
// Optional: Enable security
// mq.WithSecurity(true),
// mq.WithUsername("consumer"),
// mq.WithPassword("con123"),
)
consumers = append(consumers, consumer)
fmt.Printf(" ✅ Created consumer for queue: %s\n", queue)
}
fmt.Println("\n✅ Consumers created")
// Start periodic statistics reporting for first consumer
go reportStatistics(consumers[0])
// Start consuming messages
fmt.Println("\n🔄 Starting message consumption...")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Run all consumers in background
for _, consumer := range consumers {
c := consumer // capture for goroutine
go func() {
if err := c.Consume(ctx); err != nil {
log.Printf("❌ Consumer error: %v", err)
}
}()
}
fmt.Println("✅ All consumers are running")
fmt.Println("\n⏳ Consuming messages. Press Ctrl+C to shutdown gracefully...")
// Set up graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
fmt.Println("\n\n🛑 Shutdown signal received...")
// Cancel context to stop consumption
cancel()
// Give a moment for context cancellation to propagate
time.Sleep(500 * time.Millisecond)
fmt.Println(" 1. Closing consumers (this will stop worker pools)...")
for i, consumer := range consumers {
if err := consumer.Close(); err != nil {
fmt.Printf("❌ Consumer %d close error: %v\n", i, err)
}
}
fmt.Println(" ✅ All consumers closed")
fmt.Println("\n✅ Graceful shutdown complete")
fmt.Println("👋 Consumer stopped")
}
// handleTask processes incoming messages
// This is called by worker pool for each task
func handleTask(ctx context.Context, task *mq.Task) mq.Result {
startTime := time.Now()
fmt.Printf("\n📦 Processing Task\n")
fmt.Printf(" Task ID: %s\n", task.ID)
fmt.Printf(" Priority: %d\n", task.Priority)
// Parse task payload
var data map[string]interface{}
if err := json.Unmarshal(task.Payload, &data); err != nil {
fmt.Printf(" ❌ Failed to parse task data: %v\n", err)
return mq.Result{
Status: mq.Failed,
Error: fmt.Errorf("invalid task data: %w", err),
}
}
// Determine task type
taskType := "unknown"
if t, ok := data["type"].(string); ok {
taskType = t
}
fmt.Printf(" Type: %s\n", taskType)
// Process based on task type
var err error
switch taskType {
case "order":
err = processOrder(data)
case "payment":
err = processPayment(data)
case "notification":
err = processNotification(data)
default:
err = processGeneric(data)
}
// Calculate processing time
duration := time.Since(startTime)
fmt.Printf(" ⏱️ Processing time: %v\n", duration)
// Return result
if err != nil {
fmt.Printf(" ❌ Task failed: %v\n", err)
// Check if error is retryable
if isRetryableError(err) {
return mq.Result{
Status: mq.Failed,
Error: err,
}
}
// Permanent failure
return mq.Result{
Status: mq.Failed,
Error: err,
}
}
fmt.Printf(" ✅ Task completed successfully\n")
return mq.Result{
Status: mq.Completed,
}
}
// processOrder handles order processing tasks
func processOrder(data map[string]interface{}) error {
fmt.Printf(" 📦 Processing order...\n")
// Extract order details
orderID := data["order_id"]
customerID := data["customer_id"]
amount := data["amount"]
fmt.Printf(" Order ID: %v\n", orderID)
fmt.Printf(" Customer ID: %v\n", customerID)
fmt.Printf(" Amount: $%.2f\n", amount)
// Simulate order processing
time.Sleep(500 * time.Millisecond)
// Simulate occasional transient errors for testing
if orderID == "ORD-3" {
return fmt.Errorf("temporary database connection error")
}
return nil
}
// processPayment handles payment processing tasks
func processPayment(data map[string]interface{}) error {
fmt.Printf(" 💳 Processing payment...\n")
paymentID := data["payment_id"]
orderID := data["order_id"]
amount := data["amount"]
method := data["method"]
fmt.Printf(" Payment ID: %v\n", paymentID)
fmt.Printf(" Order ID: %v\n", orderID)
fmt.Printf(" Amount: $%.2f\n", amount)
fmt.Printf(" Method: %v\n", method)
// Simulate payment processing
time.Sleep(1 * time.Second)
// Validate payment amount
if amt, ok := amount.(float64); ok && amt < 0 {
return fmt.Errorf("invalid payment amount: %.2f", amt)
}
return nil
}
// processNotification handles notification tasks
func processNotification(data map[string]interface{}) error {
fmt.Printf(" 📧 Processing notification...\n")
recipient := data["recipient"]
subject := data["subject"]
body := data["body"]
fmt.Printf(" Recipient: %v\n", recipient)
fmt.Printf(" Subject: %v\n", subject)
fmt.Printf(" Body length: %d chars\n", len(fmt.Sprint(body)))
// Simulate sending notification
time.Sleep(300 * time.Millisecond)
return nil
}
// processGeneric handles unknown task types
func processGeneric(data map[string]interface{}) error {
fmt.Printf(" ⚙️ Processing generic task...\n")
// Just print the data
for key, value := range data {
fmt.Printf(" %s: %v\n", key, value)
}
time.Sleep(200 * time.Millisecond)
return nil
}
// isRetryableError determines if an error should trigger a retry
func isRetryableError(err error) bool {
if err == nil {
return false
}
errMsg := err.Error()
// List of retryable error patterns
retryablePatterns := []string{
"temporary",
"timeout",
"connection",
"network",
"unavailable",
}
for _, pattern := range retryablePatterns {
if contains(errMsg, pattern) {
return true
}
}
return false
}
// reportStatistics periodically reports consumer statistics
func reportStatistics(consumer *mq.Consumer) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
metrics := consumer.Metrics()
fmt.Println("\n📊 Consumer Statistics:")
fmt.Println(" " + strings.Repeat("-", 50))
fmt.Printf(" Consumer ID: %s\n", consumer.GetKey())
fmt.Printf(" Total Tasks: %d\n", metrics.TotalTasks)
fmt.Printf(" Completed Tasks: %d\n", metrics.CompletedTasks)
fmt.Printf(" Failed Tasks: %d\n", metrics.ErrorCount)
fmt.Printf(" Scheduled Tasks: %d\n", metrics.TotalScheduled)
fmt.Printf(" Memory Used: %d bytes\n", metrics.TotalMemoryUsed)
if metrics.TotalTasks > 0 {
successRate := float64(metrics.CompletedTasks) / float64(metrics.TotalTasks) * 100
fmt.Printf(" Success Rate: %.1f%%\n", successRate)
}
if metrics.TotalTasks > 0 && metrics.ExecutionTime > 0 {
avgTime := time.Duration(metrics.ExecutionTime/metrics.TotalTasks) * time.Millisecond
fmt.Printf(" Avg Processing Time: %v\n", avgTime)
}
fmt.Println(" " + strings.Repeat("-", 50))
}
}
// Helper function
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
findInString(s, substr)))
}
func findInString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,266 @@
package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/oarkflow/json"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/logger"
)
// Comprehensive Publisher Example
// Demonstrates a production-ready publisher with:
// - Connection pooling
// - Batch publishing
// - Error handling and retries
// - Message deduplication support
// - Security/authentication
func main() {
fmt.Println("🚀 Starting Production Message Publisher")
fmt.Println(strings.Repeat("=", 60))
// Configure logger
nullLogger := logger.NewNullLogger()
// Publisher configuration
publisherID := "publisher-1"
brokerAddress := ":9092"
fmt.Printf("\n📡 Publisher ID: %s\n", publisherID)
fmt.Printf("📡 Broker Address: %s\n", brokerAddress)
// Create publisher with authentication (optional)
publisher := mq.NewPublisher(
publisherID,
mq.WithBrokerURL(brokerAddress),
mq.WithLogger(nullLogger),
// Optional: Enable security
// mq.WithSecurity(true),
// mq.WithUsername("publisher"),
// mq.WithPassword("pub123"),
)
// Connect to broker
ctx := context.Background()
fmt.Println("\n🔌 Connecting to broker...")
// Publisher connects automatically on first Publish, but we can test connection
time.Sleep(500 * time.Millisecond)
fmt.Println("✅ Publisher ready")
// Publish messages to different queues
fmt.Println("\n📤 Publishing messages...")
// Publish orders
if err := publishOrders(ctx, publisher); err != nil {
log.Printf("❌ Failed to publish orders: %v", err)
}
// Publish payments
if err := publishPayments(ctx, publisher); err != nil {
log.Printf("❌ Failed to publish payments: %v", err)
}
// Publish notifications
if err := publishNotifications(ctx, publisher); err != nil {
log.Printf("❌ Failed to publish notifications: %v", err)
}
fmt.Println("\n✅ All messages published successfully")
// Start periodic message publishing (optional)
fmt.Println("\n🔄 Starting periodic publishing (every 10 seconds)...")
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
// Set up graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
fmt.Println("⏳ Publisher running. Press Ctrl+C to shutdown...")
for {
select {
case <-ticker.C:
// Publish periodic messages
if err := publishPeriodicMessage(ctx, publisher); err != nil {
log.Printf("⚠️ Periodic publish error: %v", err)
}
case <-sigChan:
fmt.Println("\n\n🛑 Shutdown signal received...")
fmt.Println("✅ Graceful shutdown complete")
fmt.Println("👋 Publisher stopped")
return
}
}
}
// publishOrders publishes sample order messages
func publishOrders(ctx context.Context, publisher *mq.Publisher) error {
fmt.Println("\n 📦 Publishing orders...")
for i := 1; i <= 5; i++ {
orderData := map[string]interface{}{
"type": "order",
"order_id": fmt.Sprintf("ORD-%d", i),
"customer_id": fmt.Sprintf("CUST-%d", i),
"amount": 100.0 * float64(i),
"items": []string{"item1", "item2", "item3"},
"timestamp": time.Now().Unix(),
}
payload, err := json.Marshal(orderData)
if err != nil {
return fmt.Errorf("failed to marshal order: %w", err)
}
task := mq.NewTask(
fmt.Sprintf("order-task-%d-%d", i, time.Now().Unix()),
payload,
"orders",
mq.WithPriority(i),
mq.WithTaskMaxRetries(3),
mq.WithTTL(1*time.Hour),
mq.WithTags(map[string]string{
"type": "order",
"priority": fmt.Sprintf("%d", i),
}),
)
if err := publisher.Publish(ctx, *task, "orders"); err != nil {
return fmt.Errorf("failed to publish order: %w", err)
}
fmt.Printf(" ✅ Published: %s (priority: %d)\n", task.ID, task.Priority)
time.Sleep(100 * time.Millisecond)
}
return nil
}
// publishPayments publishes sample payment messages
func publishPayments(ctx context.Context, publisher *mq.Publisher) error {
fmt.Println("\n 💳 Publishing payments...")
for i := 1; i <= 3; i++ {
paymentData := map[string]interface{}{
"type": "payment",
"payment_id": fmt.Sprintf("PAY-%d", i),
"order_id": fmt.Sprintf("ORD-%d", i),
"amount": 50.0 * float64(i),
"method": "credit_card",
"timestamp": time.Now().Unix(),
}
payload, err := json.Marshal(paymentData)
if err != nil {
return fmt.Errorf("failed to marshal payment: %w", err)
}
task := mq.NewTask(
fmt.Sprintf("payment-task-%d-%d", i, time.Now().Unix()),
payload,
"payments",
mq.WithPriority(10), // High priority for payments
mq.WithTaskMaxRetries(3),
)
if err := publisher.Publish(ctx, *task, "payments"); err != nil {
return fmt.Errorf("failed to publish payment: %w", err)
}
fmt.Printf(" ✅ Published: %s\n", task.ID)
time.Sleep(100 * time.Millisecond)
}
return nil
}
// publishNotifications publishes sample notification messages
func publishNotifications(ctx context.Context, publisher *mq.Publisher) error {
fmt.Println("\n 📧 Publishing notifications...")
notifications := []map[string]interface{}{
{
"type": "notification",
"notif_type": "email",
"recipient": "customer1@example.com",
"subject": "Order Confirmation",
"body": "Your order has been confirmed",
"timestamp": time.Now().Unix(),
},
{
"type": "notification",
"notif_type": "sms",
"recipient": "+1234567890",
"subject": "Payment Received",
"body": "Your payment of $150 has been received",
"timestamp": time.Now().Unix(),
},
}
for i, notif := range notifications {
payload, err := json.Marshal(notif)
if err != nil {
return fmt.Errorf("failed to marshal notification: %w", err)
}
task := mq.NewTask(
fmt.Sprintf("notif-task-%d-%d", i+1, time.Now().Unix()),
payload,
"notifications",
mq.WithPriority(5),
)
if err := publisher.Publish(ctx, *task, "notifications"); err != nil {
return fmt.Errorf("failed to publish notification: %w", err)
}
fmt.Printf(" ✅ Published: %s (%s)\n", task.ID, notif["notif_type"])
time.Sleep(100 * time.Millisecond)
}
return nil
}
// publishPeriodicMessage publishes a periodic heartbeat/analytics message
func publishPeriodicMessage(ctx context.Context, publisher *mq.Publisher) error {
data := map[string]interface{}{
"type": "analytics",
"event": "heartbeat",
"timestamp": time.Now().Unix(),
"metrics": map[string]interface{}{
"cpu_usage": 75.5,
"memory_usage": 60.2,
"active_users": 1250,
},
}
payload, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal periodic message: %w", err)
}
task := mq.NewTask(
fmt.Sprintf("analytics-task-%d", time.Now().Unix()),
payload,
"analytics",
mq.WithPriority(1), // Low priority for analytics
)
if err := publisher.Publish(ctx, *task, "analytics"); err != nil {
return fmt.Errorf("failed to publish periodic message: %w", err)
}
fmt.Printf(" 🔄 Published periodic message: %s\n", task.ID)
return nil
}

88
mq.go
View File

@@ -469,6 +469,8 @@ type Broker struct {
consumers storage.IMap[string, *consumer] consumers storage.IMap[string, *consumer]
publishers storage.IMap[string, *publisher] publishers storage.IMap[string, *publisher]
deadLetter storage.IMap[string, *Queue] deadLetter storage.IMap[string, *Queue]
deferredTasks storage.IMap[string, *QueuedTask] // NEW: Store for deferred tasks
dlqHandlers storage.IMap[string, Handler] // NEW: Custom handlers for DLQ queues
opts *Options opts *Options
pIDs storage.IMap[string, bool] pIDs storage.IMap[string, bool]
listener net.Listener listener net.Listener
@@ -485,9 +487,11 @@ type Broker struct {
authenticatedConns storage.IMap[string, bool] // authenticated connections authenticatedConns storage.IMap[string, bool] // authenticated connections
taskHeaders storage.IMap[string, map[string]string] // task headers by task ID taskHeaders storage.IMap[string, map[string]string] // task headers by task ID
pendingTasks map[string]map[string]*Task // consumerID -> taskID -> task pendingTasks map[string]map[string]*Task // consumerID -> taskID -> task
enhanced *EnhancedFeatures // enhanced features (DLQ, WAL, ACK, etc.)
mu sync.RWMutex // for pendingTasks mu sync.RWMutex // for pendingTasks
isShutdown int32 isShutdown int32
shutdown chan struct{} shutdown chan struct{}
stopDeferredChan chan struct{} // NEW: Signal to stop deferred task processor
wg sync.WaitGroup wg sync.WaitGroup
logger logger.Logger logger logger.Logger
} }
@@ -501,6 +505,8 @@ func NewBroker(opts ...Option) *Broker {
publishers: memory.New[string, *publisher](), publishers: memory.New[string, *publisher](),
consumers: memory.New[string, *consumer](), consumers: memory.New[string, *consumer](),
deadLetter: memory.New[string, *Queue](), deadLetter: memory.New[string, *Queue](),
deferredTasks: memory.New[string, *QueuedTask](), // NEW: Initialize deferred tasks map
dlqHandlers: memory.New[string, Handler](), // NEW: Initialize DLQ handlers map
pIDs: memory.New[string, bool](), pIDs: memory.New[string, bool](),
pendingTasks: make(map[string]map[string]*Task), pendingTasks: make(map[string]map[string]*Task),
opts: options, opts: options,
@@ -514,6 +520,7 @@ func NewBroker(opts ...Option) *Broker {
authenticatedConns: memory.New[string, bool](), authenticatedConns: memory.New[string, bool](),
taskHeaders: memory.New[string, map[string]string](), taskHeaders: memory.New[string, map[string]string](),
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
stopDeferredChan: make(chan struct{}), // NEW: Initialize stop channel for deferred processor
logger: options.Logger(), logger: options.Logger(),
} }
@@ -886,6 +893,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
// Store headers for response routing // Store headers for response routing
b.taskHeaders.Set(taskID, msg.Headers) b.taskHeaders.Set(taskID, msg.Headers)
// Send acknowledgment back to publisher
ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err != nil { if err != nil {
log.Printf("Error creating PUBLISH_ACK message: %v\n", err) log.Printf("Error creating PUBLISH_ACK message: %v\n", err)
@@ -894,7 +902,82 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
if err := b.send(ctx, conn, ack); err != nil { if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err) log.Printf("Error sending PUBLISH_ACK: %v\n", err)
} }
// Apply enhanced features if enabled
if b.enhanced != nil && b.enhanced.enabled {
// Parse the task from the message
var task Task
if err := json.Unmarshal(msg.Payload, &task); err != nil {
log.Printf("Error parsing task for enhanced features: %v\n", err)
b.broadcastToConsumers(msg) b.broadcastToConsumers(msg)
return
}
// Check for duplicates
if b.enhanced.dedupManager != nil {
isDuplicate, err := b.enhanced.dedupManager.CheckDuplicate(ctx, &task)
if err != nil {
log.Printf("Error checking duplicate: %v\n", err)
} else if isDuplicate {
b.logger.Debug("Duplicate message rejected",
logger.Field{Key: "taskID", Value: task.ID})
return // Don't broadcast duplicates
}
}
// Acquire flow control credits
if b.enhanced.flowController != nil {
if err := b.enhanced.flowController.AcquireCredit(ctx, 1); err != nil {
log.Printf("Flow control credit acquisition failed: %v\n", err)
// Continue anyway - don't block
} else {
defer b.enhanced.flowController.ReleaseCredit(1)
}
}
// Write to WAL
if b.enhanced.walLog != nil {
walEntry := &WALEntry{
EntryType: WALEntryEnqueue,
TaskID: task.ID,
QueueName: msg.Queue,
Payload: msg.Payload, // already []byte
}
if err := b.enhanced.walLog.WriteEntry(ctx, walEntry); err != nil {
b.logger.Error("Failed to write WAL entry",
logger.Field{Key: "error", Value: err})
}
}
// Start tracing
if b.enhanced.lifecycleTracker != nil {
b.enhanced.lifecycleTracker.TrackEnqueue(ctx, &task, msg.Queue)
}
// Track for acknowledgment
if b.enhanced.ackManager != nil {
_ = b.enhanced.ackManager.TrackMessage(ctx, &task, msg.Queue, pub.id)
}
// Check if task is deferred
if !task.DeferUntil.IsZero() && task.DeferUntil.After(time.Now()) {
// Create QueuedTask for deferred execution
queuedTask := &QueuedTask{
Message: msg,
Task: &task,
RetryCount: 0,
}
// Add to deferred tasks queue
b.AddDeferredTask(queuedTask)
log.Printf("[DEFERRED] Task %s deferred until %s",
task.ID, task.DeferUntil.Format(time.RFC3339))
return // Don't broadcast yet
}
}
// Broadcast to consumers
b.broadcastToConsumers(msg)
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -928,6 +1011,9 @@ func (b *Broker) Start(ctx context.Context) error {
// Start health checker // Start health checker
b.healthChecker.Start() b.healthChecker.Start()
// Start deferred task processor
b.StartDeferredTaskProcessor(ctx)
// Start connection cleanup routine // Start connection cleanup routine
b.wg.Add(1) b.wg.Add(1)
go b.connectionCleanupRoutine() go b.connectionCleanupRoutine()
@@ -1303,6 +1389,8 @@ func (b *Broker) URL() string {
func (b *Broker) Close() error { func (b *Broker) Close() error {
if b != nil && b.listener != nil { if b != nil && b.listener != nil {
log.Printf("Broker is closing...") log.Printf("Broker is closing...")
// Stop deferred task processor
b.StopDeferredTaskProcessor()
return b.listener.Close() return b.listener.Close()
} }
return nil return nil

View File

@@ -382,6 +382,7 @@ var Config = &DynamicConfig{
type Pool struct { type Pool struct {
taskStorage TaskStorage taskStorage TaskStorage
stop chan struct{} stop chan struct{}
stopOnce sync.Once // Ensure Stop() is only called once
taskNotify chan struct{} taskNotify chan struct{}
workerAdjust chan int workerAdjust chan int
handler Handler handler Handler
@@ -1169,6 +1170,8 @@ func min(a, b int) int {
} }
func (wp *Pool) Stop() { func (wp *Pool) Stop() {
// Use sync.Once to ensure Stop is only executed once
wp.stopOnce.Do(func() {
wp.logger.Info().Msg("Initiating graceful shutdown") wp.logger.Info().Msg("Initiating graceful shutdown")
wp.gracefulShutdown = true wp.gracefulShutdown = true
@@ -1219,7 +1222,7 @@ func (wp *Pool) Stop() {
} }
} }
forceShutdown: forceShutdown:
// Final cleanup // Final cleanup
wp.cleanup() wp.cleanup()
@@ -1228,6 +1231,7 @@ forceShutdown:
} }
wp.logger.Info().Msg("Pool shutdown completed") wp.logger.Info().Msg("Pool shutdown completed")
}) // Close the Do() function
} }
// cleanup performs final resource cleanup // cleanup performs final resource cleanup

525
pool_enhancements.go Normal file
View File

@@ -0,0 +1,525 @@
package mq
import (
"context"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/oarkflow/mq/logger"
)
// WorkerHealth represents the health status of a worker
type WorkerHealth struct {
WorkerID int
IsHealthy bool
LastHeartbeat time.Time
TasksProcessed int64
ErrorCount int64
AvgProcessTime time.Duration
CurrentTaskID string
CurrentTaskStart time.Time
MemoryUsage uint64
}
// WorkerHealthMonitor monitors worker health and performs auto-recovery
type WorkerHealthMonitor struct {
pool *Pool
workers map[int]*WorkerHealth
mu sync.RWMutex
heartbeatTimeout time.Duration
checkInterval time.Duration
shutdown chan struct{}
logger logger.Logger
onUnhealthyWorker func(workerID int, health *WorkerHealth)
onRecoveredWorker func(workerID int)
}
// NewWorkerHealthMonitor creates a new worker health monitor
func NewWorkerHealthMonitor(pool *Pool, heartbeatTimeout, checkInterval time.Duration, log logger.Logger) *WorkerHealthMonitor {
if heartbeatTimeout == 0 {
heartbeatTimeout = 30 * time.Second
}
if checkInterval == 0 {
checkInterval = 10 * time.Second
}
monitor := &WorkerHealthMonitor{
pool: pool,
workers: make(map[int]*WorkerHealth),
heartbeatTimeout: heartbeatTimeout,
checkInterval: checkInterval,
shutdown: make(chan struct{}),
logger: log,
}
go monitor.monitorLoop()
return monitor
}
// RecordHeartbeat records a heartbeat from a worker
func (m *WorkerHealthMonitor) RecordHeartbeat(workerID int, taskID string) {
m.mu.Lock()
defer m.mu.Unlock()
health, exists := m.workers[workerID]
if !exists {
health = &WorkerHealth{
WorkerID: workerID,
IsHealthy: true,
}
m.workers[workerID] = health
}
health.LastHeartbeat = time.Now()
health.IsHealthy = true
health.CurrentTaskID = taskID
if taskID != "" {
health.CurrentTaskStart = time.Now()
} else {
health.CurrentTaskStart = time.Time{}
}
}
// RecordTaskCompletion records task completion metrics
func (m *WorkerHealthMonitor) RecordTaskCompletion(workerID int, processingTime time.Duration, isError bool) {
m.mu.Lock()
defer m.mu.Unlock()
health, exists := m.workers[workerID]
if !exists {
return
}
health.TasksProcessed++
if isError {
health.ErrorCount++
}
// Update average processing time
if health.AvgProcessTime == 0 {
health.AvgProcessTime = processingTime
} else {
health.AvgProcessTime = (health.AvgProcessTime + processingTime) / 2
}
health.CurrentTaskID = ""
health.CurrentTaskStart = time.Time{}
}
// monitorLoop continuously monitors worker health
func (m *WorkerHealthMonitor) monitorLoop() {
ticker := time.NewTicker(m.checkInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.checkWorkerHealth()
case <-m.shutdown:
return
}
}
}
// checkWorkerHealth checks the health of all workers
func (m *WorkerHealthMonitor) checkWorkerHealth() {
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
for workerID, health := range m.workers {
// Check if worker has timed out
if now.Sub(health.LastHeartbeat) > m.heartbeatTimeout {
if health.IsHealthy {
health.IsHealthy = false
m.logger.Warn("Worker unhealthy - heartbeat timeout",
logger.Field{Key: "workerID", Value: workerID},
logger.Field{Key: "lastHeartbeat", Value: health.LastHeartbeat},
logger.Field{Key: "currentTask", Value: health.CurrentTaskID})
if m.onUnhealthyWorker != nil {
go m.onUnhealthyWorker(workerID, health)
}
// Attempt to restart the worker
go m.restartWorker(workerID)
}
}
// Check if worker is stuck on a task
if !health.CurrentTaskStart.IsZero() && now.Sub(health.CurrentTaskStart) > m.heartbeatTimeout*2 {
m.logger.Warn("Worker stuck on task",
logger.Field{Key: "workerID", Value: workerID},
logger.Field{Key: "taskID", Value: health.CurrentTaskID},
logger.Field{Key: "duration", Value: now.Sub(health.CurrentTaskStart)})
}
// Update memory usage
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
health.MemoryUsage = memStats.Alloc
}
}
// restartWorker attempts to restart an unhealthy worker
func (m *WorkerHealthMonitor) restartWorker(workerID int) {
m.logger.Info("Attempting to restart unhealthy worker",
logger.Field{Key: "workerID", Value: workerID})
// Signal the pool to start a replacement worker
m.pool.wg.Add(1)
go m.pool.worker()
// Mark as recovered
m.mu.Lock()
if health, exists := m.workers[workerID]; exists {
health.IsHealthy = true
health.LastHeartbeat = time.Now()
}
m.mu.Unlock()
if m.onRecoveredWorker != nil {
m.onRecoveredWorker(workerID)
}
m.logger.Info("Worker restarted successfully",
logger.Field{Key: "workerID", Value: workerID})
}
// GetHealthStats returns health statistics for all workers
func (m *WorkerHealthMonitor) GetHealthStats() map[string]interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
healthyCount := 0
unhealthyCount := 0
totalTasks := int64(0)
totalErrors := int64(0)
for _, health := range m.workers {
if health.IsHealthy {
healthyCount++
} else {
unhealthyCount++
}
totalTasks += health.TasksProcessed
totalErrors += health.ErrorCount
}
return map[string]interface{}{
"total_workers": len(m.workers),
"healthy_workers": healthyCount,
"unhealthy_workers": unhealthyCount,
"total_tasks_processed": totalTasks,
"total_errors": totalErrors,
}
}
// GetWorkerHealth returns health info for a specific worker
func (m *WorkerHealthMonitor) GetWorkerHealth(workerID int) (*WorkerHealth, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
health, exists := m.workers[workerID]
return health, exists
}
// SetOnUnhealthyWorker sets callback for unhealthy worker detection
func (m *WorkerHealthMonitor) SetOnUnhealthyWorker(fn func(workerID int, health *WorkerHealth)) {
m.onUnhealthyWorker = fn
}
// SetOnRecoveredWorker sets callback for worker recovery
func (m *WorkerHealthMonitor) SetOnRecoveredWorker(fn func(workerID int)) {
m.onRecoveredWorker = fn
}
// Shutdown stops the health monitor
func (m *WorkerHealthMonitor) Shutdown() {
close(m.shutdown)
}
// DynamicScaler dynamically scales workers based on load
type DynamicScaler struct {
pool *Pool
minWorkers int
maxWorkers int
scaleUpThreshold float64 // Task queue utilization % to scale up
scaleDownThreshold float64 // Task queue utilization % to scale down
cooldownPeriod time.Duration
lastScaleTime time.Time
mu sync.RWMutex
shutdown chan struct{}
logger logger.Logger
checkInterval time.Duration
onScaleUp func(oldCount, newCount int)
onScaleDown func(oldCount, newCount int)
}
// NewDynamicScaler creates a new dynamic worker scaler
func NewDynamicScaler(pool *Pool, minWorkers, maxWorkers int, log logger.Logger) *DynamicScaler {
scaler := &DynamicScaler{
pool: pool,
minWorkers: minWorkers,
maxWorkers: maxWorkers,
scaleUpThreshold: 0.75, // Scale up when 75% full
scaleDownThreshold: 0.25, // Scale down when 25% full
cooldownPeriod: 30 * time.Second,
shutdown: make(chan struct{}),
logger: log,
checkInterval: 10 * time.Second,
}
go scaler.scaleLoop()
return scaler
}
// scaleLoop continuously monitors and scales workers
func (s *DynamicScaler) scaleLoop() {
ticker := time.NewTicker(s.checkInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.evaluateScaling()
case <-s.shutdown:
return
}
}
}
// evaluateScaling evaluates whether to scale up or down
func (s *DynamicScaler) evaluateScaling() {
s.mu.Lock()
defer s.mu.Unlock()
// Check cooldown period
if time.Since(s.lastScaleTime) < s.cooldownPeriod {
return
}
// Get current metrics
currentWorkers := int(atomic.LoadInt32(&s.pool.numOfWorkers))
s.pool.taskQueueLock.Lock()
queueSize := len(s.pool.taskQueue)
s.pool.taskQueueLock.Unlock()
// Calculate utilization
queueCapacity := cap(s.pool.taskQueue)
if queueCapacity == 0 {
return
}
utilization := float64(queueSize) / float64(queueCapacity)
// Decide to scale up or down
if utilization >= s.scaleUpThreshold && currentWorkers < s.maxWorkers {
newWorkers := currentWorkers + 1
if newWorkers > s.maxWorkers {
newWorkers = s.maxWorkers
}
s.scaleUp(currentWorkers, newWorkers)
} else if utilization <= s.scaleDownThreshold && currentWorkers > s.minWorkers {
newWorkers := currentWorkers - 1
if newWorkers < s.minWorkers {
newWorkers = s.minWorkers
}
s.scaleDown(currentWorkers, newWorkers)
}
}
// scaleUp increases the number of workers
func (s *DynamicScaler) scaleUp(oldCount, newCount int) {
additionalWorkers := newCount - oldCount
s.logger.Info("Scaling up workers",
logger.Field{Key: "oldCount", Value: oldCount},
logger.Field{Key: "newCount", Value: newCount})
for i := 0; i < additionalWorkers; i++ {
s.pool.wg.Add(1)
go s.pool.worker()
}
atomic.StoreInt32(&s.pool.numOfWorkers, int32(newCount))
s.lastScaleTime = time.Now()
if s.onScaleUp != nil {
s.onScaleUp(oldCount, newCount)
}
}
// scaleDown decreases the number of workers
func (s *DynamicScaler) scaleDown(oldCount, newCount int) {
s.logger.Info("Scaling down workers",
logger.Field{Key: "oldCount", Value: oldCount},
logger.Field{Key: "newCount", Value: newCount})
atomic.StoreInt32(&s.pool.numOfWorkers, int32(newCount))
s.lastScaleTime = time.Now()
// Workers will naturally exit when they check numOfWorkers
if s.onScaleDown != nil {
s.onScaleDown(oldCount, newCount)
}
}
// SetScaleUpThreshold sets the threshold for scaling up
func (s *DynamicScaler) SetScaleUpThreshold(threshold float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.scaleUpThreshold = threshold
}
// SetScaleDownThreshold sets the threshold for scaling down
func (s *DynamicScaler) SetScaleDownThreshold(threshold float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.scaleDownThreshold = threshold
}
// SetOnScaleUp sets callback for scale up events
func (s *DynamicScaler) SetOnScaleUp(fn func(oldCount, newCount int)) {
s.onScaleUp = fn
}
// SetOnScaleDown sets callback for scale down events
func (s *DynamicScaler) SetOnScaleDown(fn func(oldCount, newCount int)) {
s.onScaleDown = fn
}
// Shutdown stops the dynamic scaler
func (s *DynamicScaler) Shutdown() {
close(s.shutdown)
}
// GracefulShutdownManager manages graceful shutdown of the pool
type GracefulShutdownManager struct {
pool *Pool
timeout time.Duration
drainTasks bool
logger logger.Logger
onShutdownStart func()
onShutdownEnd func()
}
// NewGracefulShutdownManager creates a new graceful shutdown manager
func NewGracefulShutdownManager(pool *Pool, timeout time.Duration, drainTasks bool, log logger.Logger) *GracefulShutdownManager {
if timeout == 0 {
timeout = 30 * time.Second
}
return &GracefulShutdownManager{
pool: pool,
timeout: timeout,
drainTasks: drainTasks,
logger: log,
}
}
// Shutdown performs a graceful shutdown
func (g *GracefulShutdownManager) Shutdown(ctx context.Context) error {
if g.onShutdownStart != nil {
g.onShutdownStart()
}
g.logger.Info("Starting graceful shutdown",
logger.Field{Key: "timeout", Value: g.timeout},
logger.Field{Key: "drainTasks", Value: g.drainTasks})
// Stop accepting new tasks
g.pool.gracefulShutdown = true
if g.drainTasks {
// Wait for existing tasks to complete
g.logger.Info("Draining existing tasks")
done := make(chan struct{})
go func() {
g.pool.wg.Wait()
close(done)
}()
select {
case <-done:
g.logger.Info("All tasks completed successfully")
case <-time.After(g.timeout):
g.logger.Warn("Graceful shutdown timeout exceeded, forcing shutdown")
case <-ctx.Done():
return ctx.Err()
}
}
// Stop the pool
close(g.pool.stop)
g.pool.wg.Wait()
if g.onShutdownEnd != nil {
g.onShutdownEnd()
}
g.logger.Info("Graceful shutdown complete")
return nil
}
// SetOnShutdownStart sets callback for shutdown start
func (g *GracefulShutdownManager) SetOnShutdownStart(fn func()) {
g.onShutdownStart = fn
}
// SetOnShutdownEnd sets callback for shutdown end
func (g *GracefulShutdownManager) SetOnShutdownEnd(fn func()) {
g.onShutdownEnd = fn
}
// PoolEnhancedStats returns enhanced statistics about the pool
func PoolEnhancedStats(pool *Pool) map[string]interface{} {
pool.taskQueueLock.Lock()
queueLen := len(pool.taskQueue)
queueCap := cap(pool.taskQueue)
pool.taskQueueLock.Unlock()
pool.overflowBufferLock.RLock()
overflowLen := len(pool.overflowBuffer)
pool.overflowBufferLock.RUnlock()
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
return map[string]interface{}{
"workers": map[string]interface{}{
"count": atomic.LoadInt32(&pool.numOfWorkers),
"paused": pool.paused,
},
"queue": map[string]interface{}{
"length": queueLen,
"capacity": queueCap,
"utilization": float64(queueLen) / float64(queueCap) * 100,
},
"overflow": map[string]interface{}{
"length": overflowLen,
},
"tasks": map[string]interface{}{
"total": atomic.LoadInt64(&pool.metrics.TotalTasks),
"completed": atomic.LoadInt64(&pool.metrics.CompletedTasks),
"errors": atomic.LoadInt64(&pool.metrics.ErrorCount),
},
"memory": map[string]interface{}{
"alloc": memStats.Alloc,
"total_alloc": memStats.TotalAlloc,
"sys": memStats.Sys,
"num_gc": memStats.NumGC,
},
"dlq": map[string]interface{}{
"size": pool.dlq.Size(),
},
}
}

460
snapshot.go Normal file
View File

@@ -0,0 +1,460 @@
package mq
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/oarkflow/mq/logger"
)
// QueueSnapshot represents a point-in-time snapshot of queue state
type QueueSnapshot struct {
Timestamp time.Time `json:"timestamp"`
QueueName string `json:"queue_name"`
PendingTasks []*Task `json:"pending_tasks"`
ConsumerState map[string]*ConsumerMeta `json:"consumer_state"`
Metrics *QueueMetrics `json:"metrics"`
Config *QueueConfig `json:"config"`
Version int `json:"version"`
}
// ConsumerMeta represents consumer metadata in snapshot
type ConsumerMeta struct {
ID string `json:"id"`
State string `json:"state"`
LastActivity time.Time `json:"last_activity"`
TasksAssigned int `json:"tasks_assigned"`
}
// SnapshotManager manages queue snapshots and recovery
type SnapshotManager struct {
baseDir string
broker *Broker
snapshotInterval time.Duration
retentionPeriod time.Duration
mu sync.RWMutex
logger logger.Logger
shutdown chan struct{}
wg sync.WaitGroup
}
// SnapshotConfig holds snapshot configuration
type SnapshotConfig struct {
BaseDir string
SnapshotInterval time.Duration
RetentionPeriod time.Duration // How long to keep old snapshots
Logger logger.Logger
}
// NewSnapshotManager creates a new snapshot manager
func NewSnapshotManager(broker *Broker, config SnapshotConfig) (*SnapshotManager, error) {
if config.SnapshotInterval == 0 {
config.SnapshotInterval = 5 * time.Minute
}
if config.RetentionPeriod == 0 {
config.RetentionPeriod = 24 * time.Hour
}
if err := os.MkdirAll(config.BaseDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create snapshot directory: %w", err)
}
sm := &SnapshotManager{
baseDir: config.BaseDir,
broker: broker,
snapshotInterval: config.SnapshotInterval,
retentionPeriod: config.RetentionPeriod,
logger: config.Logger,
shutdown: make(chan struct{}),
}
// Start periodic snapshot worker
sm.wg.Add(1)
go sm.snapshotLoop()
return sm, nil
}
// snapshotLoop periodically creates snapshots
func (sm *SnapshotManager) snapshotLoop() {
defer sm.wg.Done()
ticker := time.NewTicker(sm.snapshotInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := sm.CreateSnapshotAll(context.Background()); err != nil {
sm.logger.Error("Failed to create periodic snapshot",
logger.Field{Key: "error", Value: err})
}
// Cleanup old snapshots
if err := sm.CleanupOldSnapshots(); err != nil {
sm.logger.Error("Failed to cleanup old snapshots",
logger.Field{Key: "error", Value: err})
}
case <-sm.shutdown:
return
}
}
}
// CreateSnapshot creates a snapshot of a specific queue
func (sm *SnapshotManager) CreateSnapshot(ctx context.Context, queueName string) (*QueueSnapshot, error) {
sm.mu.RLock()
defer sm.mu.RUnlock()
queue, exists := sm.broker.queues.Get(queueName)
if !exists {
return nil, fmt.Errorf("queue not found: %s", queueName)
}
// Collect pending tasks
var pendingTasks []*Task
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
// Drain channel into slice (non-blocking)
for {
select {
case qt := <-queue.tasks:
pendingTasks = append(pendingTasks, qt.Task)
// Put it back
queue.tasks <- qt
default:
goto done
}
}
}
done:
// Collect consumer state
consumerState := make(map[string]*ConsumerMeta)
queue.consumers.ForEach(func(id string, consumer *consumer) bool {
consumerState[id] = &ConsumerMeta{
ID: id,
State: string(consumer.state),
LastActivity: consumer.metrics.LastActivity,
}
return true
})
snapshot := &QueueSnapshot{
Timestamp: time.Now(),
QueueName: queueName,
PendingTasks: pendingTasks,
ConsumerState: consumerState,
Metrics: queue.metrics,
Config: queue.config,
Version: 1,
}
// Persist snapshot to disk
if err := sm.persistSnapshot(snapshot); err != nil {
return nil, fmt.Errorf("failed to persist snapshot: %w", err)
}
sm.logger.Info("Created queue snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "pendingTasks", Value: len(pendingTasks)},
logger.Field{Key: "consumers", Value: len(consumerState)})
return snapshot, nil
}
// CreateSnapshotAll creates snapshots of all queues
func (sm *SnapshotManager) CreateSnapshotAll(ctx context.Context) error {
var snapshots []*QueueSnapshot
var mu sync.Mutex
var wg sync.WaitGroup
sm.broker.queues.ForEach(func(queueName string, _ *Queue) bool {
wg.Add(1)
go func(name string) {
defer wg.Done()
snapshot, err := sm.CreateSnapshot(ctx, name)
if err != nil {
sm.logger.Error("Failed to create snapshot",
logger.Field{Key: "queue", Value: name},
logger.Field{Key: "error", Value: err})
return
}
mu.Lock()
snapshots = append(snapshots, snapshot)
mu.Unlock()
}(queueName)
return true
})
wg.Wait()
sm.logger.Info("Created snapshots for all queues",
logger.Field{Key: "count", Value: len(snapshots)})
return nil
}
// persistSnapshot writes a snapshot to disk
func (sm *SnapshotManager) persistSnapshot(snapshot *QueueSnapshot) error {
data, err := json.MarshalIndent(snapshot, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal snapshot: %w", err)
}
queueDir := filepath.Join(sm.baseDir, snapshot.QueueName)
if err := os.MkdirAll(queueDir, 0755); err != nil {
return fmt.Errorf("failed to create queue directory: %w", err)
}
filename := fmt.Sprintf("snapshot-%d.json", snapshot.Timestamp.UnixNano())
filePath := filepath.Join(queueDir, filename)
// Write atomically using temp file
tempPath := filePath + ".tmp"
if err := os.WriteFile(tempPath, data, 0644); err != nil {
return fmt.Errorf("failed to write snapshot: %w", err)
}
if err := os.Rename(tempPath, filePath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to rename snapshot: %w", err)
}
return nil
}
// RestoreFromSnapshot restores a queue from the latest snapshot
func (sm *SnapshotManager) RestoreFromSnapshot(ctx context.Context, queueName string) error {
snapshot, err := sm.GetLatestSnapshot(queueName)
if err != nil {
return fmt.Errorf("failed to get latest snapshot: %w", err)
}
if snapshot == nil {
return fmt.Errorf("no snapshot found for queue: %s", queueName)
}
sm.logger.Info("Restoring queue from snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "timestamp", Value: snapshot.Timestamp},
logger.Field{Key: "tasks", Value: len(snapshot.PendingTasks)})
// Get or create queue
queue, exists := sm.broker.queues.Get(queueName)
if !exists {
queue = sm.broker.NewQueue(queueName)
}
// Restore pending tasks
restored := 0
for _, task := range snapshot.PendingTasks {
select {
case queue.tasks <- &QueuedTask{Task: task}:
restored++
case <-ctx.Done():
return ctx.Err()
default:
sm.logger.Warn("Queue full during restore, task skipped",
logger.Field{Key: "taskID", Value: task.ID})
}
}
// Restore metrics
if snapshot.Metrics != nil {
queue.metrics = snapshot.Metrics
}
// Restore config
if snapshot.Config != nil {
queue.config = snapshot.Config
}
sm.logger.Info("Queue restored from snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "restoredTasks", Value: restored})
return nil
}
// GetLatestSnapshot retrieves the latest snapshot for a queue
func (sm *SnapshotManager) GetLatestSnapshot(queueName string) (*QueueSnapshot, error) {
queueDir := filepath.Join(sm.baseDir, queueName)
files, err := filepath.Glob(filepath.Join(queueDir, "snapshot-*.json"))
if err != nil {
return nil, fmt.Errorf("failed to list snapshots: %w", err)
}
if len(files) == 0 {
return nil, nil
}
// Sort files to get latest (files are named with timestamp)
latestFile := files[0]
for _, file := range files {
if file > latestFile {
latestFile = file
}
}
// Read and parse snapshot
data, err := os.ReadFile(latestFile)
if err != nil {
return nil, fmt.Errorf("failed to read snapshot: %w", err)
}
var snapshot QueueSnapshot
if err := json.Unmarshal(data, &snapshot); err != nil {
return nil, fmt.Errorf("failed to unmarshal snapshot: %w", err)
}
return &snapshot, nil
}
// ListSnapshots lists all snapshots for a queue
func (sm *SnapshotManager) ListSnapshots(queueName string) ([]*QueueSnapshot, error) {
queueDir := filepath.Join(sm.baseDir, queueName)
files, err := filepath.Glob(filepath.Join(queueDir, "snapshot-*.json"))
if err != nil {
return nil, fmt.Errorf("failed to list snapshots: %w", err)
}
snapshots := make([]*QueueSnapshot, 0, len(files))
for _, file := range files {
data, err := os.ReadFile(file)
if err != nil {
sm.logger.Error("Failed to read snapshot file",
logger.Field{Key: "file", Value: file},
logger.Field{Key: "error", Value: err})
continue
}
var snapshot QueueSnapshot
if err := json.Unmarshal(data, &snapshot); err != nil {
sm.logger.Error("Failed to unmarshal snapshot",
logger.Field{Key: "file", Value: file},
logger.Field{Key: "error", Value: err})
continue
}
snapshots = append(snapshots, &snapshot)
}
return snapshots, nil
}
// CleanupOldSnapshots removes snapshots older than retention period
func (sm *SnapshotManager) CleanupOldSnapshots() error {
cutoff := time.Now().Add(-sm.retentionPeriod)
removed := 0
err := filepath.Walk(sm.baseDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || filepath.Ext(path) != ".json" {
return nil
}
if info.ModTime().Before(cutoff) {
if err := os.Remove(path); err != nil {
sm.logger.Error("Failed to remove old snapshot",
logger.Field{Key: "file", Value: path},
logger.Field{Key: "error", Value: err})
return nil
}
removed++
}
return nil
})
if err != nil {
return err
}
if removed > 0 {
sm.logger.Info("Cleaned up old snapshots",
logger.Field{Key: "removed", Value: removed})
}
return nil
}
// DeleteSnapshot deletes a specific snapshot
func (sm *SnapshotManager) DeleteSnapshot(queueName string, timestamp time.Time) error {
filename := fmt.Sprintf("snapshot-%d.json", timestamp.UnixNano())
filePath := filepath.Join(sm.baseDir, queueName, filename)
if err := os.Remove(filePath); err != nil {
return fmt.Errorf("failed to delete snapshot: %w", err)
}
sm.logger.Info("Deleted snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "timestamp", Value: timestamp})
return nil
}
// GetSnapshotStats returns statistics about snapshots
func (sm *SnapshotManager) GetSnapshotStats() map[string]interface{} {
totalSnapshots := 0
var totalSize int64
filepath.Walk(sm.baseDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && filepath.Ext(path) == ".json" {
totalSnapshots++
totalSize += info.Size()
}
return nil
})
return map[string]interface{}{
"total_snapshots": totalSnapshots,
"total_size_bytes": totalSize,
"retention_period": sm.retentionPeriod,
"snapshot_interval": sm.snapshotInterval,
}
}
// Shutdown gracefully shuts down the snapshot manager
func (sm *SnapshotManager) Shutdown(ctx context.Context) error {
close(sm.shutdown)
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
sm.wg.Wait()
close(done)
}()
// Create final snapshot before shutdown
if err := sm.CreateSnapshotAll(ctx); err != nil {
sm.logger.Error("Failed to create final snapshot",
logger.Field{Key: "error", Value: err})
}
select {
case <-done:
sm.logger.Info("Snapshot manager shutdown complete")
return nil
case <-ctx.Done():
return ctx.Err()
}
}

15
task.go
View File

@@ -115,6 +115,7 @@ type Task struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"` ProcessedAt time.Time `json:"processed_at"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
DeferUntil time.Time `json:"defer_until,omitempty"` // NEW: For deferred task execution
Error error `json:"-"` // Don't serialize errors directly Error error `json:"-"` // Don't serialize errors directly
ErrorMsg string `json:"error,omitempty"` // Serialize error message if present ErrorMsg string `json:"error,omitempty"` // Serialize error message if present
ID string `json:"id"` ID string `json:"id"`
@@ -232,6 +233,20 @@ func WithTTL(ttl time.Duration) TaskOption {
} }
} }
// TaskOption for deferring task execution until a specific time
func WithDeferUntil(deferUntil time.Time) TaskOption {
return func(t *Task) {
t.DeferUntil = deferUntil
}
}
// TaskOption for deferring task execution for a specific duration
func WithDeferDuration(duration time.Duration) TaskOption {
return func(t *Task) {
t.DeferUntil = time.Now().Add(duration)
}
}
// TaskOption for adding tags // TaskOption for adding tags
func WithTags(tags map[string]string) TaskOption { func WithTags(tags map[string]string) TaskOption {
return func(t *Task) { return func(t *Task) {

562
tracing.go Normal file
View File

@@ -0,0 +1,562 @@
package mq
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/oarkflow/mq/logger"
)
// TraceEvent represents a single event in a message's lifecycle
type TraceEvent struct {
Timestamp time.Time `json:"timestamp"`
EventType TraceEventType `json:"event_type"`
Source string `json:"source"`
Details map[string]string `json:"details,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
Error string `json:"error,omitempty"`
}
// TraceEventType defines types of trace events
type TraceEventType string
const (
TraceEventEnqueued TraceEventType = "ENQUEUED"
TraceEventDequeued TraceEventType = "DEQUEUED"
TraceEventProcessing TraceEventType = "PROCESSING"
TraceEventCompleted TraceEventType = "COMPLETED"
TraceEventFailed TraceEventType = "FAILED"
TraceEventRetried TraceEventType = "RETRIED"
TraceEventRejected TraceEventType = "REJECTED"
TraceEventAcked TraceEventType = "ACKED"
TraceEventNacked TraceEventType = "NACKED"
TraceEventTimeout TraceEventType = "TIMEOUT"
TraceEventDLQ TraceEventType = "DLQ"
)
// MessageTrace represents the complete trace of a message
type MessageTrace struct {
TraceID string `json:"trace_id"`
SpanID string `json:"span_id"`
ParentSpanID string `json:"parent_span_id,omitempty"`
TaskID string `json:"task_id"`
QueueName string `json:"queue_name"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
Events []*TraceEvent `json:"events"`
Tags map[string]string `json:"tags,omitempty"`
Status string `json:"status"`
mu sync.RWMutex
}
// TraceManager manages distributed tracing for messages
type TraceManager struct {
traces map[string]*MessageTrace
mu sync.RWMutex
logger logger.Logger
storage TraceStorage
shutdown chan struct{}
retention time.Duration
exportInterval time.Duration
onTraceComplete func(*MessageTrace)
}
// TraceStorage interface for persistent trace storage
type TraceStorage interface {
Store(ctx context.Context, trace *MessageTrace) error
Get(ctx context.Context, traceID string) (*MessageTrace, error)
Query(ctx context.Context, filter TraceFilter) ([]*MessageTrace, error)
Delete(ctx context.Context, traceID string) error
DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error)
Close() error
}
// TraceFilter for querying traces
type TraceFilter struct {
QueueName string
Status string
StartTime time.Time
EndTime time.Time
MinDuration time.Duration
MaxDuration time.Duration
HasError bool
Limit int
Offset int
}
// TraceConfig holds tracing configuration
type TraceConfig struct {
Storage TraceStorage
Retention time.Duration
ExportInterval time.Duration
Logger logger.Logger
}
// NewTraceManager creates a new trace manager
func NewTraceManager(config TraceConfig) *TraceManager {
if config.Retention == 0 {
config.Retention = 24 * time.Hour
}
if config.ExportInterval == 0 {
config.ExportInterval = 30 * time.Second
}
tm := &TraceManager{
traces: make(map[string]*MessageTrace),
logger: config.Logger,
storage: config.Storage,
shutdown: make(chan struct{}),
retention: config.Retention,
exportInterval: config.ExportInterval,
}
go tm.exportLoop()
go tm.cleanupLoop()
return tm
}
// StartTrace initiates tracing for a message
func (tm *TraceManager) StartTrace(ctx context.Context, task *Task, queueName string) *MessageTrace {
trace := &MessageTrace{
TraceID: task.TraceID,
SpanID: task.SpanID,
TaskID: task.ID,
QueueName: queueName,
StartTime: time.Now(),
Events: make([]*TraceEvent, 0),
Tags: task.Tags,
Status: "in_progress",
}
tm.mu.Lock()
tm.traces[trace.TraceID] = trace
tm.mu.Unlock()
tm.RecordEvent(trace.TraceID, TraceEventEnqueued, "broker", nil)
return trace
}
// RecordEvent records an event in a trace
func (tm *TraceManager) RecordEvent(traceID string, eventType TraceEventType, source string, details map[string]string) {
tm.mu.RLock()
trace, exists := tm.traces[traceID]
tm.mu.RUnlock()
if !exists {
tm.logger.Warn("Trace not found for event recording",
logger.Field{Key: "traceID", Value: traceID},
logger.Field{Key: "eventType", Value: eventType})
return
}
trace.mu.Lock()
defer trace.mu.Unlock()
event := &TraceEvent{
Timestamp: time.Now(),
EventType: eventType,
Source: source,
Details: details,
}
trace.Events = append(trace.Events, event)
tm.logger.Debug("Recorded trace event",
logger.Field{Key: "traceID", Value: traceID},
logger.Field{Key: "eventType", Value: eventType},
logger.Field{Key: "source", Value: source})
}
// RecordError records an error event in a trace
func (tm *TraceManager) RecordError(traceID string, source string, err error, details map[string]string) {
tm.mu.RLock()
trace, exists := tm.traces[traceID]
tm.mu.RUnlock()
if !exists {
return
}
trace.mu.Lock()
defer trace.mu.Unlock()
if details == nil {
details = make(map[string]string)
}
event := &TraceEvent{
Timestamp: time.Now(),
EventType: TraceEventFailed,
Source: source,
Details: details,
Error: err.Error(),
}
trace.Events = append(trace.Events, event)
trace.Status = "failed"
}
// CompleteTrace marks a trace as complete
func (tm *TraceManager) CompleteTrace(traceID string, status string) {
tm.mu.Lock()
trace, exists := tm.traces[traceID]
if !exists {
tm.mu.Unlock()
return
}
delete(tm.traces, traceID)
tm.mu.Unlock()
trace.mu.Lock()
trace.EndTime = time.Now()
trace.Duration = trace.EndTime.Sub(trace.StartTime)
trace.Status = status
trace.mu.Unlock()
// Store in persistent storage
if tm.storage != nil {
go func() {
if err := tm.storage.Store(context.Background(), trace); err != nil {
tm.logger.Error("Failed to store trace",
logger.Field{Key: "traceID", Value: traceID},
logger.Field{Key: "error", Value: err})
}
}()
}
if tm.onTraceComplete != nil {
go tm.onTraceComplete(trace)
}
tm.logger.Debug("Completed trace",
logger.Field{Key: "traceID", Value: traceID},
logger.Field{Key: "duration", Value: trace.Duration},
logger.Field{Key: "status", Value: status})
}
// GetTrace retrieves a trace by ID
func (tm *TraceManager) GetTrace(ctx context.Context, traceID string) (*MessageTrace, error) {
// Check in-memory traces first
tm.mu.RLock()
trace, exists := tm.traces[traceID]
tm.mu.RUnlock()
if exists {
return trace, nil
}
// Check persistent storage
if tm.storage != nil {
return tm.storage.Get(ctx, traceID)
}
return nil, fmt.Errorf("trace not found: %s", traceID)
}
// QueryTraces queries traces based on filter
func (tm *TraceManager) QueryTraces(ctx context.Context, filter TraceFilter) ([]*MessageTrace, error) {
if tm.storage != nil {
return tm.storage.Query(ctx, filter)
}
// Fallback to in-memory search
tm.mu.RLock()
defer tm.mu.RUnlock()
var results []*MessageTrace
for _, trace := range tm.traces {
if tm.matchesFilter(trace, filter) {
results = append(results, trace)
}
}
return results, nil
}
// matchesFilter checks if a trace matches the filter
func (tm *TraceManager) matchesFilter(trace *MessageTrace, filter TraceFilter) bool {
if filter.QueueName != "" && trace.QueueName != filter.QueueName {
return false
}
if filter.Status != "" && trace.Status != filter.Status {
return false
}
if !filter.StartTime.IsZero() && trace.StartTime.Before(filter.StartTime) {
return false
}
if !filter.EndTime.IsZero() && !trace.EndTime.IsZero() && trace.EndTime.After(filter.EndTime) {
return false
}
if filter.MinDuration > 0 && trace.Duration < filter.MinDuration {
return false
}
if filter.MaxDuration > 0 && trace.Duration > filter.MaxDuration {
return false
}
return true
}
// exportLoop periodically exports traces to storage
func (tm *TraceManager) exportLoop() {
ticker := time.NewTicker(tm.exportInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
tm.exportTraces()
case <-tm.shutdown:
return
}
}
}
// exportTraces exports active traces to storage
func (tm *TraceManager) exportTraces() {
if tm.storage == nil {
return
}
tm.mu.RLock()
traces := make([]*MessageTrace, 0, len(tm.traces))
for _, trace := range tm.traces {
traces = append(traces, trace)
}
tm.mu.RUnlock()
for _, trace := range traces {
if err := tm.storage.Store(context.Background(), trace); err != nil {
tm.logger.Error("Failed to export trace",
logger.Field{Key: "traceID", Value: trace.TraceID},
logger.Field{Key: "error", Value: err})
}
}
}
// cleanupLoop periodically cleans up old traces
func (tm *TraceManager) cleanupLoop() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
tm.cleanup()
case <-tm.shutdown:
return
}
}
}
// cleanup removes old traces
func (tm *TraceManager) cleanup() {
if tm.storage != nil {
deleted, err := tm.storage.DeleteOlderThan(context.Background(), tm.retention)
if err != nil {
tm.logger.Error("Failed to cleanup old traces",
logger.Field{Key: "error", Value: err})
} else if deleted > 0 {
tm.logger.Info("Cleaned up old traces",
logger.Field{Key: "deleted", Value: deleted})
}
}
}
// SetOnTraceComplete sets callback for trace completion
func (tm *TraceManager) SetOnTraceComplete(fn func(*MessageTrace)) {
tm.onTraceComplete = fn
}
// GetStats returns tracing statistics
func (tm *TraceManager) GetStats() map[string]interface{} {
tm.mu.RLock()
defer tm.mu.RUnlock()
activeTraces := len(tm.traces)
return map[string]interface{}{
"active_traces": activeTraces,
"retention": tm.retention,
"export_interval": tm.exportInterval,
}
}
// Shutdown gracefully shuts down the trace manager
func (tm *TraceManager) Shutdown(ctx context.Context) error {
close(tm.shutdown)
// Export remaining traces
tm.exportTraces()
if tm.storage != nil {
return tm.storage.Close()
}
return nil
}
// InMemoryTraceStorage implements in-memory trace storage
type InMemoryTraceStorage struct {
traces map[string]*MessageTrace
mu sync.RWMutex
}
// NewInMemoryTraceStorage creates a new in-memory trace storage
func NewInMemoryTraceStorage() *InMemoryTraceStorage {
return &InMemoryTraceStorage{
traces: make(map[string]*MessageTrace),
}
}
func (s *InMemoryTraceStorage) Store(ctx context.Context, trace *MessageTrace) error {
s.mu.Lock()
defer s.mu.Unlock()
// Deep copy to avoid concurrent modifications
data, err := json.Marshal(trace)
if err != nil {
return err
}
var traceCopy MessageTrace
if err := json.Unmarshal(data, &traceCopy); err != nil {
return err
}
s.traces[trace.TraceID] = &traceCopy
return nil
}
func (s *InMemoryTraceStorage) Get(ctx context.Context, traceID string) (*MessageTrace, error) {
s.mu.RLock()
defer s.mu.RUnlock()
trace, exists := s.traces[traceID]
if !exists {
return nil, fmt.Errorf("trace not found: %s", traceID)
}
return trace, nil
}
func (s *InMemoryTraceStorage) Query(ctx context.Context, filter TraceFilter) ([]*MessageTrace, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var results []*MessageTrace
for _, trace := range s.traces {
results = append(results, trace)
}
// Apply pagination
start := filter.Offset
if start >= len(results) {
return []*MessageTrace{}, nil
}
end := start + filter.Limit
if filter.Limit == 0 || end > len(results) {
end = len(results)
}
return results[start:end], nil
}
func (s *InMemoryTraceStorage) Delete(ctx context.Context, traceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.traces, traceID)
return nil
}
func (s *InMemoryTraceStorage) DeleteOlderThan(ctx context.Context, duration time.Duration) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
cutoff := time.Now().Add(-duration)
deleted := 0
for traceID, trace := range s.traces {
if !trace.EndTime.IsZero() && trace.EndTime.Before(cutoff) {
delete(s.traces, traceID)
deleted++
}
}
return deleted, nil
}
func (s *InMemoryTraceStorage) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
s.traces = nil
return nil
}
// MessageLifecycleTracker tracks the complete lifecycle of messages
type MessageLifecycleTracker struct {
traceManager *TraceManager
logger logger.Logger
}
// NewMessageLifecycleTracker creates a new lifecycle tracker
func NewMessageLifecycleTracker(traceManager *TraceManager, log logger.Logger) *MessageLifecycleTracker {
return &MessageLifecycleTracker{
traceManager: traceManager,
logger: log,
}
}
// TrackEnqueue tracks message enqueue
func (lt *MessageLifecycleTracker) TrackEnqueue(ctx context.Context, task *Task, queueName string) {
lt.traceManager.StartTrace(ctx, task, queueName)
}
// TrackDequeue tracks message dequeue
func (lt *MessageLifecycleTracker) TrackDequeue(traceID, consumerID string) {
lt.traceManager.RecordEvent(traceID, TraceEventDequeued, consumerID, nil)
}
// TrackProcessing tracks message processing start
func (lt *MessageLifecycleTracker) TrackProcessing(traceID, workerID string) {
lt.traceManager.RecordEvent(traceID, TraceEventProcessing, workerID, nil)
}
// TrackCompletion tracks successful completion
func (lt *MessageLifecycleTracker) TrackCompletion(traceID string, duration time.Duration) {
details := map[string]string{
"duration": duration.String(),
}
lt.traceManager.RecordEvent(traceID, TraceEventCompleted, "worker", details)
lt.traceManager.CompleteTrace(traceID, "completed")
}
// TrackError tracks processing error
func (lt *MessageLifecycleTracker) TrackError(traceID, source string, err error) {
lt.traceManager.RecordError(traceID, source, err, nil)
}
// TrackRetry tracks retry attempt
func (lt *MessageLifecycleTracker) TrackRetry(traceID string, retryCount int) {
details := map[string]string{
"retry_count": fmt.Sprintf("%d", retryCount),
}
lt.traceManager.RecordEvent(traceID, TraceEventRetried, "broker", details)
}
// TrackDLQ tracks movement to dead letter queue
func (lt *MessageLifecycleTracker) TrackDLQ(traceID, reason string) {
details := map[string]string{
"reason": reason,
}
lt.traceManager.RecordEvent(traceID, TraceEventDLQ, "broker", details)
lt.traceManager.CompleteTrace(traceID, "dlq")
}

444
wal.go Normal file
View File

@@ -0,0 +1,444 @@
package mq
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/oarkflow/mq/logger"
)
// WALEntry represents a single write-ahead log entry
type WALEntry struct {
EntryType WALEntryType `json:"entry_type"`
TaskID string `json:"task_id"`
QueueName string `json:"queue_name"`
Timestamp time.Time `json:"timestamp"`
Payload json.RawMessage `json:"payload,omitempty"`
SequenceID int64 `json:"sequence_id"`
}
// WALEntryType defines the type of WAL entry
type WALEntryType string
const (
WALEntryEnqueue WALEntryType = "ENQUEUE"
WALEntryDequeue WALEntryType = "DEQUEUE"
WALEntryComplete WALEntryType = "COMPLETE"
WALEntryFailed WALEntryType = "FAILED"
WALEntryCheckpoint WALEntryType = "CHECKPOINT"
)
// WriteAheadLog provides message durability through persistent logging
type WriteAheadLog struct {
dir string
currentFile *os.File
currentWriter *bufio.Writer
sequenceID int64
maxFileSize int64
syncInterval time.Duration
mu sync.Mutex
logger logger.Logger
shutdown chan struct{}
wg sync.WaitGroup
entries chan *WALEntry
fsyncOnWrite bool
}
// WALConfig holds configuration for the WAL
type WALConfig struct {
Directory string
MaxFileSize int64 // Maximum file size before rotation
SyncInterval time.Duration // Interval for syncing to disk
FsyncOnWrite bool // Sync after every write (slower but more durable)
Logger logger.Logger
}
// NewWriteAheadLog creates a new write-ahead log
func NewWriteAheadLog(config WALConfig) (*WriteAheadLog, error) {
if config.MaxFileSize == 0 {
config.MaxFileSize = 100 * 1024 * 1024 // 100MB default
}
if config.SyncInterval == 0 {
config.SyncInterval = 1 * time.Second
}
if err := os.MkdirAll(config.Directory, 0755); err != nil {
return nil, fmt.Errorf("failed to create WAL directory: %w", err)
}
wal := &WriteAheadLog{
dir: config.Directory,
maxFileSize: config.MaxFileSize,
syncInterval: config.SyncInterval,
logger: config.Logger,
shutdown: make(chan struct{}),
entries: make(chan *WALEntry, 10000),
fsyncOnWrite: config.FsyncOnWrite,
}
// Recover sequence ID from existing logs
if err := wal.recoverSequenceID(); err != nil {
return nil, fmt.Errorf("failed to recover sequence ID: %w", err)
}
// Open or create current log file
if err := wal.openNewFile(); err != nil {
return nil, fmt.Errorf("failed to open WAL file: %w", err)
}
// Start background workers
wal.wg.Add(2)
go wal.writeWorker()
go wal.syncWorker()
return wal, nil
}
// WriteEntry writes a new entry to the WAL
func (w *WriteAheadLog) WriteEntry(ctx context.Context, entry *WALEntry) error {
w.mu.Lock()
w.sequenceID++
entry.SequenceID = w.sequenceID
entry.Timestamp = time.Now()
w.mu.Unlock()
select {
case w.entries <- entry:
return nil
case <-ctx.Done():
return ctx.Err()
case <-w.shutdown:
return fmt.Errorf("WAL is shutting down")
}
}
// writeWorker processes WAL entries in the background
func (w *WriteAheadLog) writeWorker() {
defer w.wg.Done()
for {
select {
case entry := <-w.entries:
if err := w.writeEntryToFile(entry); err != nil {
w.logger.Error("Failed to write WAL entry",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "taskID", Value: entry.TaskID})
}
case <-w.shutdown:
// Drain remaining entries
for len(w.entries) > 0 {
entry := <-w.entries
_ = w.writeEntryToFile(entry)
}
return
}
}
}
// writeEntryToFile writes a single entry to the current WAL file
func (w *WriteAheadLog) writeEntryToFile(entry *WALEntry) error {
w.mu.Lock()
defer w.mu.Unlock()
data, err := json.Marshal(entry)
if err != nil {
return fmt.Errorf("failed to marshal WAL entry: %w", err)
}
// Write entry with newline delimiter
if _, err := w.currentWriter.Write(append(data, '\n')); err != nil {
return fmt.Errorf("failed to write WAL entry: %w", err)
}
if w.fsyncOnWrite {
if err := w.currentWriter.Flush(); err != nil {
return fmt.Errorf("failed to flush WAL: %w", err)
}
if err := w.currentFile.Sync(); err != nil {
return fmt.Errorf("failed to sync WAL: %w", err)
}
}
// Check if we need to rotate the file
stat, err := w.currentFile.Stat()
if err == nil && stat.Size() >= w.maxFileSize {
return w.rotateFile()
}
return nil
}
// syncWorker periodically syncs the WAL to disk
func (w *WriteAheadLog) syncWorker() {
defer w.wg.Done()
ticker := time.NewTicker(w.syncInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
w.mu.Lock()
if w.currentWriter != nil {
_ = w.currentWriter.Flush()
}
if w.currentFile != nil {
_ = w.currentFile.Sync()
}
w.mu.Unlock()
case <-w.shutdown:
return
}
}
}
// openNewFile creates a new WAL file
func (w *WriteAheadLog) openNewFile() error {
filename := fmt.Sprintf("wal-%d.log", time.Now().UnixNano())
filepath := filepath.Join(w.dir, filename)
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("failed to open WAL file: %w", err)
}
w.currentFile = file
w.currentWriter = bufio.NewWriter(file)
w.logger.Info("Opened new WAL file", logger.Field{Key: "filename", Value: filename})
return nil
}
// rotateFile rotates to a new WAL file
func (w *WriteAheadLog) rotateFile() error {
// Flush and close current file
if w.currentWriter != nil {
if err := w.currentWriter.Flush(); err != nil {
return err
}
}
if w.currentFile != nil {
if err := w.currentFile.Close(); err != nil {
return err
}
}
// Open new file
return w.openNewFile()
}
// recoverSequenceID recovers the last sequence ID from existing WAL files
func (w *WriteAheadLog) recoverSequenceID() error {
files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log"))
if err != nil {
return err
}
maxSeq := int64(0)
for _, filepath := range files {
file, err := os.Open(filepath)
if err != nil {
continue
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
var entry WALEntry
if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil {
continue
}
if entry.SequenceID > maxSeq {
maxSeq = entry.SequenceID
}
}
file.Close()
}
w.sequenceID = maxSeq
w.logger.Info("Recovered sequence ID", logger.Field{Key: "sequenceID", Value: maxSeq})
return nil
}
// Replay replays all WAL entries to recover state
func (w *WriteAheadLog) Replay(handler func(*WALEntry) error) error {
files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log"))
if err != nil {
return fmt.Errorf("failed to list WAL files: %w", err)
}
// Sort files by name (which includes timestamp)
// Simple bubble sort since the list is typically small
for i := 0; i < len(files)-1; i++ {
for j := i + 1; j < len(files); j++ {
if files[i] > files[j] {
files[i], files[j] = files[j], files[i]
}
}
}
entriesReplayed := 0
for _, filepath := range files {
file, err := os.Open(filepath)
if err != nil {
w.logger.Error("Failed to open WAL file for replay",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "file", Value: filepath})
continue
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
var entry WALEntry
if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil {
w.logger.Error("Failed to unmarshal WAL entry",
logger.Field{Key: "error", Value: err})
continue
}
if err := handler(&entry); err != nil {
w.logger.Error("Failed to replay WAL entry",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "taskID", Value: entry.TaskID})
continue
}
entriesReplayed++
}
file.Close()
}
w.logger.Info("WAL replay complete",
logger.Field{Key: "entries", Value: entriesReplayed})
return nil
}
// Checkpoint writes a checkpoint entry and optionally truncates old logs
func (w *WriteAheadLog) Checkpoint(ctx context.Context, state map[string]interface{}) error {
stateData, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("failed to marshal checkpoint state: %w", err)
}
entry := &WALEntry{
EntryType: WALEntryCheckpoint,
TaskID: "checkpoint",
Payload: stateData,
}
return w.WriteEntry(ctx, entry)
}
// TruncateOldLogs removes old WAL files (called after checkpoint)
func (w *WriteAheadLog) TruncateOldLogs(keepRecent int) error {
files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log"))
if err != nil {
return fmt.Errorf("failed to list WAL files: %w", err)
}
// Sort files by modification time
type fileInfo struct {
path string
modTime time.Time
}
var fileInfos []fileInfo
for _, path := range files {
stat, err := os.Stat(path)
if err != nil {
continue
}
fileInfos = append(fileInfos, fileInfo{path: path, modTime: stat.ModTime()})
}
// Sort by modification time (newest first)
for i := 0; i < len(fileInfos)-1; i++ {
for j := i + 1; j < len(fileInfos); j++ {
if fileInfos[i].modTime.Before(fileInfos[j].modTime) {
fileInfos[i], fileInfos[j] = fileInfos[j], fileInfos[i]
}
}
}
// Remove old files
removed := 0
for i := keepRecent; i < len(fileInfos); i++ {
if err := os.Remove(fileInfos[i].path); err != nil {
w.logger.Error("Failed to remove old WAL file",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "file", Value: fileInfos[i].path})
continue
}
removed++
}
w.logger.Info("Truncated old WAL files",
logger.Field{Key: "removed", Value: removed})
return nil
}
// Shutdown gracefully shuts down the WAL
func (w *WriteAheadLog) Shutdown(ctx context.Context) error {
close(w.shutdown)
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
w.wg.Wait()
close(done)
}()
select {
case <-done:
w.mu.Lock()
defer w.mu.Unlock()
if w.currentWriter != nil {
_ = w.currentWriter.Flush()
}
if w.currentFile != nil {
_ = w.currentFile.Sync()
_ = w.currentFile.Close()
}
w.logger.Info("WAL shutdown complete")
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// GetStats returns statistics about the WAL
func (w *WriteAheadLog) GetStats() map[string]interface{} {
w.mu.Lock()
defer w.mu.Unlock()
var currentFileSize int64
if w.currentFile != nil {
if stat, err := w.currentFile.Stat(); err == nil {
currentFileSize = stat.Size()
}
}
files, _ := filepath.Glob(filepath.Join(w.dir, "wal-*.log"))
return map[string]interface{}{
"current_sequence_id": w.sequenceID,
"current_file_size": currentFileSize,
"total_files": len(files),
"entries_backlog": len(w.entries),
}
}