Files
mq/snapshot.go
2025-10-01 19:46:14 +05:45

461 lines
12 KiB
Go

package mq
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/oarkflow/mq/logger"
)
// QueueSnapshot represents a point-in-time snapshot of queue state
type QueueSnapshot struct {
Timestamp time.Time `json:"timestamp"`
QueueName string `json:"queue_name"`
PendingTasks []*Task `json:"pending_tasks"`
ConsumerState map[string]*ConsumerMeta `json:"consumer_state"`
Metrics *QueueMetrics `json:"metrics"`
Config *QueueConfig `json:"config"`
Version int `json:"version"`
}
// ConsumerMeta represents consumer metadata in snapshot
type ConsumerMeta struct {
ID string `json:"id"`
State string `json:"state"`
LastActivity time.Time `json:"last_activity"`
TasksAssigned int `json:"tasks_assigned"`
}
// SnapshotManager manages queue snapshots and recovery
type SnapshotManager struct {
baseDir string
broker *Broker
snapshotInterval time.Duration
retentionPeriod time.Duration
mu sync.RWMutex
logger logger.Logger
shutdown chan struct{}
wg sync.WaitGroup
}
// SnapshotConfig holds snapshot configuration
type SnapshotConfig struct {
BaseDir string
SnapshotInterval time.Duration
RetentionPeriod time.Duration // How long to keep old snapshots
Logger logger.Logger
}
// NewSnapshotManager creates a new snapshot manager
func NewSnapshotManager(broker *Broker, config SnapshotConfig) (*SnapshotManager, error) {
if config.SnapshotInterval == 0 {
config.SnapshotInterval = 5 * time.Minute
}
if config.RetentionPeriod == 0 {
config.RetentionPeriod = 24 * time.Hour
}
if err := os.MkdirAll(config.BaseDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create snapshot directory: %w", err)
}
sm := &SnapshotManager{
baseDir: config.BaseDir,
broker: broker,
snapshotInterval: config.SnapshotInterval,
retentionPeriod: config.RetentionPeriod,
logger: config.Logger,
shutdown: make(chan struct{}),
}
// Start periodic snapshot worker
sm.wg.Add(1)
go sm.snapshotLoop()
return sm, nil
}
// snapshotLoop periodically creates snapshots
func (sm *SnapshotManager) snapshotLoop() {
defer sm.wg.Done()
ticker := time.NewTicker(sm.snapshotInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := sm.CreateSnapshotAll(context.Background()); err != nil {
sm.logger.Error("Failed to create periodic snapshot",
logger.Field{Key: "error", Value: err})
}
// Cleanup old snapshots
if err := sm.CleanupOldSnapshots(); err != nil {
sm.logger.Error("Failed to cleanup old snapshots",
logger.Field{Key: "error", Value: err})
}
case <-sm.shutdown:
return
}
}
}
// CreateSnapshot creates a snapshot of a specific queue
func (sm *SnapshotManager) CreateSnapshot(ctx context.Context, queueName string) (*QueueSnapshot, error) {
sm.mu.RLock()
defer sm.mu.RUnlock()
queue, exists := sm.broker.queues.Get(queueName)
if !exists {
return nil, fmt.Errorf("queue not found: %s", queueName)
}
// Collect pending tasks
var pendingTasks []*Task
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
// Drain channel into slice (non-blocking)
for {
select {
case qt := <-queue.tasks:
pendingTasks = append(pendingTasks, qt.Task)
// Put it back
queue.tasks <- qt
default:
goto done
}
}
}
done:
// Collect consumer state
consumerState := make(map[string]*ConsumerMeta)
queue.consumers.ForEach(func(id string, consumer *consumer) bool {
consumerState[id] = &ConsumerMeta{
ID: id,
State: string(consumer.state),
LastActivity: consumer.metrics.LastActivity,
}
return true
})
snapshot := &QueueSnapshot{
Timestamp: time.Now(),
QueueName: queueName,
PendingTasks: pendingTasks,
ConsumerState: consumerState,
Metrics: queue.metrics,
Config: queue.config,
Version: 1,
}
// Persist snapshot to disk
if err := sm.persistSnapshot(snapshot); err != nil {
return nil, fmt.Errorf("failed to persist snapshot: %w", err)
}
sm.logger.Info("Created queue snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "pendingTasks", Value: len(pendingTasks)},
logger.Field{Key: "consumers", Value: len(consumerState)})
return snapshot, nil
}
// CreateSnapshotAll creates snapshots of all queues
func (sm *SnapshotManager) CreateSnapshotAll(ctx context.Context) error {
var snapshots []*QueueSnapshot
var mu sync.Mutex
var wg sync.WaitGroup
sm.broker.queues.ForEach(func(queueName string, _ *Queue) bool {
wg.Add(1)
go func(name string) {
defer wg.Done()
snapshot, err := sm.CreateSnapshot(ctx, name)
if err != nil {
sm.logger.Error("Failed to create snapshot",
logger.Field{Key: "queue", Value: name},
logger.Field{Key: "error", Value: err})
return
}
mu.Lock()
snapshots = append(snapshots, snapshot)
mu.Unlock()
}(queueName)
return true
})
wg.Wait()
sm.logger.Info("Created snapshots for all queues",
logger.Field{Key: "count", Value: len(snapshots)})
return nil
}
// persistSnapshot writes a snapshot to disk
func (sm *SnapshotManager) persistSnapshot(snapshot *QueueSnapshot) error {
data, err := json.MarshalIndent(snapshot, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal snapshot: %w", err)
}
queueDir := filepath.Join(sm.baseDir, snapshot.QueueName)
if err := os.MkdirAll(queueDir, 0755); err != nil {
return fmt.Errorf("failed to create queue directory: %w", err)
}
filename := fmt.Sprintf("snapshot-%d.json", snapshot.Timestamp.UnixNano())
filePath := filepath.Join(queueDir, filename)
// Write atomically using temp file
tempPath := filePath + ".tmp"
if err := os.WriteFile(tempPath, data, 0644); err != nil {
return fmt.Errorf("failed to write snapshot: %w", err)
}
if err := os.Rename(tempPath, filePath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to rename snapshot: %w", err)
}
return nil
}
// RestoreFromSnapshot restores a queue from the latest snapshot
func (sm *SnapshotManager) RestoreFromSnapshot(ctx context.Context, queueName string) error {
snapshot, err := sm.GetLatestSnapshot(queueName)
if err != nil {
return fmt.Errorf("failed to get latest snapshot: %w", err)
}
if snapshot == nil {
return fmt.Errorf("no snapshot found for queue: %s", queueName)
}
sm.logger.Info("Restoring queue from snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "timestamp", Value: snapshot.Timestamp},
logger.Field{Key: "tasks", Value: len(snapshot.PendingTasks)})
// Get or create queue
queue, exists := sm.broker.queues.Get(queueName)
if !exists {
queue = sm.broker.NewQueue(queueName)
}
// Restore pending tasks
restored := 0
for _, task := range snapshot.PendingTasks {
select {
case queue.tasks <- &QueuedTask{Task: task}:
restored++
case <-ctx.Done():
return ctx.Err()
default:
sm.logger.Warn("Queue full during restore, task skipped",
logger.Field{Key: "taskID", Value: task.ID})
}
}
// Restore metrics
if snapshot.Metrics != nil {
queue.metrics = snapshot.Metrics
}
// Restore config
if snapshot.Config != nil {
queue.config = snapshot.Config
}
sm.logger.Info("Queue restored from snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "restoredTasks", Value: restored})
return nil
}
// GetLatestSnapshot retrieves the latest snapshot for a queue
func (sm *SnapshotManager) GetLatestSnapshot(queueName string) (*QueueSnapshot, error) {
queueDir := filepath.Join(sm.baseDir, queueName)
files, err := filepath.Glob(filepath.Join(queueDir, "snapshot-*.json"))
if err != nil {
return nil, fmt.Errorf("failed to list snapshots: %w", err)
}
if len(files) == 0 {
return nil, nil
}
// Sort files to get latest (files are named with timestamp)
latestFile := files[0]
for _, file := range files {
if file > latestFile {
latestFile = file
}
}
// Read and parse snapshot
data, err := os.ReadFile(latestFile)
if err != nil {
return nil, fmt.Errorf("failed to read snapshot: %w", err)
}
var snapshot QueueSnapshot
if err := json.Unmarshal(data, &snapshot); err != nil {
return nil, fmt.Errorf("failed to unmarshal snapshot: %w", err)
}
return &snapshot, nil
}
// ListSnapshots lists all snapshots for a queue
func (sm *SnapshotManager) ListSnapshots(queueName string) ([]*QueueSnapshot, error) {
queueDir := filepath.Join(sm.baseDir, queueName)
files, err := filepath.Glob(filepath.Join(queueDir, "snapshot-*.json"))
if err != nil {
return nil, fmt.Errorf("failed to list snapshots: %w", err)
}
snapshots := make([]*QueueSnapshot, 0, len(files))
for _, file := range files {
data, err := os.ReadFile(file)
if err != nil {
sm.logger.Error("Failed to read snapshot file",
logger.Field{Key: "file", Value: file},
logger.Field{Key: "error", Value: err})
continue
}
var snapshot QueueSnapshot
if err := json.Unmarshal(data, &snapshot); err != nil {
sm.logger.Error("Failed to unmarshal snapshot",
logger.Field{Key: "file", Value: file},
logger.Field{Key: "error", Value: err})
continue
}
snapshots = append(snapshots, &snapshot)
}
return snapshots, nil
}
// CleanupOldSnapshots removes snapshots older than retention period
func (sm *SnapshotManager) CleanupOldSnapshots() error {
cutoff := time.Now().Add(-sm.retentionPeriod)
removed := 0
err := filepath.Walk(sm.baseDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || filepath.Ext(path) != ".json" {
return nil
}
if info.ModTime().Before(cutoff) {
if err := os.Remove(path); err != nil {
sm.logger.Error("Failed to remove old snapshot",
logger.Field{Key: "file", Value: path},
logger.Field{Key: "error", Value: err})
return nil
}
removed++
}
return nil
})
if err != nil {
return err
}
if removed > 0 {
sm.logger.Info("Cleaned up old snapshots",
logger.Field{Key: "removed", Value: removed})
}
return nil
}
// DeleteSnapshot deletes a specific snapshot
func (sm *SnapshotManager) DeleteSnapshot(queueName string, timestamp time.Time) error {
filename := fmt.Sprintf("snapshot-%d.json", timestamp.UnixNano())
filePath := filepath.Join(sm.baseDir, queueName, filename)
if err := os.Remove(filePath); err != nil {
return fmt.Errorf("failed to delete snapshot: %w", err)
}
sm.logger.Info("Deleted snapshot",
logger.Field{Key: "queue", Value: queueName},
logger.Field{Key: "timestamp", Value: timestamp})
return nil
}
// GetSnapshotStats returns statistics about snapshots
func (sm *SnapshotManager) GetSnapshotStats() map[string]any {
totalSnapshots := 0
var totalSize int64
filepath.Walk(sm.baseDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && filepath.Ext(path) == ".json" {
totalSnapshots++
totalSize += info.Size()
}
return nil
})
return map[string]any{
"total_snapshots": totalSnapshots,
"total_size_bytes": totalSize,
"retention_period": sm.retentionPeriod,
"snapshot_interval": sm.snapshotInterval,
}
}
// Shutdown gracefully shuts down the snapshot manager
func (sm *SnapshotManager) Shutdown(ctx context.Context) error {
close(sm.shutdown)
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
sm.wg.Wait()
close(done)
}()
// Create final snapshot before shutdown
if err := sm.CreateSnapshotAll(ctx); err != nil {
sm.logger.Error("Failed to create final snapshot",
logger.Field{Key: "error", Value: err})
}
select {
case <-done:
sm.logger.Info("Snapshot manager shutdown complete")
return nil
case <-ctx.Done():
return ctx.Err()
}
}