Files
mq/wal.go
2025-10-01 19:46:14 +05:45

445 lines
10 KiB
Go

package mq
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/oarkflow/mq/logger"
)
// WALEntry represents a single write-ahead log entry
type WALEntry struct {
EntryType WALEntryType `json:"entry_type"`
TaskID string `json:"task_id"`
QueueName string `json:"queue_name"`
Timestamp time.Time `json:"timestamp"`
Payload json.RawMessage `json:"payload,omitempty"`
SequenceID int64 `json:"sequence_id"`
}
// WALEntryType defines the type of WAL entry
type WALEntryType string
const (
WALEntryEnqueue WALEntryType = "ENQUEUE"
WALEntryDequeue WALEntryType = "DEQUEUE"
WALEntryComplete WALEntryType = "COMPLETE"
WALEntryFailed WALEntryType = "FAILED"
WALEntryCheckpoint WALEntryType = "CHECKPOINT"
)
// WriteAheadLog provides message durability through persistent logging
type WriteAheadLog struct {
dir string
currentFile *os.File
currentWriter *bufio.Writer
sequenceID int64
maxFileSize int64
syncInterval time.Duration
mu sync.Mutex
logger logger.Logger
shutdown chan struct{}
wg sync.WaitGroup
entries chan *WALEntry
fsyncOnWrite bool
}
// WALConfig holds configuration for the WAL
type WALConfig struct {
Directory string
MaxFileSize int64 // Maximum file size before rotation
SyncInterval time.Duration // Interval for syncing to disk
FsyncOnWrite bool // Sync after every write (slower but more durable)
Logger logger.Logger
}
// NewWriteAheadLog creates a new write-ahead log
func NewWriteAheadLog(config WALConfig) (*WriteAheadLog, error) {
if config.MaxFileSize == 0 {
config.MaxFileSize = 100 * 1024 * 1024 // 100MB default
}
if config.SyncInterval == 0 {
config.SyncInterval = 1 * time.Second
}
if err := os.MkdirAll(config.Directory, 0755); err != nil {
return nil, fmt.Errorf("failed to create WAL directory: %w", err)
}
wal := &WriteAheadLog{
dir: config.Directory,
maxFileSize: config.MaxFileSize,
syncInterval: config.SyncInterval,
logger: config.Logger,
shutdown: make(chan struct{}),
entries: make(chan *WALEntry, 10000),
fsyncOnWrite: config.FsyncOnWrite,
}
// Recover sequence ID from existing logs
if err := wal.recoverSequenceID(); err != nil {
return nil, fmt.Errorf("failed to recover sequence ID: %w", err)
}
// Open or create current log file
if err := wal.openNewFile(); err != nil {
return nil, fmt.Errorf("failed to open WAL file: %w", err)
}
// Start background workers
wal.wg.Add(2)
go wal.writeWorker()
go wal.syncWorker()
return wal, nil
}
// WriteEntry writes a new entry to the WAL
func (w *WriteAheadLog) WriteEntry(ctx context.Context, entry *WALEntry) error {
w.mu.Lock()
w.sequenceID++
entry.SequenceID = w.sequenceID
entry.Timestamp = time.Now()
w.mu.Unlock()
select {
case w.entries <- entry:
return nil
case <-ctx.Done():
return ctx.Err()
case <-w.shutdown:
return fmt.Errorf("WAL is shutting down")
}
}
// writeWorker processes WAL entries in the background
func (w *WriteAheadLog) writeWorker() {
defer w.wg.Done()
for {
select {
case entry := <-w.entries:
if err := w.writeEntryToFile(entry); err != nil {
w.logger.Error("Failed to write WAL entry",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "taskID", Value: entry.TaskID})
}
case <-w.shutdown:
// Drain remaining entries
for len(w.entries) > 0 {
entry := <-w.entries
_ = w.writeEntryToFile(entry)
}
return
}
}
}
// writeEntryToFile writes a single entry to the current WAL file
func (w *WriteAheadLog) writeEntryToFile(entry *WALEntry) error {
w.mu.Lock()
defer w.mu.Unlock()
data, err := json.Marshal(entry)
if err != nil {
return fmt.Errorf("failed to marshal WAL entry: %w", err)
}
// Write entry with newline delimiter
if _, err := w.currentWriter.Write(append(data, '\n')); err != nil {
return fmt.Errorf("failed to write WAL entry: %w", err)
}
if w.fsyncOnWrite {
if err := w.currentWriter.Flush(); err != nil {
return fmt.Errorf("failed to flush WAL: %w", err)
}
if err := w.currentFile.Sync(); err != nil {
return fmt.Errorf("failed to sync WAL: %w", err)
}
}
// Check if we need to rotate the file
stat, err := w.currentFile.Stat()
if err == nil && stat.Size() >= w.maxFileSize {
return w.rotateFile()
}
return nil
}
// syncWorker periodically syncs the WAL to disk
func (w *WriteAheadLog) syncWorker() {
defer w.wg.Done()
ticker := time.NewTicker(w.syncInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
w.mu.Lock()
if w.currentWriter != nil {
_ = w.currentWriter.Flush()
}
if w.currentFile != nil {
_ = w.currentFile.Sync()
}
w.mu.Unlock()
case <-w.shutdown:
return
}
}
}
// openNewFile creates a new WAL file
func (w *WriteAheadLog) openNewFile() error {
filename := fmt.Sprintf("wal-%d.log", time.Now().UnixNano())
filepath := filepath.Join(w.dir, filename)
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("failed to open WAL file: %w", err)
}
w.currentFile = file
w.currentWriter = bufio.NewWriter(file)
w.logger.Info("Opened new WAL file", logger.Field{Key: "filename", Value: filename})
return nil
}
// rotateFile rotates to a new WAL file
func (w *WriteAheadLog) rotateFile() error {
// Flush and close current file
if w.currentWriter != nil {
if err := w.currentWriter.Flush(); err != nil {
return err
}
}
if w.currentFile != nil {
if err := w.currentFile.Close(); err != nil {
return err
}
}
// Open new file
return w.openNewFile()
}
// recoverSequenceID recovers the last sequence ID from existing WAL files
func (w *WriteAheadLog) recoverSequenceID() error {
files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log"))
if err != nil {
return err
}
maxSeq := int64(0)
for _, filepath := range files {
file, err := os.Open(filepath)
if err != nil {
continue
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
var entry WALEntry
if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil {
continue
}
if entry.SequenceID > maxSeq {
maxSeq = entry.SequenceID
}
}
file.Close()
}
w.sequenceID = maxSeq
w.logger.Info("Recovered sequence ID", logger.Field{Key: "sequenceID", Value: maxSeq})
return nil
}
// Replay replays all WAL entries to recover state
func (w *WriteAheadLog) Replay(handler func(*WALEntry) error) error {
files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log"))
if err != nil {
return fmt.Errorf("failed to list WAL files: %w", err)
}
// Sort files by name (which includes timestamp)
// Simple bubble sort since the list is typically small
for i := 0; i < len(files)-1; i++ {
for j := i + 1; j < len(files); j++ {
if files[i] > files[j] {
files[i], files[j] = files[j], files[i]
}
}
}
entriesReplayed := 0
for _, filepath := range files {
file, err := os.Open(filepath)
if err != nil {
w.logger.Error("Failed to open WAL file for replay",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "file", Value: filepath})
continue
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
var entry WALEntry
if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil {
w.logger.Error("Failed to unmarshal WAL entry",
logger.Field{Key: "error", Value: err})
continue
}
if err := handler(&entry); err != nil {
w.logger.Error("Failed to replay WAL entry",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "taskID", Value: entry.TaskID})
continue
}
entriesReplayed++
}
file.Close()
}
w.logger.Info("WAL replay complete",
logger.Field{Key: "entries", Value: entriesReplayed})
return nil
}
// Checkpoint writes a checkpoint entry and optionally truncates old logs
func (w *WriteAheadLog) Checkpoint(ctx context.Context, state map[string]any) error {
stateData, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("failed to marshal checkpoint state: %w", err)
}
entry := &WALEntry{
EntryType: WALEntryCheckpoint,
TaskID: "checkpoint",
Payload: stateData,
}
return w.WriteEntry(ctx, entry)
}
// TruncateOldLogs removes old WAL files (called after checkpoint)
func (w *WriteAheadLog) TruncateOldLogs(keepRecent int) error {
files, err := filepath.Glob(filepath.Join(w.dir, "wal-*.log"))
if err != nil {
return fmt.Errorf("failed to list WAL files: %w", err)
}
// Sort files by modification time
type fileInfo struct {
path string
modTime time.Time
}
var fileInfos []fileInfo
for _, path := range files {
stat, err := os.Stat(path)
if err != nil {
continue
}
fileInfos = append(fileInfos, fileInfo{path: path, modTime: stat.ModTime()})
}
// Sort by modification time (newest first)
for i := 0; i < len(fileInfos)-1; i++ {
for j := i + 1; j < len(fileInfos); j++ {
if fileInfos[i].modTime.Before(fileInfos[j].modTime) {
fileInfos[i], fileInfos[j] = fileInfos[j], fileInfos[i]
}
}
}
// Remove old files
removed := 0
for i := keepRecent; i < len(fileInfos); i++ {
if err := os.Remove(fileInfos[i].path); err != nil {
w.logger.Error("Failed to remove old WAL file",
logger.Field{Key: "error", Value: err},
logger.Field{Key: "file", Value: fileInfos[i].path})
continue
}
removed++
}
w.logger.Info("Truncated old WAL files",
logger.Field{Key: "removed", Value: removed})
return nil
}
// Shutdown gracefully shuts down the WAL
func (w *WriteAheadLog) Shutdown(ctx context.Context) error {
close(w.shutdown)
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
w.wg.Wait()
close(done)
}()
select {
case <-done:
w.mu.Lock()
defer w.mu.Unlock()
if w.currentWriter != nil {
_ = w.currentWriter.Flush()
}
if w.currentFile != nil {
_ = w.currentFile.Sync()
_ = w.currentFile.Close()
}
w.logger.Info("WAL shutdown complete")
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// GetStats returns statistics about the WAL
func (w *WriteAheadLog) GetStats() map[string]any {
w.mu.Lock()
defer w.mu.Unlock()
var currentFileSize int64
if w.currentFile != nil {
if stat, err := w.currentFile.Stat(); err == nil {
currentFileSize = stat.Size()
}
}
files, _ := filepath.Glob(filepath.Join(w.dir, "wal-*.log"))
return map[string]any{
"current_sequence_id": w.sequenceID,
"current_file_size": currentFileSize,
"total_files": len(files),
"entries_backlog": len(w.entries),
}
}