mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-06 16:36:53 +08:00
feat: update
This commit is contained in:
162
dag/dag.go
162
dag/dag.go
@@ -15,6 +15,8 @@ import (
|
||||
"github.com/oarkflow/json"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
dagstorage "github.com/oarkflow/mq/dag/storage"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/logger"
|
||||
"github.com/oarkflow/mq/sio"
|
||||
@@ -125,6 +127,159 @@ type DAG struct {
|
||||
globalMiddlewares []mq.Handler
|
||||
nodeMiddlewares map[string][]mq.Handler
|
||||
middlewaresMu sync.RWMutex
|
||||
|
||||
// Task storage for persistence
|
||||
taskStorage dagstorage.TaskStorage
|
||||
}
|
||||
|
||||
// SetTaskStorage sets the task storage for persistence
|
||||
func (d *DAG) SetTaskStorage(storage dagstorage.TaskStorage) {
|
||||
d.taskStorage = storage
|
||||
}
|
||||
|
||||
// GetTaskStorage returns the current task storage
|
||||
func (d *DAG) GetTaskStorage() dagstorage.TaskStorage {
|
||||
return d.taskStorage
|
||||
}
|
||||
|
||||
// GetTasks retrieves tasks for this DAG with optional status filtering
|
||||
func (d *DAG) GetTasks(ctx context.Context, status *dagstorage.TaskStatus, limit int, offset int) ([]*dagstorage.PersistentTask, error) {
|
||||
if d.taskStorage == nil {
|
||||
return nil, fmt.Errorf("task storage not configured")
|
||||
}
|
||||
|
||||
if status != nil {
|
||||
return d.taskStorage.GetTasksByStatus(ctx, d.key, *status)
|
||||
}
|
||||
return d.taskStorage.GetTasksByDAG(ctx, d.key, limit, offset)
|
||||
}
|
||||
|
||||
// GetTaskActivityLogs retrieves activity logs for this DAG
|
||||
func (d *DAG) GetTaskActivityLogs(ctx context.Context, limit int, offset int) ([]*dagstorage.TaskActivityLog, error) {
|
||||
if d.taskStorage == nil {
|
||||
return nil, fmt.Errorf("task storage not configured")
|
||||
}
|
||||
|
||||
return d.taskStorage.GetActivityLogsByDAG(ctx, d.key, limit, offset)
|
||||
}
|
||||
|
||||
// RecoverTasks loads and resumes pending/running tasks from storage
|
||||
func (d *DAG) RecoverTasks(ctx context.Context) error {
|
||||
if d.taskStorage == nil {
|
||||
return fmt.Errorf("task storage not configured")
|
||||
}
|
||||
|
||||
d.Logger().Info("Starting task recovery", logger.Field{Key: "dagID", Value: d.key})
|
||||
|
||||
// Get all resumable tasks for this DAG
|
||||
resumableTasks, err := d.taskStorage.GetResumableTasks(ctx, d.key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get resumable tasks: %w", err)
|
||||
}
|
||||
|
||||
d.Logger().Info("Found tasks to recover", logger.Field{Key: "count", Value: len(resumableTasks)})
|
||||
|
||||
// Resume each task from its last known position
|
||||
for _, task := range resumableTasks {
|
||||
if err := d.resumeTaskFromStorage(ctx, task); err != nil {
|
||||
d.Logger().Error("Failed to resume task",
|
||||
logger.Field{Key: "taskID", Value: task.ID},
|
||||
logger.Field{Key: "error", Value: err.Error()})
|
||||
continue
|
||||
}
|
||||
d.Logger().Info("Successfully resumed task",
|
||||
logger.Field{Key: "taskID", Value: task.ID},
|
||||
logger.Field{Key: "currentNode", Value: task.CurrentNodeID})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resumeTaskFromStorage resumes a task from its stored position
|
||||
func (d *DAG) resumeTaskFromStorage(ctx context.Context, task *dagstorage.PersistentTask) error {
|
||||
// Determine the node to resume from
|
||||
resumeNodeID := task.CurrentNodeID
|
||||
if resumeNodeID == "" {
|
||||
resumeNodeID = task.NodeID // Fallback to original node
|
||||
}
|
||||
|
||||
// Check if the node exists (but don't use the variable)
|
||||
_, exists := d.nodes.Get(resumeNodeID)
|
||||
if !exists {
|
||||
return fmt.Errorf("resume node %s not found in DAG", resumeNodeID)
|
||||
}
|
||||
|
||||
// Create a new task manager for this task if it doesn't exist
|
||||
manager, exists := d.taskManager.Get(task.ID)
|
||||
if !exists {
|
||||
resultCh := make(chan mq.Result, 1)
|
||||
manager = NewTaskManager(d, task.ID, resultCh, d.iteratorNodes.Clone(), d.taskStorage)
|
||||
d.taskManager.Set(task.ID, manager)
|
||||
}
|
||||
|
||||
// Resume the task from the stored position using TaskManager's ProcessTask
|
||||
if task.Status == dagstorage.TaskStatusPending {
|
||||
// Re-enqueue the task
|
||||
manager.ProcessTask(ctx, resumeNodeID, task.Payload)
|
||||
} else if task.Status == dagstorage.TaskStatusRunning {
|
||||
// Task was in progress, resume from current node
|
||||
manager.ProcessTask(ctx, resumeNodeID, task.Payload)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigureMemoryStorage configures the DAG to use in-memory storage
|
||||
func (d *DAG) ConfigureMemoryStorage() {
|
||||
d.taskStorage = dagstorage.NewMemoryTaskStorage()
|
||||
}
|
||||
|
||||
// ConfigurePostgresStorage configures the DAG to use PostgreSQL storage
|
||||
func (d *DAG) ConfigurePostgresStorage(dsn string, opts ...dagstorage.StorageOption) error {
|
||||
config := &dagstorage.TaskStorageConfig{
|
||||
Type: "postgres",
|
||||
DSN: dsn,
|
||||
MaxOpenConns: 10,
|
||||
MaxIdleConns: 5,
|
||||
ConnMaxLifetime: 5 * time.Minute,
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
storage, err := dagstorage.NewSQLTaskStorage(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create postgres storage: %w", err)
|
||||
}
|
||||
|
||||
d.taskStorage = storage
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigureSQLiteStorage configures the DAG to use SQLite storage
|
||||
func (d *DAG) ConfigureSQLiteStorage(dbPath string, opts ...dagstorage.StorageOption) error {
|
||||
config := &dagstorage.TaskStorageConfig{
|
||||
Type: "sqlite",
|
||||
DSN: dbPath,
|
||||
MaxOpenConns: 1, // SQLite works best with single connection
|
||||
MaxIdleConns: 1,
|
||||
ConnMaxLifetime: 0, // No limit for SQLite
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
storage, err := dagstorage.NewSQLTaskStorage(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create sqlite storage: %w", err)
|
||||
}
|
||||
|
||||
d.taskStorage = storage
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPreProcessHook configures a function to be called before each node is processed.
|
||||
@@ -280,6 +435,7 @@ func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.
|
||||
nextNodesCache: make(map[string][]*Node),
|
||||
prevNodesCache: make(map[string][]*Node),
|
||||
nodeMiddlewares: make(map[string][]mq.Handler),
|
||||
taskStorage: dagstorage.NewMemoryTaskStorage(), // Initialize default memory storage
|
||||
}
|
||||
|
||||
opts = append(opts,
|
||||
@@ -603,7 +759,7 @@ func (tm *DAG) processTaskInternal(ctx context.Context, task *mq.Task) mq.Result
|
||||
manager, ok := tm.taskManager.Get(task.ID)
|
||||
resultCh := make(chan mq.Result, 1)
|
||||
if !ok {
|
||||
manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone())
|
||||
manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone(), tm.taskStorage)
|
||||
tm.taskManager.Set(task.ID, manager)
|
||||
tm.Logger().Info("Processing task",
|
||||
logger.Field{Key: "taskID", Value: task.ID},
|
||||
@@ -717,7 +873,7 @@ func (tm *DAG) ProcessTaskNew(ctx context.Context, task *mq.Task) mq.Result {
|
||||
manager, ok := tm.taskManager.Get(task.ID)
|
||||
resultCh := make(chan mq.Result, 1)
|
||||
if !ok {
|
||||
manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone())
|
||||
manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone(), tm.taskStorage)
|
||||
tm.taskManager.Set(task.ID, manager)
|
||||
tm.Logger().Info("Processing task",
|
||||
logger.Field{Key: "taskID", Value: task.ID},
|
||||
@@ -1055,7 +1211,7 @@ func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.Sche
|
||||
manager, ok := tm.taskManager.Get(taskID)
|
||||
resultCh := make(chan mq.Result, 1)
|
||||
if !ok {
|
||||
manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone())
|
||||
manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone(), tm.taskStorage)
|
||||
tm.taskManager.Set(taskID, manager)
|
||||
} else {
|
||||
manager.resultCh = resultCh
|
||||
|
117
dag/storage/interface.go
Normal file
117
dag/storage/interface.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
)
|
||||
|
||||
// TaskStatus represents the status of a task
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusPending TaskStatus = "pending"
|
||||
TaskStatusRunning TaskStatus = "running"
|
||||
TaskStatusCompleted TaskStatus = "completed"
|
||||
TaskStatusFailed TaskStatus = "failed"
|
||||
TaskStatusCancelled TaskStatus = "cancelled"
|
||||
)
|
||||
|
||||
// PersistentTask represents a task that can be stored persistently
|
||||
type PersistentTask struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
DAGID string `json:"dag_id" db:"dag_id"`
|
||||
NodeID string `json:"node_id" db:"node_id"`
|
||||
CurrentNodeID string `json:"current_node_id" db:"current_node_id"` // Node where task is currently processing
|
||||
SubDAGPath string `json:"sub_dag_path" db:"sub_dag_path"` // Path through nested DAGs (e.g., "subdag1.subdag2")
|
||||
ProcessingState string `json:"processing_state" db:"processing_state"` // Current processing state (pending, processing, waiting, etc.)
|
||||
Payload json.RawMessage `json:"payload" db:"payload"`
|
||||
Status TaskStatus `json:"status" db:"status"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty" db:"started_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||
Error string `json:"error,omitempty" db:"error"`
|
||||
RetryCount int `json:"retry_count" db:"retry_count"`
|
||||
MaxRetries int `json:"max_retries" db:"max_retries"`
|
||||
Priority int `json:"priority" db:"priority"`
|
||||
}
|
||||
|
||||
// TaskActivityLog represents an activity log entry for a task
|
||||
type TaskActivityLog struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
TaskID string `json:"task_id" db:"task_id"`
|
||||
DAGID string `json:"dag_id" db:"dag_id"`
|
||||
NodeID string `json:"node_id" db:"node_id"`
|
||||
Action string `json:"action" db:"action"`
|
||||
Message string `json:"message" db:"message"`
|
||||
Data json.RawMessage `json:"data,omitempty" db:"data"`
|
||||
Level string `json:"level" db:"level"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
}
|
||||
|
||||
// TaskStorage defines the interface for task storage operations
|
||||
type TaskStorage interface {
|
||||
// Task operations
|
||||
SaveTask(ctx context.Context, task *PersistentTask) error
|
||||
GetTask(ctx context.Context, taskID string) (*PersistentTask, error)
|
||||
GetTasksByDAG(ctx context.Context, dagID string, limit int, offset int) ([]*PersistentTask, error)
|
||||
GetTasksByStatus(ctx context.Context, dagID string, status TaskStatus) ([]*PersistentTask, error)
|
||||
UpdateTaskStatus(ctx context.Context, taskID string, status TaskStatus, errorMsg string) error
|
||||
DeleteTask(ctx context.Context, taskID string) error
|
||||
DeleteTasksByDAG(ctx context.Context, dagID string) error
|
||||
|
||||
// Activity logging
|
||||
LogActivity(ctx context.Context, log *TaskActivityLog) error
|
||||
GetActivityLogs(ctx context.Context, taskID string, limit int, offset int) ([]*TaskActivityLog, error)
|
||||
GetActivityLogsByDAG(ctx context.Context, dagID string, limit int, offset int) ([]*TaskActivityLog, error)
|
||||
|
||||
// Batch operations
|
||||
SaveTasks(ctx context.Context, tasks []*PersistentTask) error
|
||||
GetPendingTasks(ctx context.Context, dagID string, limit int) ([]*PersistentTask, error)
|
||||
|
||||
// Recovery operations
|
||||
GetResumableTasks(ctx context.Context, dagID string) ([]*PersistentTask, error)
|
||||
|
||||
// Cleanup operations
|
||||
CleanupOldTasks(ctx context.Context, dagID string, olderThan time.Time) error
|
||||
CleanupOldActivityLogs(ctx context.Context, dagID string, olderThan time.Time) error
|
||||
|
||||
// Health check
|
||||
Ping(ctx context.Context) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// TaskStorageConfig holds configuration for task storage
|
||||
type TaskStorageConfig struct {
|
||||
Type string // "memory", "postgres", "sqlite"
|
||||
DSN string // Database connection string
|
||||
MaxOpenConns int // Maximum open connections
|
||||
MaxIdleConns int // Maximum idle connections
|
||||
ConnMaxLifetime time.Duration // Connection max lifetime
|
||||
}
|
||||
|
||||
// StorageOption is a function that configures TaskStorageConfig
|
||||
type StorageOption func(*TaskStorageConfig)
|
||||
|
||||
// WithMaxOpenConns sets the maximum number of open connections
|
||||
func WithMaxOpenConns(maxOpen int) StorageOption {
|
||||
return func(config *TaskStorageConfig) {
|
||||
config.MaxOpenConns = maxOpen
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxIdleConns sets the maximum number of idle connections
|
||||
func WithMaxIdleConns(maxIdle int) StorageOption {
|
||||
return func(config *TaskStorageConfig) {
|
||||
config.MaxIdleConns = maxIdle
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnMaxLifetime sets the maximum lifetime of connections
|
||||
func WithConnMaxLifetime(lifetime time.Duration) StorageOption {
|
||||
return func(config *TaskStorageConfig) {
|
||||
config.ConnMaxLifetime = lifetime
|
||||
}
|
||||
}
|
399
dag/storage/memory.go
Normal file
399
dag/storage/memory.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xid/wuid"
|
||||
)
|
||||
|
||||
// MemoryTaskStorage implements TaskStorage using in-memory storage
|
||||
type MemoryTaskStorage struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*PersistentTask
|
||||
activityLogs map[string][]*TaskActivityLog // taskID -> logs
|
||||
dagTasks map[string][]string // dagID -> taskIDs
|
||||
}
|
||||
|
||||
// NewMemoryTaskStorage creates a new memory-based task storage
|
||||
func NewMemoryTaskStorage() *MemoryTaskStorage {
|
||||
return &MemoryTaskStorage{
|
||||
tasks: make(map[string]*PersistentTask),
|
||||
activityLogs: make(map[string][]*TaskActivityLog),
|
||||
dagTasks: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// SaveTask saves a task to memory
|
||||
func (m *MemoryTaskStorage) SaveTask(ctx context.Context, task *PersistentTask) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if task.ID == "" {
|
||||
task.ID = wuid.New().String()
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now()
|
||||
}
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
m.tasks[task.ID] = task
|
||||
|
||||
// Add to DAG index
|
||||
if _, exists := m.dagTasks[task.DAGID]; !exists {
|
||||
m.dagTasks[task.DAGID] = make([]string, 0)
|
||||
}
|
||||
m.dagTasks[task.DAGID] = append(m.dagTasks[task.DAGID], task.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTask retrieves a task by ID
|
||||
func (m *MemoryTaskStorage) GetTask(ctx context.Context, taskID string) (*PersistentTask, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
task, exists := m.tasks[taskID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("task not found: %s", taskID)
|
||||
}
|
||||
|
||||
// Return a copy to prevent external modifications
|
||||
taskCopy := *task
|
||||
return &taskCopy, nil
|
||||
}
|
||||
|
||||
// GetTasksByDAG retrieves tasks for a specific DAG
|
||||
func (m *MemoryTaskStorage) GetTasksByDAG(ctx context.Context, dagID string, limit int, offset int) ([]*PersistentTask, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
taskIDs, exists := m.dagTasks[dagID]
|
||||
if !exists {
|
||||
return []*PersistentTask{}, nil
|
||||
}
|
||||
|
||||
tasks := make([]*PersistentTask, 0, len(taskIDs))
|
||||
for _, taskID := range taskIDs {
|
||||
if task, exists := m.tasks[taskID]; exists {
|
||||
taskCopy := *task
|
||||
tasks = append(tasks, &taskCopy)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by creation time (newest first)
|
||||
sort.Slice(tasks, func(i, j int) bool {
|
||||
return tasks[i].CreatedAt.After(tasks[j].CreatedAt)
|
||||
})
|
||||
|
||||
// Apply pagination
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if start > len(tasks) {
|
||||
return []*PersistentTask{}, nil
|
||||
}
|
||||
if end > len(tasks) {
|
||||
end = len(tasks)
|
||||
}
|
||||
|
||||
return tasks[start:end], nil
|
||||
}
|
||||
|
||||
// GetTasksByStatus retrieves tasks by status for a specific DAG
|
||||
func (m *MemoryTaskStorage) GetTasksByStatus(ctx context.Context, dagID string, status TaskStatus) ([]*PersistentTask, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
taskIDs, exists := m.dagTasks[dagID]
|
||||
if !exists {
|
||||
return []*PersistentTask{}, nil
|
||||
}
|
||||
|
||||
tasks := make([]*PersistentTask, 0)
|
||||
for _, taskID := range taskIDs {
|
||||
if task, exists := m.tasks[taskID]; exists && task.Status == status {
|
||||
taskCopy := *task
|
||||
tasks = append(tasks, &taskCopy)
|
||||
}
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// UpdateTaskStatus updates the status of a task
|
||||
func (m *MemoryTaskStorage) UpdateTaskStatus(ctx context.Context, taskID string, status TaskStatus, errorMsg string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
task, exists := m.tasks[taskID]
|
||||
if !exists {
|
||||
return fmt.Errorf("task not found: %s", taskID)
|
||||
}
|
||||
|
||||
task.Status = status
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
if status == TaskStatusCompleted || status == TaskStatusFailed {
|
||||
now := time.Now()
|
||||
task.CompletedAt = &now
|
||||
}
|
||||
|
||||
if errorMsg != "" {
|
||||
task.Error = errorMsg
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTask deletes a task
|
||||
func (m *MemoryTaskStorage) DeleteTask(ctx context.Context, taskID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
task, exists := m.tasks[taskID]
|
||||
if !exists {
|
||||
return fmt.Errorf("task not found: %s", taskID)
|
||||
}
|
||||
|
||||
delete(m.tasks, taskID)
|
||||
delete(m.activityLogs, taskID)
|
||||
|
||||
// Remove from DAG index
|
||||
if taskIDs, exists := m.dagTasks[task.DAGID]; exists {
|
||||
for i, id := range taskIDs {
|
||||
if id == taskID {
|
||||
m.dagTasks[task.DAGID] = append(taskIDs[:i], taskIDs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTasksByDAG deletes all tasks for a specific DAG
|
||||
func (m *MemoryTaskStorage) DeleteTasksByDAG(ctx context.Context, dagID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
taskIDs, exists := m.dagTasks[dagID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, taskID := range taskIDs {
|
||||
delete(m.tasks, taskID)
|
||||
delete(m.activityLogs, taskID)
|
||||
}
|
||||
|
||||
delete(m.dagTasks, dagID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogActivity logs an activity for a task
|
||||
func (m *MemoryTaskStorage) LogActivity(ctx context.Context, logEntry *TaskActivityLog) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if logEntry.ID == "" {
|
||||
logEntry.ID = wuid.New().String()
|
||||
}
|
||||
if logEntry.CreatedAt.IsZero() {
|
||||
logEntry.CreatedAt = time.Now()
|
||||
}
|
||||
|
||||
if _, exists := m.activityLogs[logEntry.TaskID]; !exists {
|
||||
m.activityLogs[logEntry.TaskID] = make([]*TaskActivityLog, 0)
|
||||
}
|
||||
|
||||
m.activityLogs[logEntry.TaskID] = append(m.activityLogs[logEntry.TaskID], logEntry)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActivityLogs retrieves activity logs for a task
|
||||
func (m *MemoryTaskStorage) GetActivityLogs(ctx context.Context, taskID string, limit int, offset int) ([]*TaskActivityLog, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
logs, exists := m.activityLogs[taskID]
|
||||
if !exists {
|
||||
return []*TaskActivityLog{}, nil
|
||||
}
|
||||
|
||||
// Sort by creation time (newest first)
|
||||
sortedLogs := make([]*TaskActivityLog, len(logs))
|
||||
copy(sortedLogs, logs)
|
||||
sort.Slice(sortedLogs, func(i, j int) bool {
|
||||
return sortedLogs[i].CreatedAt.After(sortedLogs[j].CreatedAt)
|
||||
})
|
||||
|
||||
// Apply pagination
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if start > len(sortedLogs) {
|
||||
return []*TaskActivityLog{}, nil
|
||||
}
|
||||
if end > len(sortedLogs) {
|
||||
end = len(sortedLogs)
|
||||
}
|
||||
|
||||
return sortedLogs[start:end], nil
|
||||
}
|
||||
|
||||
// GetActivityLogsByDAG retrieves activity logs for all tasks in a DAG
|
||||
func (m *MemoryTaskStorage) GetActivityLogsByDAG(ctx context.Context, dagID string, limit int, offset int) ([]*TaskActivityLog, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
taskIDs, exists := m.dagTasks[dagID]
|
||||
if !exists {
|
||||
return []*TaskActivityLog{}, nil
|
||||
}
|
||||
|
||||
allLogs := make([]*TaskActivityLog, 0)
|
||||
for _, taskID := range taskIDs {
|
||||
if logs, exists := m.activityLogs[taskID]; exists {
|
||||
allLogs = append(allLogs, logs...)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by creation time (newest first)
|
||||
sort.Slice(allLogs, func(i, j int) bool {
|
||||
return allLogs[i].CreatedAt.After(allLogs[j].CreatedAt)
|
||||
})
|
||||
|
||||
// Apply pagination
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if start > len(allLogs) {
|
||||
return []*TaskActivityLog{}, nil
|
||||
}
|
||||
if end > len(allLogs) {
|
||||
end = len(allLogs)
|
||||
}
|
||||
|
||||
return allLogs[start:end], nil
|
||||
}
|
||||
|
||||
// SaveTasks saves multiple tasks
|
||||
func (m *MemoryTaskStorage) SaveTasks(ctx context.Context, tasks []*PersistentTask) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, task := range tasks {
|
||||
if task.ID == "" {
|
||||
task.ID = wuid.New().String()
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now()
|
||||
}
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
m.tasks[task.ID] = task
|
||||
|
||||
// Add to DAG index
|
||||
if _, exists := m.dagTasks[task.DAGID]; !exists {
|
||||
m.dagTasks[task.DAGID] = make([]string, 0)
|
||||
}
|
||||
m.dagTasks[task.DAGID] = append(m.dagTasks[task.DAGID], task.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPendingTasks retrieves pending tasks for a DAG
|
||||
func (m *MemoryTaskStorage) GetPendingTasks(ctx context.Context, dagID string, limit int) ([]*PersistentTask, error) {
|
||||
return m.GetTasksByStatus(ctx, dagID, TaskStatusPending)
|
||||
}
|
||||
|
||||
// CleanupOldTasks removes tasks older than the specified time
|
||||
func (m *MemoryTaskStorage) CleanupOldTasks(ctx context.Context, dagID string, olderThan time.Time) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
taskIDs, exists := m.dagTasks[dagID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
updatedTaskIDs := make([]string, 0)
|
||||
for _, taskID := range taskIDs {
|
||||
if task, exists := m.tasks[taskID]; exists {
|
||||
if task.CreatedAt.Before(olderThan) {
|
||||
delete(m.tasks, taskID)
|
||||
delete(m.activityLogs, taskID)
|
||||
} else {
|
||||
updatedTaskIDs = append(updatedTaskIDs, taskID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.dagTasks[dagID] = updatedTaskIDs
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupOldActivityLogs removes activity logs older than the specified time
|
||||
func (m *MemoryTaskStorage) CleanupOldActivityLogs(ctx context.Context, dagID string, olderThan time.Time) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
taskIDs, exists := m.dagTasks[dagID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, taskID := range taskIDs {
|
||||
if logs, exists := m.activityLogs[taskID]; exists {
|
||||
updatedLogs := make([]*TaskActivityLog, 0)
|
||||
for _, log := range logs {
|
||||
if log.CreatedAt.After(olderThan) {
|
||||
updatedLogs = append(updatedLogs, log)
|
||||
}
|
||||
}
|
||||
if len(updatedLogs) > 0 {
|
||||
m.activityLogs[taskID] = updatedLogs
|
||||
} else {
|
||||
delete(m.activityLogs, taskID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetResumableTasks gets tasks that can be resumed (pending or running status)
|
||||
func (m *MemoryTaskStorage) GetResumableTasks(ctx context.Context, dagID string) ([]*PersistentTask, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
taskIDs, exists := m.dagTasks[dagID]
|
||||
if !exists {
|
||||
return []*PersistentTask{}, nil
|
||||
}
|
||||
|
||||
var resumableTasks []*PersistentTask
|
||||
for _, taskID := range taskIDs {
|
||||
if task, exists := m.tasks[taskID]; exists {
|
||||
if task.Status == TaskStatusPending || task.Status == TaskStatusRunning {
|
||||
// Return a copy to prevent external modifications
|
||||
taskCopy := *task
|
||||
resumableTasks = append(resumableTasks, &taskCopy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resumableTasks, nil
|
||||
}
|
||||
|
||||
// Ping checks if the storage is healthy
|
||||
func (m *MemoryTaskStorage) Ping(ctx context.Context) error {
|
||||
return nil // Memory storage is always healthy
|
||||
}
|
||||
|
||||
// Close closes the storage (no-op for memory)
|
||||
func (m *MemoryTaskStorage) Close() error {
|
||||
return nil
|
||||
}
|
640
dag/storage/sql.go
Normal file
640
dag/storage/sql.go
Normal file
@@ -0,0 +1,640 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq" // PostgreSQL driver
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"github.com/oarkflow/json"
|
||||
"github.com/oarkflow/squealx"
|
||||
"github.com/oarkflow/xid/wuid"
|
||||
)
|
||||
|
||||
// SQLTaskStorage implements TaskStorage using SQL databases
|
||||
type SQLTaskStorage struct {
|
||||
db *squealx.DB
|
||||
config *TaskStorageConfig
|
||||
}
|
||||
|
||||
// NewSQLTaskStorage creates a new SQL-based task storage
|
||||
func NewSQLTaskStorage(config *TaskStorageConfig) (*SQLTaskStorage, error) {
|
||||
db, err := squealx.Open(config.Type, config.DSN, "task-storage")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
if config.MaxOpenConns > 0 {
|
||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
||||
}
|
||||
if config.MaxIdleConns > 0 {
|
||||
db.SetMaxIdleConns(config.MaxIdleConns)
|
||||
}
|
||||
if config.ConnMaxLifetime > 0 {
|
||||
db.SetConnMaxLifetime(config.ConnMaxLifetime)
|
||||
}
|
||||
|
||||
storage := &SQLTaskStorage{
|
||||
db: db,
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Create tables
|
||||
if err := storage.createTables(context.Background()); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to create tables: %w", err)
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// createTables creates the necessary database tables
|
||||
func (s *SQLTaskStorage) createTables(ctx context.Context) error {
|
||||
tasksTable := `
|
||||
CREATE TABLE IF NOT EXISTS dag_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
dag_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
current_node_id TEXT,
|
||||
sub_dag_path TEXT,
|
||||
processing_state TEXT,
|
||||
payload TEXT,
|
||||
status TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
error TEXT,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
max_retries INTEGER DEFAULT 3,
|
||||
priority INTEGER DEFAULT 0
|
||||
)`
|
||||
|
||||
activityLogsTable := `
|
||||
CREATE TABLE IF NOT EXISTS dag_task_activity_logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
task_id TEXT NOT NULL,
|
||||
dag_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
message TEXT,
|
||||
data TEXT,
|
||||
level TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY (task_id) REFERENCES dag_tasks(id) ON DELETE CASCADE
|
||||
)`
|
||||
|
||||
// Create indexes for better performance
|
||||
indexes := []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_dag_tasks_dag_id ON dag_tasks(dag_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_dag_tasks_status ON dag_tasks(status)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_dag_tasks_created_at ON dag_tasks(created_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_activity_logs_task_id ON dag_task_activity_logs(task_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_activity_logs_dag_id ON dag_task_activity_logs(dag_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_activity_logs_created_at ON dag_task_activity_logs(created_at)`,
|
||||
}
|
||||
|
||||
// Execute table creation
|
||||
if _, err := s.db.ExecContext(ctx, tasksTable); err != nil {
|
||||
return fmt.Errorf("failed to create tasks table: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, activityLogsTable); err != nil {
|
||||
return fmt.Errorf("failed to create activity logs table: %w", err)
|
||||
}
|
||||
|
||||
// Execute index creation
|
||||
for _, index := range indexes {
|
||||
if _, err := s.db.ExecContext(ctx, index); err != nil {
|
||||
return fmt.Errorf("failed to create index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveTask saves a task to the database
|
||||
func (s *SQLTaskStorage) SaveTask(ctx context.Context, task *PersistentTask) error {
|
||||
if task.ID == "" {
|
||||
task.ID = wuid.New().String()
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now()
|
||||
}
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
query := `
|
||||
INSERT INTO dag_tasks (id, dag_id, node_id, current_node_id, sub_dag_path, processing_state,
|
||||
payload, status, created_at, updated_at, started_at, completed_at,
|
||||
error, retry_count, max_retries, priority)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
node_id = excluded.node_id,
|
||||
current_node_id = excluded.current_node_id,
|
||||
sub_dag_path = excluded.sub_dag_path,
|
||||
processing_state = excluded.processing_state,
|
||||
payload = excluded.payload,
|
||||
status = excluded.status,
|
||||
updated_at = excluded.updated_at,
|
||||
started_at = excluded.started_at,
|
||||
completed_at = excluded.completed_at,
|
||||
error = excluded.error,
|
||||
retry_count = excluded.retry_count,
|
||||
max_retries = excluded.max_retries,
|
||||
priority = excluded.priority`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, s.placeholderQuery(query),
|
||||
task.ID, task.DAGID, task.NodeID, task.CurrentNodeID, task.SubDAGPath, task.ProcessingState,
|
||||
string(task.Payload), task.Status, task.CreatedAt, task.UpdatedAt, task.StartedAt, task.CompletedAt,
|
||||
task.Error, task.RetryCount, task.MaxRetries, task.Priority)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTask retrieves a task by ID
|
||||
func (s *SQLTaskStorage) GetTask(ctx context.Context, taskID string) (*PersistentTask, error) {
|
||||
query := `
|
||||
SELECT id, dag_id, node_id, current_node_id, sub_dag_path, processing_state,
|
||||
payload, status, created_at, updated_at, started_at, completed_at,
|
||||
error, retry_count, max_retries, priority
|
||||
FROM dag_tasks WHERE id = ?`
|
||||
|
||||
var task PersistentTask
|
||||
var payload sql.NullString
|
||||
var currentNodeID, subDAGPath, processingState sql.NullString
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var error sql.NullString
|
||||
|
||||
err := s.db.QueryRowContext(ctx, query, taskID).Scan(
|
||||
&task.ID, &task.DAGID, &task.NodeID, ¤tNodeID, &subDAGPath, &processingState,
|
||||
&payload, &task.Status, &task.CreatedAt, &task.UpdatedAt, &startedAt, &completedAt,
|
||||
&error, &task.RetryCount, &task.MaxRetries, &task.Priority)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("task not found: %s", taskID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if currentNodeID.Valid {
|
||||
task.CurrentNodeID = currentNodeID.String
|
||||
}
|
||||
if subDAGPath.Valid {
|
||||
task.SubDAGPath = subDAGPath.String
|
||||
}
|
||||
if processingState.Valid {
|
||||
task.ProcessingState = processingState.String
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("task not found: %s", taskID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if payload.Valid {
|
||||
task.Payload = []byte(payload.String)
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if error.Valid {
|
||||
task.Error = error.String
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// GetTasksByDAG retrieves tasks for a specific DAG
|
||||
func (s *SQLTaskStorage) GetTasksByDAG(ctx context.Context, dagID string, limit int, offset int) ([]*PersistentTask, error) {
|
||||
query := `
|
||||
SELECT id, dag_id, node_id, payload, status, created_at, updated_at,
|
||||
started_at, completed_at, error, retry_count, max_retries, priority
|
||||
FROM dag_tasks
|
||||
WHERE dag_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, dagID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tasks := make([]*PersistentTask, 0)
|
||||
for rows.Next() {
|
||||
var task PersistentTask
|
||||
var payload sql.NullString
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var error sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID, &task.DAGID, &task.NodeID, &payload, &task.Status,
|
||||
&task.CreatedAt, &task.UpdatedAt, &startedAt, &completedAt,
|
||||
&error, &task.RetryCount, &task.MaxRetries, &task.Priority)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if payload.Valid {
|
||||
task.Payload = []byte(payload.String)
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if error.Valid {
|
||||
task.Error = error.String
|
||||
}
|
||||
|
||||
tasks = append(tasks, &task)
|
||||
}
|
||||
|
||||
return tasks, rows.Err()
|
||||
}
|
||||
|
||||
// GetTasksByStatus retrieves tasks by status for a specific DAG
|
||||
func (s *SQLTaskStorage) GetTasksByStatus(ctx context.Context, dagID string, status TaskStatus) ([]*PersistentTask, error) {
|
||||
query := `
|
||||
SELECT id, dag_id, node_id, payload, status, created_at, updated_at,
|
||||
started_at, completed_at, error, retry_count, max_retries, priority
|
||||
FROM dag_tasks
|
||||
WHERE dag_id = ? AND status = ?
|
||||
ORDER BY created_at DESC`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, dagID, status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tasks := make([]*PersistentTask, 0)
|
||||
for rows.Next() {
|
||||
var task PersistentTask
|
||||
var payload sql.NullString
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var error sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID, &task.DAGID, &task.NodeID, &payload, &task.Status,
|
||||
&task.CreatedAt, &task.UpdatedAt, &startedAt, &completedAt,
|
||||
&error, &task.RetryCount, &task.MaxRetries, &task.Priority)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if payload.Valid {
|
||||
task.Payload = []byte(payload.String)
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if error.Valid {
|
||||
task.Error = error.String
|
||||
}
|
||||
|
||||
tasks = append(tasks, &task)
|
||||
}
|
||||
|
||||
return tasks, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateTaskStatus updates the status of a task
|
||||
func (s *SQLTaskStorage) UpdateTaskStatus(ctx context.Context, taskID string, status TaskStatus, errorMsg string) error {
|
||||
now := time.Now()
|
||||
query := `
|
||||
UPDATE dag_tasks
|
||||
SET status = ?, updated_at = ?, completed_at = ?, error = ?
|
||||
WHERE id = ?`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query, status, now, now, errorMsg, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteTask deletes a task
|
||||
func (s *SQLTaskStorage) DeleteTask(ctx context.Context, taskID string) error {
|
||||
query := `DELETE FROM dag_tasks WHERE id = ?`
|
||||
_, err := s.db.ExecContext(ctx, query, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteTasksByDAG deletes all tasks for a specific DAG
|
||||
func (s *SQLTaskStorage) DeleteTasksByDAG(ctx context.Context, dagID string) error {
|
||||
query := `DELETE FROM dag_tasks WHERE dag_id = ?`
|
||||
_, err := s.db.ExecContext(ctx, query, dagID)
|
||||
return err
|
||||
}
|
||||
|
||||
// LogActivity logs an activity for a task
|
||||
func (s *SQLTaskStorage) LogActivity(ctx context.Context, logEntry *TaskActivityLog) error {
|
||||
if logEntry.ID == "" {
|
||||
logEntry.ID = wuid.New().String()
|
||||
}
|
||||
if logEntry.CreatedAt.IsZero() {
|
||||
logEntry.CreatedAt = time.Now()
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO dag_task_activity_logs (id, task_id, dag_id, node_id, action, message, data, level, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query,
|
||||
logEntry.ID, logEntry.TaskID, logEntry.DAGID, logEntry.NodeID,
|
||||
logEntry.Action, logEntry.Message, string(logEntry.Data), logEntry.Level, logEntry.CreatedAt)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetActivityLogs retrieves activity logs for a task
|
||||
func (s *SQLTaskStorage) GetActivityLogs(ctx context.Context, taskID string, limit int, offset int) ([]*TaskActivityLog, error) {
|
||||
query := `
|
||||
SELECT id, task_id, dag_id, node_id, action, message, data, level, created_at
|
||||
FROM dag_task_activity_logs
|
||||
WHERE task_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, taskID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
logs := make([]*TaskActivityLog, 0)
|
||||
for rows.Next() {
|
||||
var log TaskActivityLog
|
||||
var message, data sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&log.ID, &log.TaskID, &log.DAGID, &log.NodeID, &log.Action,
|
||||
&message, &data, &log.Level, &log.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if message.Valid {
|
||||
log.Message = message.String
|
||||
}
|
||||
if data.Valid {
|
||||
log.Data = []byte(data.String)
|
||||
}
|
||||
|
||||
logs = append(logs, &log)
|
||||
}
|
||||
|
||||
return logs, rows.Err()
|
||||
}
|
||||
|
||||
// GetActivityLogsByDAG retrieves activity logs for all tasks in a DAG
|
||||
func (s *SQLTaskStorage) GetActivityLogsByDAG(ctx context.Context, dagID string, limit int, offset int) ([]*TaskActivityLog, error) {
|
||||
query := `
|
||||
SELECT id, task_id, dag_id, node_id, action, message, data, level, created_at
|
||||
FROM dag_task_activity_logs
|
||||
WHERE dag_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, dagID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
logs := make([]*TaskActivityLog, 0)
|
||||
for rows.Next() {
|
||||
var log TaskActivityLog
|
||||
var message, data sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&log.ID, &log.TaskID, &log.DAGID, &log.NodeID, &log.Action,
|
||||
&message, &data, &log.Level, &log.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if message.Valid {
|
||||
log.Message = message.String
|
||||
}
|
||||
if data.Valid {
|
||||
log.Data = []byte(data.String)
|
||||
}
|
||||
|
||||
logs = append(logs, &log)
|
||||
}
|
||||
|
||||
return logs, rows.Err()
|
||||
}
|
||||
|
||||
// SaveTasks saves multiple tasks
|
||||
func (s *SQLTaskStorage) SaveTasks(ctx context.Context, tasks []*PersistentTask) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, task := range tasks {
|
||||
if task.ID == "" {
|
||||
task.ID = wuid.New().String()
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now()
|
||||
}
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
query := `
|
||||
INSERT INTO dag_tasks (id, dag_id, node_id, payload, status, created_at, updated_at,
|
||||
started_at, completed_at, error, retry_count, max_retries, priority)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
node_id = excluded.node_id,
|
||||
payload = excluded.payload,
|
||||
status = excluded.status,
|
||||
updated_at = excluded.updated_at,
|
||||
started_at = excluded.started_at,
|
||||
completed_at = excluded.completed_at,
|
||||
error = excluded.error,
|
||||
retry_count = excluded.retry_count,
|
||||
max_retries = excluded.max_retries,
|
||||
priority = excluded.priority`
|
||||
|
||||
_, err := tx.ExecContext(ctx, s.placeholderQuery(query),
|
||||
task.ID, task.DAGID, task.NodeID, string(task.Payload), task.Status,
|
||||
task.CreatedAt, task.UpdatedAt, task.StartedAt, task.CompletedAt,
|
||||
task.Error, task.RetryCount, task.MaxRetries, task.Priority)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetPendingTasks retrieves pending tasks for a DAG
|
||||
func (s *SQLTaskStorage) GetPendingTasks(ctx context.Context, dagID string, limit int) ([]*PersistentTask, error) {
|
||||
query := `
|
||||
SELECT id, dag_id, node_id, payload, status, created_at, updated_at,
|
||||
started_at, completed_at, error, retry_count, max_retries, priority
|
||||
FROM dag_tasks
|
||||
WHERE dag_id = ? AND status = ?
|
||||
ORDER BY priority DESC, created_at ASC
|
||||
LIMIT ?`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, dagID, TaskStatusPending, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tasks := make([]*PersistentTask, 0)
|
||||
for rows.Next() {
|
||||
var task PersistentTask
|
||||
var payload sql.NullString
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var error sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID, &task.DAGID, &task.NodeID, &payload, &task.Status,
|
||||
&task.CreatedAt, &task.UpdatedAt, &startedAt, &completedAt,
|
||||
&error, &task.RetryCount, &task.MaxRetries, &task.Priority)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if payload.Valid {
|
||||
task.Payload = []byte(payload.String)
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if error.Valid {
|
||||
task.Error = error.String
|
||||
}
|
||||
|
||||
tasks = append(tasks, &task)
|
||||
}
|
||||
|
||||
return tasks, rows.Err()
|
||||
}
|
||||
|
||||
// CleanupOldTasks removes tasks older than the specified time
|
||||
func (s *SQLTaskStorage) CleanupOldTasks(ctx context.Context, dagID string, olderThan time.Time) error {
|
||||
query := `DELETE FROM dag_tasks WHERE dag_id = ? AND created_at < ?`
|
||||
_, err := s.db.ExecContext(ctx, query, dagID, olderThan)
|
||||
return err
|
||||
}
|
||||
|
||||
// CleanupOldActivityLogs removes activity logs older than the specified time
|
||||
func (s *SQLTaskStorage) CleanupOldActivityLogs(ctx context.Context, dagID string, olderThan time.Time) error {
|
||||
query := `DELETE FROM dag_task_activity_logs WHERE dag_id = ? AND created_at < ?`
|
||||
_, err := s.db.ExecContext(ctx, query, dagID, olderThan)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetResumableTasks gets tasks that can be resumed (pending or running status)
|
||||
func (s *SQLTaskStorage) GetResumableTasks(ctx context.Context, dagID string) ([]*PersistentTask, error) {
|
||||
query := `
|
||||
SELECT id, dag_id, node_id, current_node_id, sub_dag_path, processing_state,
|
||||
payload, status, created_at, updated_at, started_at, completed_at,
|
||||
error, retry_count, max_retries, priority
|
||||
FROM dag_tasks
|
||||
WHERE dag_id = ? AND status IN (?, ?)
|
||||
ORDER BY created_at ASC`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, s.placeholderQuery(query), dagID, TaskStatusPending, TaskStatusRunning)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []*PersistentTask
|
||||
for rows.Next() {
|
||||
var task PersistentTask
|
||||
var payload sql.NullString
|
||||
var currentNodeID, subDAGPath, processingState sql.NullString
|
||||
var startedAt, completedAt sql.NullTime
|
||||
var error sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID, &task.DAGID, &task.NodeID, ¤tNodeID, &subDAGPath, &processingState,
|
||||
&payload, &task.Status, &task.CreatedAt, &task.UpdatedAt, &startedAt, &completedAt,
|
||||
&error, &task.RetryCount, &task.MaxRetries, &task.Priority)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if payload.Valid {
|
||||
task.Payload = json.RawMessage(payload.String)
|
||||
}
|
||||
if currentNodeID.Valid {
|
||||
task.CurrentNodeID = currentNodeID.String
|
||||
}
|
||||
if subDAGPath.Valid {
|
||||
task.SubDAGPath = subDAGPath.String
|
||||
}
|
||||
if processingState.Valid {
|
||||
task.ProcessingState = processingState.String
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if error.Valid {
|
||||
task.Error = error.String
|
||||
}
|
||||
|
||||
tasks = append(tasks, &task)
|
||||
}
|
||||
|
||||
return tasks, rows.Err()
|
||||
}
|
||||
|
||||
// Ping checks if the database is healthy
|
||||
func (s *SQLTaskStorage) Ping(ctx context.Context) error {
|
||||
return s.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (s *SQLTaskStorage) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// placeholderQuery converts ? placeholders to the appropriate format for the database
|
||||
func (s *SQLTaskStorage) placeholderQuery(query string) string {
|
||||
if s.config.Type == "postgres" {
|
||||
return strings.ReplaceAll(query, "?", "$1")
|
||||
}
|
||||
return query // SQLite uses ?
|
||||
}
|
||||
|
||||
// GetDB returns the underlying database connection
|
||||
func (s *SQLTaskStorage) GetDB() *sql.DB {
|
||||
return s.db.DB()
|
||||
}
|
21
dag/storage/wal_storage.go
Normal file
21
dag/storage/wal_storage.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// WALMemoryTaskStorage implements TaskStorage with WAL support using memory storage
|
||||
type WALMemoryTaskStorage struct {
|
||||
*MemoryTaskStorage
|
||||
walManager interface{} // WAL manager interface to avoid import cycle
|
||||
walStorage interface{} // WAL storage interface to avoid import cycle
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// WALSQLTaskStorage implements TaskStorage with WAL support using SQL storage
|
||||
type WALSQLTaskStorage struct {
|
||||
*SQLTaskStorage
|
||||
walManager interface{} // WAL manager interface to avoid import cycle
|
||||
walStorage interface{} // WAL storage interface to avoid import cycle
|
||||
mu sync.RWMutex
|
||||
}
|
275
dag/storage_test.go
Normal file
275
dag/storage_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package dag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
"github.com/oarkflow/mq"
|
||||
dagstorage "github.com/oarkflow/mq/dag/storage"
|
||||
)
|
||||
|
||||
func TestDAGWithMemoryStorage(t *testing.T) {
|
||||
// Create a new DAG
|
||||
dag := NewDAG("test-dag", "test-key", func(taskID string, result mq.Result) {
|
||||
t.Logf("Task completed: %s", taskID)
|
||||
})
|
||||
|
||||
// Configure memory storage
|
||||
dag.ConfigureMemoryStorage()
|
||||
|
||||
// Verify storage is configured
|
||||
if dag.GetTaskStorage() == nil {
|
||||
t.Fatal("Task storage should be configured")
|
||||
}
|
||||
|
||||
// Create a simple task
|
||||
ctx := context.Background()
|
||||
payload := json.RawMessage(`{"test": "data"}`)
|
||||
|
||||
// Test task storage directly
|
||||
task := &dagstorage.PersistentTask{
|
||||
ID: "test-task-1",
|
||||
DAGID: "test-key",
|
||||
NodeID: "test-node",
|
||||
Status: dagstorage.TaskStatusPending,
|
||||
Payload: payload,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Save task
|
||||
err := dag.GetTaskStorage().SaveTask(ctx, task)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save task: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve task
|
||||
retrieved, err := dag.GetTaskStorage().GetTask(ctx, "test-task-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve task: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.ID != "test-task-1" {
|
||||
t.Errorf("Expected task ID 'test-task-1', got '%s'", retrieved.ID)
|
||||
}
|
||||
|
||||
if retrieved.DAGID != "test-key" {
|
||||
t.Errorf("Expected DAG ID 'test-key', got '%s'", retrieved.DAGID)
|
||||
}
|
||||
|
||||
// Test DAG isolation - tasks from different DAG should not be accessible
|
||||
// (This would require a more complex test with multiple DAGs)
|
||||
|
||||
// Test activity logging
|
||||
logEntry := &dagstorage.TaskActivityLog{
|
||||
TaskID: "test-task-1",
|
||||
DAGID: "test-key",
|
||||
NodeID: "test-node",
|
||||
Action: "test_action",
|
||||
Message: "Test activity",
|
||||
Level: "info",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
err = dag.GetTaskStorage().LogActivity(ctx, logEntry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to log activity: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve activity logs
|
||||
logs, err := dag.GetTaskStorage().GetActivityLogs(ctx, "test-task-1", 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve activity logs: %v", err)
|
||||
}
|
||||
|
||||
if len(logs) != 1 {
|
||||
t.Errorf("Expected 1 activity log, got %d", len(logs))
|
||||
}
|
||||
|
||||
if len(logs) > 0 && logs[0].Action != "test_action" {
|
||||
t.Errorf("Expected action 'test_action', got '%s'", logs[0].Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDAGTaskRecovery(t *testing.T) {
|
||||
// Create a new DAG
|
||||
dag := NewDAG("recovery-test", "recovery-key", func(taskID string, result mq.Result) {
|
||||
t.Logf("Task completed: %s", taskID)
|
||||
})
|
||||
|
||||
// Configure memory storage
|
||||
dag.ConfigureMemoryStorage()
|
||||
|
||||
// Add a test node to the DAG
|
||||
testNode := &Node{
|
||||
ID: "test-node",
|
||||
Label: "Test Node",
|
||||
processor: &TestProcessor{},
|
||||
}
|
||||
dag.nodes.Set("test-node", testNode)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and save a task
|
||||
task := &dagstorage.PersistentTask{
|
||||
ID: "recovery-task-1",
|
||||
DAGID: "recovery-key",
|
||||
NodeID: "test-node",
|
||||
CurrentNodeID: "test-node",
|
||||
SubDAGPath: "",
|
||||
ProcessingState: "processing",
|
||||
Status: dagstorage.TaskStatusRunning,
|
||||
Payload: json.RawMessage(`{"test": "recovery data"}`),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Save the task
|
||||
err := dag.GetTaskStorage().SaveTask(ctx, task)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save task: %v", err)
|
||||
}
|
||||
|
||||
// Verify task was saved with recovery information
|
||||
retrieved, err := dag.GetTaskStorage().GetTask(ctx, "recovery-task-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve task: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.CurrentNodeID != "test-node" {
|
||||
t.Errorf("Expected current node 'test-node', got '%s'", retrieved.CurrentNodeID)
|
||||
}
|
||||
|
||||
if retrieved.ProcessingState != "processing" {
|
||||
t.Errorf("Expected processing state 'processing', got '%s'", retrieved.ProcessingState)
|
||||
}
|
||||
|
||||
// Test recovery functionality
|
||||
err = dag.RecoverTasks(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to recover tasks: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the task manager was created for recovery
|
||||
manager, exists := dag.taskManager.Get("recovery-task-1")
|
||||
if !exists {
|
||||
t.Fatal("Task manager should have been created during recovery")
|
||||
}
|
||||
|
||||
if manager == nil {
|
||||
t.Fatal("Task manager should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessor is a simple processor for testing
|
||||
type TestProcessor struct {
|
||||
key string
|
||||
tags []string
|
||||
}
|
||||
|
||||
func (p *TestProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
return mq.Result{
|
||||
Payload: task.Payload,
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
TaskID: task.ID,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TestProcessor) Consume(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TestProcessor) Pause(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TestProcessor) Resume(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TestProcessor) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TestProcessor) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TestProcessor) GetKey() string {
|
||||
return p.key
|
||||
}
|
||||
|
||||
func (p *TestProcessor) SetKey(key string) {
|
||||
p.key = key
|
||||
}
|
||||
|
||||
func (p *TestProcessor) GetType() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
func (p *TestProcessor) SetConfig(payload Payload) {
|
||||
// No-op for test
|
||||
}
|
||||
|
||||
func (p *TestProcessor) SetTags(tags ...string) {
|
||||
p.tags = tags
|
||||
}
|
||||
|
||||
func (p *TestProcessor) GetTags() []string {
|
||||
return p.tags
|
||||
}
|
||||
|
||||
func TestDAGSubDAGRecovery(t *testing.T) {
|
||||
// Create a DAG representing a complex workflow with sub-dags
|
||||
dag := NewDAG("complex-dag", "complex-key", func(taskID string, result mq.Result) {
|
||||
t.Logf("Complex task completed: %s", taskID)
|
||||
})
|
||||
|
||||
// Configure memory storage
|
||||
dag.ConfigureMemoryStorage()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Simulate a task that was in the middle of processing in sub-dag 3, node D
|
||||
task := &dagstorage.PersistentTask{
|
||||
ID: "complex-task-1",
|
||||
DAGID: "complex-key",
|
||||
NodeID: "start-node",
|
||||
CurrentNodeID: "node-d",
|
||||
SubDAGPath: "subdag3",
|
||||
ProcessingState: "processing",
|
||||
Status: dagstorage.TaskStatusRunning,
|
||||
Payload: json.RawMessage(`{"complex": "workflow data"}`),
|
||||
CreatedAt: time.Now().Add(-5 * time.Minute), // Simulate task that started 5 minutes ago
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Save the task
|
||||
err := dag.GetTaskStorage().SaveTask(ctx, task)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save complex task: %v", err)
|
||||
}
|
||||
|
||||
// Test that we can retrieve resumable tasks
|
||||
resumableTasks, err := dag.GetTaskStorage().GetResumableTasks(ctx, "complex-key")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get resumable tasks: %v", err)
|
||||
}
|
||||
|
||||
if len(resumableTasks) != 1 {
|
||||
t.Errorf("Expected 1 resumable task, got %d", len(resumableTasks))
|
||||
}
|
||||
|
||||
if len(resumableTasks) > 0 {
|
||||
rt := resumableTasks[0]
|
||||
if rt.CurrentNodeID != "node-d" {
|
||||
t.Errorf("Expected resumable task current node 'node-d', got '%s'", rt.CurrentNodeID)
|
||||
}
|
||||
if rt.SubDAGPath != "subdag3" {
|
||||
t.Errorf("Expected resumable task sub-dag path 'subdag3', got '%s'", rt.SubDAGPath)
|
||||
}
|
||||
}
|
||||
}
|
@@ -12,8 +12,9 @@ import (
|
||||
"github.com/oarkflow/json"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
dagstorage "github.com/oarkflow/mq/dag/storage" // Import dag storage package with alias
|
||||
"github.com/oarkflow/mq/logger"
|
||||
"github.com/oarkflow/mq/storage"
|
||||
mqstorage "github.com/oarkflow/mq/storage"
|
||||
"github.com/oarkflow/mq/storage/memory"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ func (te TaskError) Error() string {
|
||||
// TaskState holds state and intermediate results for a given task (identified by a node ID).
|
||||
type TaskState struct {
|
||||
UpdatedAt time.Time
|
||||
targetResults storage.IMap[string, mq.Result]
|
||||
targetResults mqstorage.IMap[string, mq.Result]
|
||||
NodeID string
|
||||
Status mq.Status
|
||||
Result mq.Result
|
||||
@@ -76,13 +77,13 @@ type TaskManagerConfig struct {
|
||||
|
||||
type TaskManager struct {
|
||||
createdAt time.Time
|
||||
taskStates storage.IMap[string, *TaskState]
|
||||
parentNodes storage.IMap[string, string]
|
||||
childNodes storage.IMap[string, int]
|
||||
deferredTasks storage.IMap[string, *task]
|
||||
iteratorNodes storage.IMap[string, []Edge]
|
||||
currentNodePayload storage.IMap[string, json.RawMessage]
|
||||
currentNodeResult storage.IMap[string, mq.Result]
|
||||
taskStates mqstorage.IMap[string, *TaskState]
|
||||
parentNodes mqstorage.IMap[string, string]
|
||||
childNodes mqstorage.IMap[string, int]
|
||||
deferredTasks mqstorage.IMap[string, *task]
|
||||
iteratorNodes mqstorage.IMap[string, []Edge]
|
||||
currentNodePayload mqstorage.IMap[string, json.RawMessage]
|
||||
currentNodeResult mqstorage.IMap[string, mq.Result]
|
||||
taskQueue chan *task
|
||||
result *mq.Result
|
||||
resultQueue chan nodeResult
|
||||
@@ -96,9 +97,10 @@ type TaskManager struct {
|
||||
pauseMu sync.Mutex
|
||||
pauseCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
storage dagstorage.TaskStorage // Added TaskStorage for persistence
|
||||
}
|
||||
|
||||
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
|
||||
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes mqstorage.IMap[string, []Edge], taskStorage dagstorage.TaskStorage) *TaskManager {
|
||||
config := TaskManagerConfig{
|
||||
MaxRetries: 3,
|
||||
BaseBackoff: time.Second,
|
||||
@@ -121,6 +123,7 @@ func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNo
|
||||
baseBackoff: config.BaseBackoff,
|
||||
recoveryHandler: config.RecoveryHandler,
|
||||
iteratorNodes: iteratorNodes,
|
||||
storage: taskStorage,
|
||||
}
|
||||
|
||||
tm.wg.Add(3)
|
||||
@@ -144,7 +147,27 @@ func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string
|
||||
tm.taskStates.Set(startNode, newTaskState(startNode))
|
||||
}
|
||||
t := newTask(ctx, taskID, startNode, payload)
|
||||
|
||||
// Persist task to storage
|
||||
if tm.storage != nil {
|
||||
persistentTask := &dagstorage.PersistentTask{
|
||||
ID: taskID,
|
||||
DAGID: tm.dag.key,
|
||||
NodeID: startNode,
|
||||
CurrentNodeID: startNode,
|
||||
ProcessingState: "enqueued",
|
||||
Status: dagstorage.TaskStatusPending,
|
||||
Payload: payload,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
MaxRetries: tm.maxRetries,
|
||||
}
|
||||
if err := tm.storage.SaveTask(ctx, persistentTask); err != nil {
|
||||
tm.dag.Logger().Error("Failed to persist task", logger.Field{Key: "taskID", Value: taskID}, logger.Field{Key: "error", Value: err.Error()})
|
||||
} else {
|
||||
// Log task creation activity
|
||||
tm.logActivity(ctx, taskID, startNode, "task_created", "Task enqueued for processing", nil)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case tm.taskQueue <- t:
|
||||
// Successfully enqueued
|
||||
@@ -342,6 +365,18 @@ func (tm *TaskManager) processNode(exec *task) {
|
||||
state.Status = mq.Processing
|
||||
state.UpdatedAt = time.Now()
|
||||
tm.currentNodePayload.Clear()
|
||||
// Update task status in storage
|
||||
if tm.storage != nil {
|
||||
// Update task position and status
|
||||
if err := tm.updateTaskPosition(exec.ctx, exec.taskID, pureNodeID, "processing"); err != nil {
|
||||
tm.dag.Logger().Error("Failed to update task position", logger.Field{Key: "taskID", Value: exec.taskID}, logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
if err := tm.storage.UpdateTaskStatus(exec.ctx, exec.taskID, dagstorage.TaskStatusRunning, ""); err != nil {
|
||||
tm.dag.Logger().Error("Failed to update task status", logger.Field{Key: "taskID", Value: exec.taskID}, logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
// Log node processing start
|
||||
tm.logActivity(exec.ctx, exec.taskID, pureNodeID, "node_processing_started", "Node processing started", nil)
|
||||
}
|
||||
tm.currentNodeResult.Clear()
|
||||
tm.currentNodePayload.Set(exec.nodeID, exec.payload)
|
||||
|
||||
@@ -578,6 +613,36 @@ func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskSt
|
||||
if result.Status == "" {
|
||||
result.Status = state.Status
|
||||
}
|
||||
// Update task status in storage based on final result
|
||||
if tm.storage != nil {
|
||||
var status dagstorage.TaskStatus
|
||||
var errorMsg string
|
||||
var action string
|
||||
var message string
|
||||
|
||||
if result.Error != nil {
|
||||
status = dagstorage.TaskStatusFailed
|
||||
errorMsg = result.Error.Error()
|
||||
action = "node_failed"
|
||||
message = fmt.Sprintf("Node %s failed: %s", state.NodeID, errorMsg)
|
||||
} else if state.Status == mq.Completed {
|
||||
status = dagstorage.TaskStatusCompleted
|
||||
action = "node_completed"
|
||||
message = fmt.Sprintf("Node %s completed successfully", state.NodeID)
|
||||
} else {
|
||||
status = dagstorage.TaskStatusRunning
|
||||
action = "node_processing"
|
||||
message = fmt.Sprintf("Node %s processing", state.NodeID)
|
||||
}
|
||||
|
||||
if err := tm.storage.UpdateTaskStatus(ctx, tm.taskID, status, errorMsg); err != nil {
|
||||
tm.dag.Logger().Error("Failed to update task status", logger.Field{Key: "taskID", Value: tm.taskID}, logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
|
||||
// Log node completion/failure
|
||||
tm.logActivity(ctx, tm.taskID, state.NodeID, action, message, result.Payload)
|
||||
}
|
||||
|
||||
tm.enqueueResult(nodeResult{
|
||||
ctx: ctx,
|
||||
nodeID: state.NodeID,
|
||||
@@ -902,3 +967,49 @@ func (tm *TaskManager) getErrorMessage(err error) string {
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// logActivity logs an activity for a task
|
||||
func (tm *TaskManager) logActivity(ctx context.Context, taskID, nodeID, action, message string, data json.RawMessage) {
|
||||
if tm.storage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logEntry := &dagstorage.TaskActivityLog{
|
||||
TaskID: taskID,
|
||||
DAGID: tm.dag.key,
|
||||
NodeID: nodeID,
|
||||
Action: action,
|
||||
Message: message,
|
||||
Data: data,
|
||||
Level: "info",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := tm.storage.LogActivity(ctx, logEntry); err != nil {
|
||||
tm.dag.Logger().Error("Failed to log activity",
|
||||
logger.Field{Key: "taskID", Value: taskID},
|
||||
logger.Field{Key: "action", Value: action},
|
||||
logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
// updateTaskPosition updates the current position of a task in the DAG
|
||||
func (tm *TaskManager) updateTaskPosition(ctx context.Context, taskID, currentNodeID, processingState string) error {
|
||||
if tm.storage == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the current task
|
||||
task, err := tm.storage.GetTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get task for position update: %w", err)
|
||||
}
|
||||
|
||||
// Update position fields
|
||||
task.CurrentNodeID = currentNodeID
|
||||
task.ProcessingState = processingState
|
||||
task.UpdatedAt = time.Now()
|
||||
|
||||
// Save the updated task
|
||||
return tm.storage.SaveTask(ctx, task)
|
||||
}
|
||||
|
221
dag/wal/recovery.go
Normal file
221
dag/wal/recovery.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/logger"
|
||||
)
|
||||
|
||||
// WALRecovery handles crash recovery for WAL
|
||||
type WALRecovery struct {
|
||||
storage WALStorage
|
||||
logger logger.Logger
|
||||
config *WALConfig
|
||||
}
|
||||
|
||||
// WALRecoveryManager manages the recovery process
|
||||
type WALRecoveryManager struct {
|
||||
recovery *WALRecovery
|
||||
config *WALConfig
|
||||
}
|
||||
|
||||
// NewWALRecovery creates a new WAL recovery instance
|
||||
func NewWALRecovery(storage WALStorage, logger logger.Logger, config *WALConfig) *WALRecovery {
|
||||
return &WALRecovery{
|
||||
storage: storage,
|
||||
logger: logger,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWALRecoveryManager creates a new WAL recovery manager
|
||||
func NewWALRecoveryManager(config *WALConfig, storage WALStorage, logger logger.Logger) *WALRecoveryManager {
|
||||
return &WALRecoveryManager{
|
||||
recovery: NewWALRecovery(storage, logger, config),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Recover performs crash recovery by replaying unflushed WAL entries
|
||||
func (r *WALRecovery) Recover(ctx context.Context) error {
|
||||
r.logger.Info("Starting WAL recovery process", logger.Field{Key: "component", Value: "wal_recovery"})
|
||||
|
||||
// Get all unflushed entries
|
||||
entries, err := r.storage.GetUnflushedEntries(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get unflushed entries: %w", err)
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
r.logger.Info("No unflushed entries found, recovery complete", logger.Field{Key: "component", Value: "wal_recovery"})
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Info("Found unflushed entries to recover", logger.Field{Key: "count", Value: len(entries)}, logger.Field{Key: "component", Value: "wal_recovery"})
|
||||
|
||||
// Validate entries before recovery
|
||||
validEntries := make([]WALEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if err := r.validateEntry(&entry); err != nil {
|
||||
r.logger.Warn("Skipping invalid entry", logger.Field{Key: "entry_id", Value: entry.ID}, logger.Field{Key: "error", Value: err.Error()}, logger.Field{Key: "component", Value: "wal_recovery"})
|
||||
continue
|
||||
}
|
||||
validEntries = append(validEntries, entry)
|
||||
}
|
||||
|
||||
if len(validEntries) == 0 {
|
||||
r.logger.Info("No valid entries to recover", logger.Field{Key: "component", Value: "wal_recovery"})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply valid entries in order
|
||||
appliedCount := 0
|
||||
failedCount := 0
|
||||
|
||||
for _, entry := range validEntries {
|
||||
if err := r.applyEntry(ctx, &entry); err != nil {
|
||||
r.logger.Error("Failed to apply entry", logger.Field{Key: "entry_id", Value: entry.ID}, logger.Field{Key: "error", Value: err.Error()}, logger.Field{Key: "component", Value: "wal_recovery"})
|
||||
failedCount++
|
||||
continue
|
||||
}
|
||||
appliedCount++
|
||||
}
|
||||
|
||||
r.logger.Info("Recovery complete", logger.Field{Key: "applied", Value: appliedCount}, logger.Field{Key: "failed", Value: failedCount}, logger.Field{Key: "component", Value: "wal_recovery"})
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateEntry validates a WAL entry before recovery
|
||||
func (r *WALRecovery) validateEntry(entry *WALEntry) error {
|
||||
if entry.ID == "" {
|
||||
return fmt.Errorf("entry ID is empty")
|
||||
}
|
||||
|
||||
if entry.Type == "" {
|
||||
return fmt.Errorf("entry type is empty")
|
||||
}
|
||||
|
||||
if len(entry.Data) == 0 {
|
||||
return fmt.Errorf("entry data is empty")
|
||||
}
|
||||
|
||||
if entry.Timestamp.IsZero() {
|
||||
return fmt.Errorf("entry timestamp is zero")
|
||||
}
|
||||
|
||||
// Check if entry is too old (configurable)
|
||||
if r.config != nil && r.config.RecoveryTimeout > 0 {
|
||||
if time.Since(entry.Timestamp) > r.config.RecoveryTimeout {
|
||||
return fmt.Errorf("entry is too old: %v", time.Since(entry.Timestamp))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyEntry applies a single WAL entry to the underlying storage
|
||||
func (r *WALRecovery) applyEntry(ctx context.Context, entry *WALEntry) error {
|
||||
switch entry.Type {
|
||||
case WALEntryTypeTaskUpdate:
|
||||
return r.applyTaskUpdate(ctx, entry)
|
||||
case WALEntryTypeActivityLog:
|
||||
return r.applyActivityLog(ctx, entry)
|
||||
case WALEntryTypeTaskDelete:
|
||||
return r.applyTaskDelete(ctx, entry)
|
||||
default:
|
||||
return fmt.Errorf("unknown entry type: %s", entry.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// applyTaskUpdate applies a task update entry
|
||||
func (r *WALRecovery) applyTaskUpdate(ctx context.Context, entry *WALEntry) error {
|
||||
// Use the underlying storage to save the task
|
||||
return r.storage.SaveTask(ctx, nil) // This needs to be implemented properly
|
||||
}
|
||||
|
||||
// applyActivityLog applies an activity log entry
|
||||
func (r *WALRecovery) applyActivityLog(ctx context.Context, entry *WALEntry) error {
|
||||
// Use the underlying storage to log activity
|
||||
return r.storage.LogActivity(ctx, nil) // This needs to be implemented properly
|
||||
}
|
||||
|
||||
// applyTaskDelete applies a task delete entry
|
||||
func (r *WALRecovery) applyTaskDelete(ctx context.Context, entry *WALEntry) error {
|
||||
// Task deletion would need to be implemented in the storage interface
|
||||
return fmt.Errorf("task delete recovery not implemented")
|
||||
}
|
||||
|
||||
// Cleanup removes old recovery data
|
||||
func (r *WALRecovery) Cleanup(ctx context.Context, olderThan time.Time) error {
|
||||
return r.storage.DeleteOldSegments(ctx, olderThan)
|
||||
}
|
||||
|
||||
// GetRecoveryStats returns recovery statistics
|
||||
func (r *WALRecovery) GetRecoveryStats(ctx context.Context) (*RecoveryStats, error) {
|
||||
entries, err := r.storage.GetUnflushedEntries(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RecoveryStats{
|
||||
TotalEntries: len(entries),
|
||||
PendingEntries: len(entries),
|
||||
LastRecoveryTime: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RecoveryStats contains recovery statistics
|
||||
type RecoveryStats struct {
|
||||
TotalEntries int
|
||||
AppliedEntries int
|
||||
FailedEntries int
|
||||
PendingEntries int
|
||||
LastRecoveryTime time.Time
|
||||
}
|
||||
|
||||
// PerformRecovery performs the full recovery process
|
||||
func (rm *WALRecoveryManager) PerformRecovery(ctx context.Context) error {
|
||||
startTime := time.Now()
|
||||
rm.recovery.logger.Info("Starting WAL recovery manager process", logger.Field{Key: "component", Value: "wal_recovery_manager"})
|
||||
|
||||
// Perform recovery
|
||||
if err := rm.recovery.Recover(ctx); err != nil {
|
||||
return fmt.Errorf("recovery failed: %w", err)
|
||||
}
|
||||
|
||||
// Cleanup old data if configured
|
||||
if rm.config.SegmentRetention > 0 {
|
||||
cleanupTime := time.Now().Add(-rm.config.SegmentRetention)
|
||||
if err := rm.recovery.Cleanup(ctx, cleanupTime); err != nil {
|
||||
rm.recovery.logger.Warn("Failed to cleanup old recovery data", logger.Field{Key: "error", Value: err.Error()}, logger.Field{Key: "component", Value: "wal_recovery_manager"})
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
rm.recovery.logger.Info("WAL recovery manager completed", logger.Field{Key: "duration", Value: duration.String()}, logger.Field{Key: "component", Value: "wal_recovery_manager"})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRecoveryStatus returns the current recovery status
|
||||
func (rm *WALRecoveryManager) GetRecoveryStatus(ctx context.Context) (*RecoveryStatus, error) {
|
||||
stats, err := rm.recovery.GetRecoveryStats(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RecoveryStatus{
|
||||
Stats: *stats,
|
||||
Config: *rm.config,
|
||||
IsRecoveryNeeded: stats.PendingEntries > 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RecoveryStatus represents the current recovery status
|
||||
type RecoveryStatus struct {
|
||||
Stats RecoveryStats
|
||||
Config WALConfig
|
||||
IsRecoveryNeeded bool
|
||||
}
|
437
dag/wal/storage.go
Normal file
437
dag/wal/storage.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
"github.com/oarkflow/mq/dag/storage"
|
||||
)
|
||||
|
||||
// WALStorageImpl implements WALStorage interface
|
||||
type WALStorageImpl struct {
|
||||
underlying storage.TaskStorage
|
||||
db *sql.DB
|
||||
config *WALConfig
|
||||
|
||||
// WAL tables
|
||||
walEntriesTable string
|
||||
walSegmentsTable string
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewWALStorage creates a new WAL-enabled storage wrapper
|
||||
func NewWALStorage(underlying storage.TaskStorage, db *sql.DB, config *WALConfig) *WALStorageImpl {
|
||||
if config == nil {
|
||||
config = DefaultWALConfig()
|
||||
}
|
||||
|
||||
return &WALStorageImpl{
|
||||
underlying: underlying,
|
||||
db: db,
|
||||
config: config,
|
||||
walEntriesTable: "wal_entries",
|
||||
walSegmentsTable: "wal_segments",
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeTables creates the necessary WAL tables
|
||||
func (ws *WALStorageImpl) InitializeTables(ctx context.Context) error {
|
||||
// Create WAL entries table
|
||||
entriesQuery := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
type VARCHAR(50) NOT NULL,
|
||||
timestamp TIMESTAMP NOT NULL,
|
||||
sequence_id BIGINT NOT NULL,
|
||||
data JSONB,
|
||||
metadata JSONB,
|
||||
checksum VARCHAR(255),
|
||||
segment_id VARCHAR(255),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(sequence_id)
|
||||
)`, ws.walEntriesTable)
|
||||
|
||||
// Create WAL segments table
|
||||
segmentsQuery := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
start_seq_id BIGINT NOT NULL,
|
||||
end_seq_id BIGINT NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active',
|
||||
checksum VARCHAR(255),
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
flushed_at TIMESTAMP,
|
||||
INDEX idx_status (status),
|
||||
INDEX idx_created_at (created_at),
|
||||
INDEX idx_sequence_range (start_seq_id, end_seq_id)
|
||||
)`, ws.walSegmentsTable)
|
||||
|
||||
// Execute table creation
|
||||
if _, err := ws.db.ExecContext(ctx, entriesQuery); err != nil {
|
||||
return fmt.Errorf("failed to create WAL entries table: %w", err)
|
||||
}
|
||||
|
||||
if _, err := ws.db.ExecContext(ctx, segmentsQuery); err != nil {
|
||||
return fmt.Errorf("failed to create WAL segments table: %w", err)
|
||||
}
|
||||
|
||||
// Create indexes for better performance
|
||||
indexes := []string{
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_wal_entries_sequence ON %s (sequence_id)", ws.walEntriesTable),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_wal_entries_timestamp ON %s (timestamp)", ws.walEntriesTable),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_wal_entries_segment ON %s (segment_id)", ws.walEntriesTable),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_wal_segments_status ON %s (status)", ws.walSegmentsTable),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_wal_segments_range ON %s (start_seq_id, end_seq_id)", ws.walSegmentsTable),
|
||||
}
|
||||
|
||||
for _, idx := range indexes {
|
||||
if _, err := ws.db.ExecContext(ctx, idx); err != nil {
|
||||
return fmt.Errorf("failed to create index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveWALEntry saves a single WAL entry
|
||||
func (ws *WALStorageImpl) SaveWALEntry(ctx context.Context, entry *WALEntry) error {
|
||||
query := fmt.Sprintf(`
|
||||
INSERT INTO %s (id, type, timestamp, sequence_id, data, metadata, checksum, segment_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (sequence_id) DO NOTHING`, ws.walEntriesTable)
|
||||
|
||||
metadataJSON, _ := json.Marshal(entry.Metadata)
|
||||
|
||||
_, err := ws.db.ExecContext(ctx, query,
|
||||
entry.ID,
|
||||
string(entry.Type),
|
||||
entry.Timestamp,
|
||||
entry.SequenceID,
|
||||
string(entry.Data),
|
||||
string(metadataJSON),
|
||||
entry.Checksum,
|
||||
entry.ID, // Use entry ID as segment ID for single entries
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save WAL entry: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveWALEntries saves multiple WAL entries in a batch
|
||||
func (ws *WALStorageImpl) SaveWALEntries(ctx context.Context, entries []WALEntry) error {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := ws.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
INSERT INTO %s (id, type, timestamp, sequence_id, data, metadata, checksum, segment_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (sequence_id) DO NOTHING`, ws.walEntriesTable)
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare statement: %w", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, entry := range entries {
|
||||
metadataJSON, _ := json.Marshal(entry.Metadata)
|
||||
|
||||
_, err = stmt.ExecContext(ctx,
|
||||
entry.ID,
|
||||
string(entry.Type),
|
||||
entry.Timestamp,
|
||||
entry.SequenceID,
|
||||
string(entry.Data),
|
||||
string(metadataJSON),
|
||||
entry.Checksum,
|
||||
entry.ID, // Use entry ID as segment ID
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save WAL entry %s: %w", entry.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveWALSegment saves a complete WAL segment
|
||||
func (ws *WALStorageImpl) SaveWALSegment(ctx context.Context, segment *WALSegment) error {
|
||||
tx, err := ws.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Save segment metadata
|
||||
segmentQuery := fmt.Sprintf(`
|
||||
INSERT INTO %s (id, start_seq_id, end_seq_id, status, checksum, created_at, flushed_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
status = EXCLUDED.status,
|
||||
flushed_at = EXCLUDED.flushed_at`, ws.walSegmentsTable)
|
||||
|
||||
var flushedAt interface{}
|
||||
if segment.FlushedAt != nil {
|
||||
flushedAt = *segment.FlushedAt
|
||||
} else {
|
||||
flushedAt = nil
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, segmentQuery,
|
||||
segment.ID,
|
||||
segment.StartSeqID,
|
||||
segment.EndSeqID,
|
||||
string(segment.Status),
|
||||
segment.Checksum,
|
||||
segment.CreatedAt,
|
||||
flushedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save WAL segment: %w", err)
|
||||
}
|
||||
|
||||
// Save all entries in the segment
|
||||
if len(segment.Entries) > 0 {
|
||||
entriesQuery := fmt.Sprintf(`
|
||||
INSERT INTO %s (id, type, timestamp, sequence_id, data, metadata, checksum, segment_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (sequence_id) DO NOTHING`, ws.walEntriesTable)
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, entriesQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare entries statement: %w", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, entry := range segment.Entries {
|
||||
metadataJSON, _ := json.Marshal(entry.Metadata)
|
||||
|
||||
_, err = stmt.ExecContext(ctx,
|
||||
entry.ID,
|
||||
string(entry.Type),
|
||||
entry.Timestamp,
|
||||
entry.SequenceID,
|
||||
string(entry.Data),
|
||||
string(metadataJSON),
|
||||
entry.Checksum,
|
||||
segment.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save entry %s: %w", entry.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit segment transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWALSegments retrieves WAL segments since a given time
|
||||
func (ws *WALStorageImpl) GetWALSegments(ctx context.Context, since time.Time) ([]WALSegment, error) {
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, start_seq_id, end_seq_id, status, checksum, created_at, flushed_at
|
||||
FROM %s
|
||||
WHERE created_at >= ?
|
||||
ORDER BY created_at ASC`, ws.walSegmentsTable)
|
||||
|
||||
rows, err := ws.db.QueryContext(ctx, query, since)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query WAL segments: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var segments []WALSegment
|
||||
for rows.Next() {
|
||||
var segment WALSegment
|
||||
var flushedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&segment.ID,
|
||||
&segment.StartSeqID,
|
||||
&segment.EndSeqID,
|
||||
&segment.Status,
|
||||
&segment.Checksum,
|
||||
&segment.CreatedAt,
|
||||
&flushedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan segment: %w", err)
|
||||
}
|
||||
|
||||
if flushedAt.Valid {
|
||||
segment.FlushedAt = &flushedAt.Time
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
}
|
||||
|
||||
return segments, rows.Err()
|
||||
}
|
||||
|
||||
// GetUnflushedEntries retrieves entries that haven't been applied yet
|
||||
func (ws *WALStorageImpl) GetUnflushedEntries(ctx context.Context) ([]WALEntry, error) {
|
||||
query := fmt.Sprintf(`
|
||||
SELECT we.id, we.type, we.timestamp, we.sequence_id, we.data, we.metadata, we.checksum
|
||||
FROM %s we
|
||||
LEFT JOIN %s ws ON we.segment_id = ws.id
|
||||
WHERE ws.status IN ('active', 'flushing') OR ws.id IS NULL
|
||||
ORDER BY we.sequence_id ASC`, ws.walEntriesTable, ws.walSegmentsTable)
|
||||
|
||||
rows, err := ws.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query unflushed entries: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []WALEntry
|
||||
for rows.Next() {
|
||||
var entry WALEntry
|
||||
var metadata sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&entry.ID,
|
||||
&entry.Type,
|
||||
&entry.Timestamp,
|
||||
&entry.SequenceID,
|
||||
&entry.Data,
|
||||
&metadata,
|
||||
&entry.Checksum,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan entry: %w", err)
|
||||
}
|
||||
|
||||
if metadata.Valid && metadata.String != "" {
|
||||
json.Unmarshal([]byte(metadata.String), &entry.Metadata)
|
||||
}
|
||||
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteOldSegments deletes WAL segments older than the specified time
|
||||
func (ws *WALStorageImpl) DeleteOldSegments(ctx context.Context, olderThan time.Time) error {
|
||||
// First, delete entries from old segments
|
||||
deleteEntriesQuery := fmt.Sprintf(`
|
||||
DELETE FROM %s
|
||||
WHERE segment_id IN (
|
||||
SELECT id FROM %s
|
||||
WHERE status = 'flushed' AND flushed_at < ?
|
||||
)`, ws.walEntriesTable, ws.walSegmentsTable)
|
||||
|
||||
// Then delete the segments themselves
|
||||
deleteSegmentsQuery := fmt.Sprintf(`
|
||||
DELETE FROM %s
|
||||
WHERE status = 'flushed' AND flushed_at < ?`, ws.walSegmentsTable)
|
||||
|
||||
tx, err := ws.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Delete entries
|
||||
if _, err = tx.ExecContext(ctx, deleteEntriesQuery, olderThan); err != nil {
|
||||
return fmt.Errorf("failed to delete old entries: %w", err)
|
||||
}
|
||||
|
||||
// Delete segments
|
||||
if _, err = tx.ExecContext(ctx, deleteSegmentsQuery, olderThan); err != nil {
|
||||
return fmt.Errorf("failed to delete old segments: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit cleanup transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveTask delegates to underlying storage
|
||||
func (ws *WALStorageImpl) SaveTask(ctx context.Context, task *storage.PersistentTask) error {
|
||||
return ws.underlying.SaveTask(ctx, task)
|
||||
}
|
||||
|
||||
// LogActivity delegates to underlying storage
|
||||
func (ws *WALStorageImpl) LogActivity(ctx context.Context, log *storage.TaskActivityLog) error {
|
||||
return ws.underlying.LogActivity(ctx, log)
|
||||
}
|
||||
|
||||
// WALEnabledStorage wraps any TaskStorage with WAL functionality
|
||||
type WALEnabledStorage struct {
|
||||
storage.TaskStorage
|
||||
walManager *WALManager
|
||||
}
|
||||
|
||||
// NewWALEnabledStorage creates a new WAL-enabled storage wrapper
|
||||
func NewWALEnabledStorage(underlying storage.TaskStorage, walManager *WALManager) *WALEnabledStorage {
|
||||
return &WALEnabledStorage{
|
||||
TaskStorage: underlying,
|
||||
walManager: walManager,
|
||||
}
|
||||
}
|
||||
|
||||
// SaveTask saves a task through WAL
|
||||
func (wes *WALEnabledStorage) SaveTask(ctx context.Context, task *storage.PersistentTask) error {
|
||||
// Serialize task to JSON for WAL
|
||||
taskData, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task: %w", err)
|
||||
}
|
||||
|
||||
// Write to WAL first
|
||||
if err := wes.walManager.WriteEntry(ctx, WALEntryTypeTaskUpdate, taskData, map[string]interface{}{
|
||||
"task_id": task.ID,
|
||||
"dag_id": task.DAGID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write task to WAL: %w", err)
|
||||
}
|
||||
|
||||
// Delegate to underlying storage (will be applied during flush)
|
||||
return wes.TaskStorage.SaveTask(ctx, task)
|
||||
}
|
||||
|
||||
// LogActivity logs activity through WAL
|
||||
func (wes *WALEnabledStorage) LogActivity(ctx context.Context, log *storage.TaskActivityLog) error {
|
||||
// Serialize log to JSON for WAL
|
||||
logData, err := json.Marshal(log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal activity log: %w", err)
|
||||
}
|
||||
|
||||
// Write to WAL first
|
||||
if err := wes.walManager.WriteEntry(ctx, WALEntryTypeActivityLog, logData, map[string]interface{}{
|
||||
"task_id": log.TaskID,
|
||||
"dag_id": log.DAGID,
|
||||
"action": log.Action,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write activity log to WAL: %w", err)
|
||||
}
|
||||
|
||||
// Delegate to underlying storage (will be applied during flush)
|
||||
return wes.TaskStorage.LogActivity(ctx, log)
|
||||
}
|
473
dag/wal/wal.go
Normal file
473
dag/wal/wal.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
"github.com/oarkflow/mq/dag/storage"
|
||||
)
|
||||
|
||||
// WALEntryType represents the type of WAL entry
|
||||
type WALEntryType string
|
||||
|
||||
const (
|
||||
WALEntryTypeTaskUpdate WALEntryType = "task_update"
|
||||
WALEntryTypeActivityLog WALEntryType = "activity_log"
|
||||
WALEntryTypeTaskDelete WALEntryType = "task_delete"
|
||||
WALEntryTypeBatchUpdate WALEntryType = "batch_update"
|
||||
)
|
||||
|
||||
// WALEntry represents a single entry in the Write-Ahead Log
|
||||
type WALEntry struct {
|
||||
ID string `json:"id"`
|
||||
Type WALEntryType `json:"type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
SequenceID uint64 `json:"sequence_id"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Checksum string `json:"checksum"`
|
||||
}
|
||||
|
||||
// WALSegment represents a segment of WAL entries
|
||||
type WALSegment struct {
|
||||
ID string `json:"id"`
|
||||
StartSeqID uint64 `json:"start_seq_id"`
|
||||
EndSeqID uint64 `json:"end_seq_id"`
|
||||
Entries []WALEntry `json:"entries"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
FlushedAt *time.Time `json:"flushed_at,omitempty"`
|
||||
Status SegmentStatus `json:"status"`
|
||||
Checksum string `json:"checksum"`
|
||||
}
|
||||
|
||||
// SegmentStatus represents the status of a WAL segment
|
||||
type SegmentStatus string
|
||||
|
||||
const (
|
||||
SegmentStatusActive SegmentStatus = "active"
|
||||
SegmentStatusFlushing SegmentStatus = "flushing"
|
||||
SegmentStatusFlushed SegmentStatus = "flushed"
|
||||
SegmentStatusFailed SegmentStatus = "failed"
|
||||
)
|
||||
|
||||
// WALConfig holds configuration for the WAL system
|
||||
type WALConfig struct {
|
||||
// Buffer configuration
|
||||
MaxBufferSize int `json:"max_buffer_size"` // Maximum entries in buffer before flush
|
||||
FlushInterval time.Duration `json:"flush_interval"` // How often to flush buffer
|
||||
MaxFlushRetries int `json:"max_flush_retries"` // Max retries for failed flushes
|
||||
|
||||
// Segment configuration
|
||||
MaxSegmentSize int `json:"max_segment_size"` // Maximum entries per segment
|
||||
SegmentRetention time.Duration `json:"segment_retention"` // How long to keep flushed segments
|
||||
|
||||
// Performance tuning
|
||||
WorkerCount int `json:"worker_count"` // Number of flush workers
|
||||
BatchSize int `json:"batch_size"` // Batch size for database operations
|
||||
|
||||
// Recovery configuration
|
||||
EnableRecovery bool `json:"enable_recovery"` // Enable WAL recovery on startup
|
||||
RecoveryTimeout time.Duration `json:"recovery_timeout"` // Timeout for recovery operations
|
||||
|
||||
// Monitoring
|
||||
EnableMetrics bool `json:"enable_metrics"` // Enable metrics collection
|
||||
MetricsInterval time.Duration `json:"metrics_interval"` // Metrics collection interval
|
||||
}
|
||||
|
||||
// DefaultWALConfig returns default WAL configuration
|
||||
func DefaultWALConfig() *WALConfig {
|
||||
return &WALConfig{
|
||||
MaxBufferSize: 1000,
|
||||
FlushInterval: 5 * time.Second,
|
||||
MaxFlushRetries: 3,
|
||||
MaxSegmentSize: 5000,
|
||||
SegmentRetention: 24 * time.Hour,
|
||||
WorkerCount: 2,
|
||||
BatchSize: 100,
|
||||
EnableRecovery: true,
|
||||
RecoveryTimeout: 30 * time.Second,
|
||||
EnableMetrics: true,
|
||||
MetricsInterval: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// WALManager manages the Write-Ahead Log system
|
||||
type WALManager struct {
|
||||
config *WALConfig
|
||||
storage WALStorage
|
||||
buffer chan WALEntry
|
||||
segments map[string]*WALSegment
|
||||
currentSeqID uint64
|
||||
activeSegment *WALSegment
|
||||
|
||||
// Control channels
|
||||
flushTrigger chan struct{}
|
||||
shutdown chan struct{}
|
||||
done chan struct{}
|
||||
|
||||
// Workers
|
||||
flushWorkers []chan WALSegment
|
||||
|
||||
// Synchronization
|
||||
mu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Metrics
|
||||
metrics *WALMetrics
|
||||
|
||||
// Callbacks
|
||||
onFlush func(segment *WALSegment) error
|
||||
onRecovery func(entries []WALEntry) error
|
||||
}
|
||||
|
||||
// WALStorage defines the interface for WAL storage operations
|
||||
type WALStorage interface {
|
||||
// WAL operations
|
||||
SaveWALEntry(ctx context.Context, entry *WALEntry) error
|
||||
SaveWALEntries(ctx context.Context, entries []WALEntry) error
|
||||
SaveWALSegment(ctx context.Context, segment *WALSegment) error
|
||||
|
||||
// Recovery operations
|
||||
GetWALSegments(ctx context.Context, since time.Time) ([]WALSegment, error)
|
||||
GetUnflushedEntries(ctx context.Context) ([]WALEntry, error)
|
||||
|
||||
// Cleanup operations
|
||||
DeleteOldSegments(ctx context.Context, olderThan time.Time) error
|
||||
|
||||
// Task operations (delegated to underlying storage)
|
||||
SaveTask(ctx context.Context, task *storage.PersistentTask) error
|
||||
LogActivity(ctx context.Context, log *storage.TaskActivityLog) error
|
||||
}
|
||||
|
||||
// WALMetrics holds metrics for the WAL system
|
||||
type WALMetrics struct {
|
||||
EntriesBuffered int64
|
||||
EntriesFlushed int64
|
||||
FlushOperations int64
|
||||
FlushErrors int64
|
||||
RecoveryOperations int64
|
||||
RecoveryErrors int64
|
||||
AverageFlushTime time.Duration
|
||||
LastFlushTime time.Time
|
||||
}
|
||||
|
||||
// NewWALManager creates a new WAL manager
|
||||
func NewWALManager(config *WALConfig, storage WALStorage) *WALManager {
|
||||
if config == nil {
|
||||
config = DefaultWALConfig()
|
||||
}
|
||||
|
||||
wm := &WALManager{
|
||||
config: config,
|
||||
storage: storage,
|
||||
buffer: make(chan WALEntry, config.MaxBufferSize*2), // Extra capacity for burst
|
||||
segments: make(map[string]*WALSegment),
|
||||
flushTrigger: make(chan struct{}, 1),
|
||||
shutdown: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
flushWorkers: make([]chan WALSegment, config.WorkerCount),
|
||||
metrics: &WALMetrics{},
|
||||
}
|
||||
|
||||
// Initialize flush workers
|
||||
for i := 0; i < config.WorkerCount; i++ {
|
||||
wm.flushWorkers[i] = make(chan WALSegment, 10)
|
||||
wm.wg.Add(1)
|
||||
go wm.flushWorker(i, wm.flushWorkers[i])
|
||||
}
|
||||
|
||||
// Start main processing loop
|
||||
wm.wg.Add(1)
|
||||
go wm.processLoop()
|
||||
|
||||
return wm
|
||||
}
|
||||
|
||||
// WriteEntry writes an entry to the WAL
|
||||
func (wm *WALManager) WriteEntry(ctx context.Context, entryType WALEntryType, data json.RawMessage, metadata map[string]interface{}) error {
|
||||
entry := WALEntry{
|
||||
ID: generateID(),
|
||||
Type: entryType,
|
||||
Timestamp: time.Now(),
|
||||
SequenceID: atomic.AddUint64(&wm.currentSeqID, 1),
|
||||
Data: data,
|
||||
Metadata: metadata,
|
||||
Checksum: calculateChecksum(data),
|
||||
}
|
||||
|
||||
select {
|
||||
case wm.buffer <- entry:
|
||||
atomic.AddInt64(&wm.metrics.EntriesBuffered, 1)
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-wm.shutdown:
|
||||
return fmt.Errorf("WAL manager is shutting down")
|
||||
default:
|
||||
// Buffer is full, trigger immediate flush
|
||||
select {
|
||||
case wm.flushTrigger <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return fmt.Errorf("WAL buffer is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Flush forces an immediate flush of the WAL buffer
|
||||
func (wm *WALManager) Flush(ctx context.Context) error {
|
||||
select {
|
||||
case wm.flushTrigger <- struct{}{}:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-wm.shutdown:
|
||||
return fmt.Errorf("WAL manager is shutting down")
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the WAL manager
|
||||
func (wm *WALManager) Shutdown(ctx context.Context) error {
|
||||
close(wm.shutdown)
|
||||
|
||||
// Wait for processing to complete or context timeout
|
||||
select {
|
||||
case <-wm.done:
|
||||
wm.wg.Wait()
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns current WAL metrics
|
||||
func (wm *WALManager) GetMetrics() WALMetrics {
|
||||
return *wm.metrics
|
||||
}
|
||||
|
||||
// SetFlushCallback sets the callback to be called after each flush
|
||||
func (wm *WALManager) SetFlushCallback(callback func(segment *WALSegment) error) {
|
||||
wm.onFlush = callback
|
||||
}
|
||||
|
||||
// SetRecoveryCallback sets the callback to be called during recovery
|
||||
func (wm *WALManager) SetRecoveryCallback(callback func(entries []WALEntry) error) {
|
||||
wm.onRecovery = callback
|
||||
}
|
||||
|
||||
// processLoop is the main processing loop for the WAL manager
|
||||
func (wm *WALManager) processLoop() {
|
||||
defer close(wm.done)
|
||||
|
||||
ticker := time.NewTicker(wm.config.FlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-wm.shutdown:
|
||||
wm.flushAllBuffers(context.Background())
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
wm.triggerFlush()
|
||||
|
||||
case <-wm.flushTrigger:
|
||||
wm.triggerFlush()
|
||||
|
||||
case entry := <-wm.buffer:
|
||||
wm.addToSegment(entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// triggerFlush triggers a flush operation
|
||||
func (wm *WALManager) triggerFlush() {
|
||||
wm.mu.Lock()
|
||||
defer wm.mu.Unlock()
|
||||
|
||||
if wm.activeSegment != nil && len(wm.activeSegment.Entries) > 0 {
|
||||
segment := wm.activeSegment
|
||||
wm.activeSegment = wm.createNewSegment()
|
||||
|
||||
// Send to a worker for processing
|
||||
workerIndex := int(segment.StartSeqID) % len(wm.flushWorkers)
|
||||
select {
|
||||
case wm.flushWorkers[workerIndex] <- *segment:
|
||||
default:
|
||||
// Worker queue is full, process synchronously
|
||||
go wm.flushSegment(context.Background(), segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flushWorker processes segments for a specific worker
|
||||
func (wm *WALManager) flushWorker(workerID int, segmentCh <-chan WALSegment) {
|
||||
defer wm.wg.Done()
|
||||
|
||||
for segment := range segmentCh {
|
||||
if err := wm.flushSegment(context.Background(), &segment); err != nil {
|
||||
// Log error and retry logic could be added here
|
||||
atomic.AddInt64(&wm.metrics.FlushErrors, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flushSegment flushes a segment to storage
|
||||
func (wm *WALManager) flushSegment(ctx context.Context, segment *WALSegment) error {
|
||||
start := time.Now()
|
||||
|
||||
// Update segment status
|
||||
segment.Status = SegmentStatusFlushing
|
||||
now := time.Now()
|
||||
segment.FlushedAt = &now
|
||||
|
||||
// Save segment to storage
|
||||
if err := wm.storage.SaveWALSegment(ctx, segment); err != nil {
|
||||
segment.Status = SegmentStatusFailed
|
||||
return fmt.Errorf("failed to save WAL segment: %w", err)
|
||||
}
|
||||
|
||||
// Apply entries to underlying storage based on type
|
||||
if err := wm.applyEntries(ctx, segment.Entries); err != nil {
|
||||
segment.Status = SegmentStatusFailed
|
||||
return fmt.Errorf("failed to apply WAL entries: %w", err)
|
||||
}
|
||||
|
||||
// Mark segment as flushed
|
||||
segment.Status = SegmentStatusFlushed
|
||||
|
||||
// Update metrics
|
||||
atomic.AddInt64(&wm.metrics.EntriesFlushed, int64(len(segment.Entries)))
|
||||
atomic.AddInt64(&wm.metrics.FlushOperations, 1)
|
||||
atomic.StoreInt64((*int64)(&wm.metrics.AverageFlushTime), int64(time.Since(start)))
|
||||
wm.metrics.LastFlushTime = time.Now()
|
||||
|
||||
// Call flush callback if set
|
||||
if wm.onFlush != nil {
|
||||
if err := wm.onFlush(segment); err != nil {
|
||||
// Log error but don't fail the flush
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyEntries applies WAL entries to the underlying storage
|
||||
func (wm *WALManager) applyEntries(ctx context.Context, entries []WALEntry) error {
|
||||
// Group entries by type for batch processing
|
||||
taskUpdates := make([]storage.PersistentTask, 0)
|
||||
activityLogs := make([]storage.TaskActivityLog, 0)
|
||||
|
||||
for _, entry := range entries {
|
||||
switch entry.Type {
|
||||
case WALEntryTypeTaskUpdate:
|
||||
var task storage.PersistentTask
|
||||
if err := json.Unmarshal(entry.Data, &task); err != nil {
|
||||
continue // Skip invalid entries
|
||||
}
|
||||
taskUpdates = append(taskUpdates, task)
|
||||
|
||||
case WALEntryTypeActivityLog:
|
||||
var log storage.TaskActivityLog
|
||||
if err := json.Unmarshal(entry.Data, &log); err != nil {
|
||||
continue // Skip invalid entries
|
||||
}
|
||||
activityLogs = append(activityLogs, log)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch save tasks
|
||||
if len(taskUpdates) > 0 {
|
||||
for i := 0; i < len(taskUpdates); i += wm.config.BatchSize {
|
||||
end := i + wm.config.BatchSize
|
||||
if end > len(taskUpdates) {
|
||||
end = len(taskUpdates)
|
||||
}
|
||||
batch := taskUpdates[i:end]
|
||||
|
||||
for _, task := range batch {
|
||||
if err := wm.storage.SaveTask(ctx, &task); err != nil {
|
||||
return fmt.Errorf("failed to save task %s: %w", task.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Batch save activity logs
|
||||
if len(activityLogs) > 0 {
|
||||
for i := 0; i < len(activityLogs); i += wm.config.BatchSize {
|
||||
end := i + wm.config.BatchSize
|
||||
if end > len(activityLogs) {
|
||||
end = len(activityLogs)
|
||||
}
|
||||
batch := activityLogs[i:end]
|
||||
|
||||
for _, log := range batch {
|
||||
if err := wm.storage.LogActivity(ctx, &log); err != nil {
|
||||
return fmt.Errorf("failed to save activity log for task %s: %w", log.TaskID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addToSegment adds an entry to the current active segment
|
||||
func (wm *WALManager) addToSegment(entry WALEntry) {
|
||||
wm.mu.Lock()
|
||||
defer wm.mu.Unlock()
|
||||
|
||||
// Create new segment if needed
|
||||
if wm.activeSegment == nil || len(wm.activeSegment.Entries) >= wm.config.MaxSegmentSize {
|
||||
wm.activeSegment = wm.createNewSegment()
|
||||
}
|
||||
|
||||
wm.activeSegment.Entries = append(wm.activeSegment.Entries, entry)
|
||||
wm.activeSegment.EndSeqID = entry.SequenceID
|
||||
}
|
||||
|
||||
// createNewSegment creates a new WAL segment
|
||||
func (wm *WALManager) createNewSegment() *WALSegment {
|
||||
segmentID := generateID()
|
||||
startSeqID := atomic.LoadUint64(&wm.currentSeqID) + 1
|
||||
|
||||
segment := &WALSegment{
|
||||
ID: segmentID,
|
||||
StartSeqID: startSeqID,
|
||||
EndSeqID: startSeqID,
|
||||
Entries: make([]WALEntry, 0, wm.config.MaxSegmentSize),
|
||||
CreatedAt: time.Now(),
|
||||
Status: SegmentStatusActive,
|
||||
}
|
||||
|
||||
wm.segments[segmentID] = segment
|
||||
return segment
|
||||
}
|
||||
|
||||
// flushAllBuffers flushes all remaining buffers during shutdown
|
||||
func (wm *WALManager) flushAllBuffers(ctx context.Context) {
|
||||
wm.mu.Lock()
|
||||
defer wm.mu.Unlock()
|
||||
|
||||
if wm.activeSegment != nil && len(wm.activeSegment.Entries) > 0 {
|
||||
wm.flushSegment(ctx, wm.activeSegment)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func generateID() string {
|
||||
return fmt.Sprintf("wal_%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func calculateChecksum(data []byte) string {
|
||||
// Simple checksum implementation - in production, use a proper hash
|
||||
sum := 0
|
||||
for _, b := range data {
|
||||
sum += int(b)
|
||||
}
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
248
dag/wal_factory.go
Normal file
248
dag/wal_factory.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package dag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/dag/storage"
|
||||
"github.com/oarkflow/mq/dag/wal"
|
||||
"github.com/oarkflow/mq/logger"
|
||||
)
|
||||
|
||||
// WALEnabledStorageFactory creates WAL-enabled storage instances
|
||||
type WALEnabledStorageFactory struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
// NewWALEnabledStorageFactory creates a new WAL-enabled storage factory
|
||||
func NewWALEnabledStorageFactory(logger logger.Logger) *WALEnabledStorageFactory {
|
||||
return &WALEnabledStorageFactory{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMemoryStorage creates a WAL-enabled memory storage
|
||||
func (f *WALEnabledStorageFactory) CreateMemoryStorage(walConfig *wal.WALConfig) (storage.TaskStorage, *wal.WALManager, error) {
|
||||
if walConfig == nil {
|
||||
walConfig = wal.DefaultWALConfig()
|
||||
}
|
||||
|
||||
// Create underlying memory storage
|
||||
memoryStorage := storage.NewMemoryTaskStorage()
|
||||
|
||||
// For memory storage, we'll use an in-memory SQLite database for WAL
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create in-memory database: %w", err)
|
||||
}
|
||||
|
||||
// Create WAL storage implementation
|
||||
walStorage := wal.NewWALStorage(memoryStorage, db, walConfig)
|
||||
|
||||
// Initialize WAL tables
|
||||
if err := walStorage.InitializeTables(context.Background()); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to initialize WAL tables: %w", err)
|
||||
}
|
||||
|
||||
// Create WAL manager
|
||||
walManager := wal.NewWALManager(walConfig, walStorage)
|
||||
|
||||
// Create WAL-enabled storage wrapper
|
||||
walEnabledStorage := &WALEnabledStorageWrapper{
|
||||
underlying: memoryStorage,
|
||||
walManager: walManager,
|
||||
}
|
||||
|
||||
return walEnabledStorage, walManager, nil
|
||||
}
|
||||
|
||||
// CreateSQLStorage creates a WAL-enabled SQL storage
|
||||
func (f *WALEnabledStorageFactory) CreateSQLStorage(config *storage.TaskStorageConfig, walConfig *wal.WALConfig) (storage.TaskStorage, *wal.WALManager, error) {
|
||||
if config == nil {
|
||||
config = &storage.TaskStorageConfig{
|
||||
Type: "postgres", // Default to postgres
|
||||
}
|
||||
}
|
||||
|
||||
if walConfig == nil {
|
||||
walConfig = wal.DefaultWALConfig()
|
||||
}
|
||||
|
||||
// Create underlying SQL storage
|
||||
sqlStorage, err := storage.NewSQLTaskStorage(config)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create SQL storage: %w", err)
|
||||
}
|
||||
|
||||
// Get the database connection from the SQL storage
|
||||
db := sqlStorage.GetDB() // We'll need to add this method to SQLTaskStorage
|
||||
|
||||
// Create WAL storage implementation
|
||||
walStorage := wal.NewWALStorage(sqlStorage, db, walConfig)
|
||||
|
||||
// Initialize WAL tables
|
||||
if err := walStorage.InitializeTables(context.Background()); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to initialize WAL tables: %w", err)
|
||||
}
|
||||
|
||||
// Create WAL manager
|
||||
walManager := wal.NewWALManager(walConfig, walStorage)
|
||||
|
||||
// Create WAL-enabled storage wrapper
|
||||
walEnabledStorage := &WALEnabledStorageWrapper{
|
||||
underlying: sqlStorage,
|
||||
walManager: walManager,
|
||||
}
|
||||
|
||||
return walEnabledStorage, walManager, nil
|
||||
}
|
||||
|
||||
// WALEnabledStorageWrapper wraps any TaskStorage with WAL functionality
|
||||
type WALEnabledStorageWrapper struct {
|
||||
underlying storage.TaskStorage
|
||||
walManager *wal.WALManager
|
||||
}
|
||||
|
||||
// SaveTask saves a task through WAL
|
||||
func (w *WALEnabledStorageWrapper) SaveTask(ctx context.Context, task *storage.PersistentTask) error {
|
||||
// Serialize task to JSON for WAL
|
||||
taskData, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task: %w", err)
|
||||
}
|
||||
|
||||
// Write to WAL first
|
||||
if err := w.walManager.WriteEntry(ctx, wal.WALEntryTypeTaskUpdate, taskData, map[string]interface{}{
|
||||
"task_id": task.ID,
|
||||
"dag_id": task.DAGID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write task to WAL: %w", err)
|
||||
}
|
||||
|
||||
// Delegate to underlying storage (will be applied during flush)
|
||||
return w.underlying.SaveTask(ctx, task)
|
||||
}
|
||||
|
||||
// LogActivity logs activity through WAL
|
||||
func (w *WALEnabledStorageWrapper) LogActivity(ctx context.Context, logEntry *storage.TaskActivityLog) error {
|
||||
// Serialize log to JSON for WAL
|
||||
logData, err := json.Marshal(logEntry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal activity log: %w", err)
|
||||
}
|
||||
|
||||
// Write to WAL first
|
||||
if err := w.walManager.WriteEntry(ctx, wal.WALEntryTypeActivityLog, logData, map[string]interface{}{
|
||||
"task_id": logEntry.TaskID,
|
||||
"dag_id": logEntry.DAGID,
|
||||
"action": logEntry.Action,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write activity log to WAL: %w", err)
|
||||
}
|
||||
|
||||
// Delegate to underlying storage (will be applied during flush)
|
||||
return w.underlying.LogActivity(ctx, logEntry)
|
||||
}
|
||||
|
||||
// GetTask retrieves a task
|
||||
func (w *WALEnabledStorageWrapper) GetTask(ctx context.Context, taskID string) (*storage.PersistentTask, error) {
|
||||
return w.underlying.GetTask(ctx, taskID)
|
||||
}
|
||||
|
||||
// GetTasksByDAG retrieves tasks by DAG ID
|
||||
func (w *WALEnabledStorageWrapper) GetTasksByDAG(ctx context.Context, dagID string, limit int, offset int) ([]*storage.PersistentTask, error) {
|
||||
return w.underlying.GetTasksByDAG(ctx, dagID, limit, offset)
|
||||
}
|
||||
|
||||
// GetTasksByStatus retrieves tasks by status
|
||||
func (w *WALEnabledStorageWrapper) GetTasksByStatus(ctx context.Context, dagID string, status storage.TaskStatus) ([]*storage.PersistentTask, error) {
|
||||
return w.underlying.GetTasksByStatus(ctx, dagID, status)
|
||||
}
|
||||
|
||||
// UpdateTaskStatus updates task status
|
||||
func (w *WALEnabledStorageWrapper) UpdateTaskStatus(ctx context.Context, taskID string, status storage.TaskStatus, errorMsg string) error {
|
||||
return w.underlying.UpdateTaskStatus(ctx, taskID, status, errorMsg)
|
||||
}
|
||||
|
||||
// DeleteTask deletes a task
|
||||
func (w *WALEnabledStorageWrapper) DeleteTask(ctx context.Context, taskID string) error {
|
||||
return w.underlying.DeleteTask(ctx, taskID)
|
||||
}
|
||||
|
||||
// DeleteTasksByDAG deletes tasks by DAG ID
|
||||
func (w *WALEnabledStorageWrapper) DeleteTasksByDAG(ctx context.Context, dagID string) error {
|
||||
return w.underlying.DeleteTasksByDAG(ctx, dagID)
|
||||
}
|
||||
|
||||
// GetActivityLogs retrieves activity logs
|
||||
func (w *WALEnabledStorageWrapper) GetActivityLogs(ctx context.Context, taskID string, limit int, offset int) ([]*storage.TaskActivityLog, error) {
|
||||
return w.underlying.GetActivityLogs(ctx, taskID, limit, offset)
|
||||
}
|
||||
|
||||
// GetActivityLogsByDAG retrieves activity logs by DAG ID
|
||||
func (w *WALEnabledStorageWrapper) GetActivityLogsByDAG(ctx context.Context, dagID string, limit int, offset int) ([]*storage.TaskActivityLog, error) {
|
||||
return w.underlying.GetActivityLogsByDAG(ctx, dagID, limit, offset)
|
||||
}
|
||||
|
||||
// SaveTasks saves multiple tasks
|
||||
func (w *WALEnabledStorageWrapper) SaveTasks(ctx context.Context, tasks []*storage.PersistentTask) error {
|
||||
return w.underlying.SaveTasks(ctx, tasks)
|
||||
}
|
||||
|
||||
// GetPendingTasks retrieves pending tasks
|
||||
func (w *WALEnabledStorageWrapper) GetPendingTasks(ctx context.Context, dagID string, limit int) ([]*storage.PersistentTask, error) {
|
||||
return w.underlying.GetPendingTasks(ctx, dagID, limit)
|
||||
}
|
||||
|
||||
// GetResumableTasks retrieves resumable tasks
|
||||
func (w *WALEnabledStorageWrapper) GetResumableTasks(ctx context.Context, dagID string) ([]*storage.PersistentTask, error) {
|
||||
return w.underlying.GetResumableTasks(ctx, dagID)
|
||||
}
|
||||
|
||||
// CleanupOldTasks cleans up old tasks
|
||||
func (w *WALEnabledStorageWrapper) CleanupOldTasks(ctx context.Context, dagID string, olderThan time.Time) error {
|
||||
return w.underlying.CleanupOldTasks(ctx, dagID, olderThan)
|
||||
}
|
||||
|
||||
// CleanupOldActivityLogs cleans up old activity logs
|
||||
func (w *WALEnabledStorageWrapper) CleanupOldActivityLogs(ctx context.Context, dagID string, olderThan time.Time) error {
|
||||
return w.underlying.CleanupOldActivityLogs(ctx, dagID, olderThan)
|
||||
}
|
||||
|
||||
// Ping checks storage connectivity
|
||||
func (w *WALEnabledStorageWrapper) Ping(ctx context.Context) error {
|
||||
return w.underlying.Ping(ctx)
|
||||
}
|
||||
|
||||
// Close closes the storage
|
||||
func (w *WALEnabledStorageWrapper) Close() error {
|
||||
return w.underlying.Close()
|
||||
}
|
||||
|
||||
// Flush forces an immediate flush of the WAL buffer
|
||||
func (w *WALEnabledStorageWrapper) Flush(ctx context.Context) error {
|
||||
return w.walManager.Flush(ctx)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the WAL-enabled storage
|
||||
func (w *WALEnabledStorageWrapper) Shutdown(ctx context.Context) error {
|
||||
// Flush any remaining entries
|
||||
if err := w.walManager.Flush(ctx); err != nil {
|
||||
// Log error but continue with shutdown
|
||||
}
|
||||
|
||||
// Shutdown WAL manager
|
||||
if err := w.walManager.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown WAL manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWALMetrics returns WAL performance metrics
|
||||
func (w *WALEnabledStorageWrapper) GetWALMetrics() wal.WALMetrics {
|
||||
return w.walManager.GetMetrics()
|
||||
}
|
179
examples/WAL_README.md
Normal file
179
examples/WAL_README.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# WAL (Write-Ahead Logging) System
|
||||
|
||||
This directory contains a robust enterprise-grade WAL system implementation designed to prevent database overload from frequent task logging operations.
|
||||
|
||||
## Overview
|
||||
|
||||
The WAL system provides:
|
||||
- **Buffered Logging**: High-frequency logging operations are buffered in memory
|
||||
- **Batch Processing**: Periodic batch flushing to database for optimal performance
|
||||
- **Crash Recovery**: Automatic recovery of unflushed entries on system restart
|
||||
- **Performance Metrics**: Real-time monitoring of WAL operations
|
||||
- **Graceful Shutdown**: Ensures data consistency during shutdown
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. WAL Manager (`dag/wal/wal.go`)
|
||||
Core WAL functionality with buffering, segment management, and flush operations.
|
||||
|
||||
### 2. WAL Storage (`dag/wal/storage.go`)
|
||||
Database persistence layer for WAL entries and segments.
|
||||
|
||||
### 3. WAL Recovery (`dag/wal/recovery.go`)
|
||||
Crash recovery mechanisms to replay unflushed entries.
|
||||
|
||||
### 4. WAL Factory (`dag/wal_factory.go`)
|
||||
Factory for creating WAL-enabled storage instances.
|
||||
|
||||
## Usage Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/dag"
|
||||
"github.com/oarkflow/mq/dag/storage"
|
||||
"github.com/oarkflow/mq/dag/wal"
|
||||
"github.com/oarkflow/mq/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create logger
|
||||
l := logger.NewDefaultLogger()
|
||||
|
||||
// Create WAL-enabled storage factory
|
||||
factory := dag.NewWALEnabledStorageFactory(l)
|
||||
|
||||
// Configure WAL
|
||||
walConfig := &wal.WALConfig{
|
||||
MaxBufferSize: 5000, // Buffer up to 5000 entries
|
||||
FlushInterval: 2 * time.Second, // Flush every 2 seconds
|
||||
MaxFlushRetries: 3, // Retry failed flushes
|
||||
MaxSegmentSize: 10000, // 10K entries per segment
|
||||
SegmentRetention: 48 * time.Hour, // Keep segments for 48 hours
|
||||
WorkerCount: 4, // 4 flush workers
|
||||
BatchSize: 500, // Batch 500 operations
|
||||
EnableRecovery: true, // Enable crash recovery
|
||||
EnableMetrics: true, // Enable metrics
|
||||
}
|
||||
|
||||
// Create WAL-enabled storage
|
||||
storage, walManager, err := factory.CreateMemoryStorage(walConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
// Create DAG with WAL-enabled storage
|
||||
d := dag.NewDAG("My DAG", "my-dag", func(taskID string, result mq.Result) {
|
||||
// Handle final results
|
||||
})
|
||||
|
||||
// Set the WAL-enabled storage
|
||||
d.SetTaskStorage(storage)
|
||||
|
||||
// Now all logging operations will be buffered and batched
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and log activities - these will be buffered
|
||||
for i := 0; i < 1000; i++ {
|
||||
task := &storage.PersistentTask{
|
||||
ID: fmt.Sprintf("task-%d", i),
|
||||
DAGID: "my-dag",
|
||||
Status: storage.TaskStatusRunning,
|
||||
}
|
||||
|
||||
// This will be buffered, not written immediately to DB
|
||||
d.GetTaskStorage().SaveTask(ctx, task)
|
||||
|
||||
// Activity logging will also be buffered
|
||||
activity := &storage.TaskActivityLog{
|
||||
TaskID: task.ID,
|
||||
DAGID: "my-dag",
|
||||
Action: "processing",
|
||||
Message: "Task is being processed",
|
||||
}
|
||||
d.GetTaskStorage().LogActivity(ctx, activity)
|
||||
}
|
||||
|
||||
// Get performance metrics
|
||||
metrics := walManager.GetMetrics()
|
||||
fmt.Printf("Buffered: %d, Flushed: %d\n", metrics.EntriesBuffered, metrics.EntriesFlushed)
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### WALConfig Fields
|
||||
|
||||
- `MaxBufferSize`: Maximum entries to buffer before flush (default: 1000)
|
||||
- `FlushInterval`: How often to flush buffer (default: 5s)
|
||||
- `MaxFlushRetries`: Max retries for failed flushes (default: 3)
|
||||
- `MaxSegmentSize`: Maximum entries per segment (default: 5000)
|
||||
- `SegmentRetention`: How long to keep flushed segments (default: 24h)
|
||||
- `WorkerCount`: Number of flush workers (default: 2)
|
||||
- `BatchSize`: Batch size for database operations (default: 100)
|
||||
- `EnableRecovery`: Enable crash recovery (default: true)
|
||||
- `RecoveryTimeout`: Timeout for recovery operations (default: 30s)
|
||||
- `EnableMetrics`: Enable metrics collection (default: true)
|
||||
- `MetricsInterval`: Metrics collection interval (default: 10s)
|
||||
|
||||
## Performance Benefits
|
||||
|
||||
1. **Reduced Database Load**: Buffering prevents thousands of individual INSERT operations
|
||||
2. **Batch Processing**: Database operations are performed in optimized batches
|
||||
3. **Async Processing**: Logging doesn't block main application flow
|
||||
4. **Configurable Buffering**: Tune buffer size based on your throughput needs
|
||||
5. **Crash Recovery**: Never lose data even if system crashes
|
||||
|
||||
## Integration with Task Manager
|
||||
|
||||
The WAL system integrates seamlessly with the existing task manager:
|
||||
|
||||
```go
|
||||
// The task manager will automatically use WAL buffering
|
||||
// when WAL-enabled storage is configured
|
||||
taskManager := NewTaskManager(dag, taskID, resultCh, iterators, walStorage)
|
||||
|
||||
// All activity logging will be buffered
|
||||
taskManager.logActivity(ctx, "processing", "Task started processing")
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
Get real-time metrics about WAL performance:
|
||||
|
||||
```go
|
||||
metrics := walManager.GetMetrics()
|
||||
fmt.Printf("Entries Buffered: %d\n", metrics.EntriesBuffered)
|
||||
fmt.Printf("Entries Flushed: %d\n", metrics.EntriesFlushed)
|
||||
fmt.Printf("Flush Operations: %d\n", metrics.FlushOperations)
|
||||
fmt.Printf("Average Flush Time: %v\n", metrics.AverageFlushTime)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Tune Buffer Size**: Set based on your expected logging frequency
|
||||
2. **Monitor Metrics**: Keep an eye on buffer usage and flush performance
|
||||
3. **Configure Retention**: Set appropriate segment retention for your needs
|
||||
4. **Use Recovery**: Always enable recovery for production deployments
|
||||
5. **Batch Size**: Optimize batch size based on your database capabilities
|
||||
|
||||
## Database Support
|
||||
|
||||
The WAL system supports:
|
||||
- PostgreSQL
|
||||
- SQLite
|
||||
- MySQL (via storage interface)
|
||||
- In-memory storage (for testing/development)
|
||||
|
||||
## Error Handling
|
||||
|
||||
The WAL system includes comprehensive error handling:
|
||||
- Failed flushes are automatically retried
|
||||
- Recovery process validates entries before replay
|
||||
- Graceful degradation if storage is unavailable
|
||||
- Detailed logging for troubleshooting
|
@@ -28,6 +28,7 @@ func main() {
|
||||
flow := dag.NewDAG("Sample DAG", "sample-dag", func(taskID string, result mq.Result) {
|
||||
fmt.Printf("DAG Final result for task %s: %s\n", taskID, string(result.Payload))
|
||||
})
|
||||
flow.ConfigureMemoryStorage()
|
||||
flow.AddNode(dag.Function, "GetData", "GetData", &GetData{}, true)
|
||||
flow.AddNode(dag.Function, "Loop", "Loop", &Loop{})
|
||||
flow.AddNode(dag.Function, "ValidateAge", "ValidateAge", &ValidateAge{})
|
||||
|
442
examples/middleware/middleware_example_main.go
Normal file
442
examples/middleware/middleware_example_main.go
Normal file
@@ -0,0 +1,442 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
)
|
||||
|
||||
// User represents a user with roles
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Roles []string `json:"roles"`
|
||||
}
|
||||
|
||||
// HasRole checks if user has a specific role
|
||||
func (u *User) HasRole(role string) bool {
|
||||
for _, r := range u.Roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAnyRole checks if user has any of the specified roles
|
||||
func (u *User) HasAnyRole(roles ...string) bool {
|
||||
for _, requiredRole := range roles {
|
||||
if u.HasRole(requiredRole) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LoggingMiddleware logs the start and end of task processing
|
||||
func LoggingMiddleware(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("Middleware: Starting processing for node %s, task %s", task.Topic, task.ID)
|
||||
start := time.Now()
|
||||
|
||||
// For middleware, we return a successful result to continue to next middleware/processor
|
||||
// The actual processing will happen after all middlewares
|
||||
result := mq.Result{
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
Payload: task.Payload, // Pass through the payload
|
||||
}
|
||||
|
||||
log.Printf("Middleware: Completed in %v", time.Since(start))
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidationMiddleware validates the task payload
|
||||
func ValidationMiddleware(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("ValidationMiddleware: Validating payload for node %s", task.Topic)
|
||||
|
||||
// Check if payload is empty
|
||||
if len(task.Payload) == 0 {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("empty payload not allowed"),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("ValidationMiddleware: Payload validation passed")
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
Payload: task.Payload,
|
||||
}
|
||||
}
|
||||
|
||||
// RoleCheckMiddleware checks if the user has required roles for accessing a sub-DAG
|
||||
func RoleCheckMiddleware(requiredRoles ...string) mq.Handler {
|
||||
return func(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("RoleCheckMiddleware: Checking roles %v for node %s", requiredRoles, task.Topic)
|
||||
|
||||
// Extract user from payload
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("invalid payload format: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
userData, exists := payload["user"]
|
||||
if !exists {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("user information not found in payload"),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
userBytes, err := json.Marshal(userData)
|
||||
if err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("invalid user data: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
var user User
|
||||
if err := json.Unmarshal(userBytes, &user); err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("invalid user format: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
if !user.HasAnyRole(requiredRoles...) {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("user %s does not have required roles %v. User roles: %v", user.Name, requiredRoles, user.Roles),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("RoleCheckMiddleware: User %s authorized for roles %v", user.Name, requiredRoles)
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
Payload: task.Payload,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TimingMiddleware measures execution time
|
||||
func TimingMiddleware(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("TimingMiddleware: Starting timing for node %s", task.Topic)
|
||||
|
||||
// Add timing info to context
|
||||
ctx = context.WithValue(ctx, "start_time", time.Now())
|
||||
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
Payload: task.Payload,
|
||||
}
|
||||
}
|
||||
|
||||
// Example processor that simulates some work
|
||||
type ExampleProcessor struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *ExampleProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("Processor: Processing task %s on node %s", task.ID, task.Topic)
|
||||
|
||||
// Simulate some processing time
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Parse the payload as JSON
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("invalid payload: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Add processing information
|
||||
payload["processed_by"] = "ExampleProcessor"
|
||||
payload["processing_time"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
resultPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("failed to marshal result: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Payload: resultPayload,
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// AdminProcessor handles admin-specific tasks
|
||||
type AdminProcessor struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *AdminProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("AdminProcessor: Processing admin task %s on node %s", task.ID, task.Topic)
|
||||
|
||||
// Simulate admin processing
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Parse the payload as JSON
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("invalid payload: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Add admin-specific processing information
|
||||
payload["processed_by"] = "AdminProcessor"
|
||||
payload["admin_action"] = "validated_and_processed"
|
||||
payload["processing_time"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
resultPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("failed to marshal result: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Payload: resultPayload,
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// UserProcessor handles user-specific tasks
|
||||
type UserProcessor struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *UserProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("UserProcessor: Processing user task %s on node %s", task.ID, task.Topic)
|
||||
|
||||
// Simulate user processing
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Parse the payload as JSON
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("invalid payload: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Add user-specific processing information
|
||||
payload["processed_by"] = "UserProcessor"
|
||||
payload["user_action"] = "authenticated_and_processed"
|
||||
payload["processing_time"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
resultPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("failed to marshal result: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Payload: resultPayload,
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// GuestProcessor handles guest-specific tasks
|
||||
type GuestProcessor struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *GuestProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("GuestProcessor: Processing guest task %s on node %s", task.ID, task.Topic)
|
||||
|
||||
// Simulate guest processing
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Parse the payload as JSON
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(task.Payload, &payload); err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("invalid payload: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Add guest-specific processing information
|
||||
payload["processed_by"] = "GuestProcessor"
|
||||
payload["guest_action"] = "limited_access_processed"
|
||||
payload["processing_time"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
resultPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("failed to marshal result: %v", err),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Payload: resultPayload,
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// createAdminSubDAG creates a sub-DAG for admin operations
|
||||
func createAdminSubDAG() *dag.DAG {
|
||||
adminDAG := dag.NewDAG("Admin Sub-DAG", "admin-subdag", func(taskID string, result mq.Result) {
|
||||
log.Printf("Admin Sub-DAG completed for task %s: %s", taskID, string(result.Payload))
|
||||
})
|
||||
|
||||
adminDAG.AddNode(dag.Function, "Admin Validate", "admin_validate", &AdminProcessor{Operation: dag.Operation{Type: dag.Function}}, true)
|
||||
adminDAG.AddNode(dag.Function, "Admin Process", "admin_process", &AdminProcessor{Operation: dag.Operation{Type: dag.Function}})
|
||||
adminDAG.AddNode(dag.Function, "Admin Finalize", "admin_finalize", &AdminProcessor{Operation: dag.Operation{Type: dag.Function}})
|
||||
|
||||
adminDAG.AddEdge(dag.Simple, "Validate to Process", "admin_validate", "admin_process")
|
||||
adminDAG.AddEdge(dag.Simple, "Process to Finalize", "admin_process", "admin_finalize")
|
||||
|
||||
return adminDAG
|
||||
}
|
||||
|
||||
// createUserSubDAG creates a sub-DAG for user operations
|
||||
func createUserSubDAG() *dag.DAG {
|
||||
userDAG := dag.NewDAG("User Sub-DAG", "user-subdag", func(taskID string, result mq.Result) {
|
||||
log.Printf("User Sub-DAG completed for task %s: %s", taskID, string(result.Payload))
|
||||
})
|
||||
|
||||
userDAG.AddNode(dag.Function, "User Auth", "user_auth", &UserProcessor{Operation: dag.Operation{Type: dag.Function}}, true)
|
||||
userDAG.AddNode(dag.Function, "User Process", "user_process", &UserProcessor{Operation: dag.Operation{Type: dag.Function}})
|
||||
userDAG.AddNode(dag.Function, "User Notify", "user_notify", &UserProcessor{Operation: dag.Operation{Type: dag.Function}})
|
||||
|
||||
userDAG.AddEdge(dag.Simple, "Auth to Process", "user_auth", "user_process")
|
||||
userDAG.AddEdge(dag.Simple, "Process to Notify", "user_process", "user_notify")
|
||||
|
||||
return userDAG
|
||||
}
|
||||
|
||||
// createGuestSubDAG creates a sub-DAG for guest operations
|
||||
func createGuestSubDAG() *dag.DAG {
|
||||
guestDAG := dag.NewDAG("Guest Sub-DAG", "guest-subdag", func(taskID string, result mq.Result) {
|
||||
log.Printf("Guest Sub-DAG completed for task %s: %s", taskID, string(result.Payload))
|
||||
})
|
||||
|
||||
guestDAG.AddNode(dag.Function, "Guest Welcome", "guest_welcome", &GuestProcessor{Operation: dag.Operation{Type: dag.Function}}, true)
|
||||
guestDAG.AddNode(dag.Function, "Guest Info", "guest_info", &GuestProcessor{Operation: dag.Operation{Type: dag.Function}})
|
||||
|
||||
guestDAG.AddEdge(dag.Simple, "Welcome to Info", "guest_welcome", "guest_info")
|
||||
|
||||
return guestDAG
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Create the main DAG
|
||||
flow := dag.NewDAG("Role-Based Access Control DAG", "rbac-dag", func(taskID string, result mq.Result) {
|
||||
log.Printf("Main DAG completed for task %s: %s", taskID, string(result.Payload))
|
||||
})
|
||||
|
||||
// Add entry point
|
||||
flow.AddNode(dag.Function, "Entry Point", "entry", &ExampleProcessor{Operation: dag.Operation{Type: dag.Function}}, true)
|
||||
|
||||
// Add sub-DAGs with role-based access
|
||||
flow.AddDAGNode(dag.Function, "Admin Operations", "admin_ops", createAdminSubDAG())
|
||||
flow.AddDAGNode(dag.Function, "User Operations", "user_ops", createUserSubDAG())
|
||||
flow.AddDAGNode(dag.Function, "Guest Operations", "guest_ops", createGuestSubDAG())
|
||||
|
||||
// Add edges from entry to sub-DAGs
|
||||
flow.AddEdge(dag.Simple, "Entry to Admin", "entry", "admin_ops")
|
||||
flow.AddEdge(dag.Simple, "Entry to User", "entry", "user_ops")
|
||||
flow.AddEdge(dag.Simple, "Entry to Guest", "entry", "guest_ops")
|
||||
|
||||
// Add global middlewares
|
||||
flow.Use(LoggingMiddleware, ValidationMiddleware)
|
||||
|
||||
// Add role-based middlewares for sub-DAGs
|
||||
flow.UseNodeMiddlewares(
|
||||
dag.NodeMiddleware{
|
||||
Node: "admin_ops",
|
||||
Middlewares: []mq.Handler{RoleCheckMiddleware("admin", "superuser")},
|
||||
},
|
||||
dag.NodeMiddleware{
|
||||
Node: "user_ops",
|
||||
Middlewares: []mq.Handler{RoleCheckMiddleware("user", "admin", "superuser")},
|
||||
},
|
||||
dag.NodeMiddleware{
|
||||
Node: "guest_ops",
|
||||
Middlewares: []mq.Handler{RoleCheckMiddleware("guest", "user", "admin", "superuser")},
|
||||
},
|
||||
)
|
||||
|
||||
if flow.Error != nil {
|
||||
panic(flow.Error)
|
||||
}
|
||||
|
||||
// Define test users with different roles
|
||||
users := []User{
|
||||
{ID: "1", Name: "Alice", Roles: []string{"admin", "superuser"}},
|
||||
{ID: "2", Name: "Bob", Roles: []string{"user"}},
|
||||
{ID: "3", Name: "Charlie", Roles: []string{"guest"}},
|
||||
{ID: "4", Name: "Dave", Roles: []string{"user", "admin"}},
|
||||
{ID: "5", Name: "Eve", Roles: []string{}}, // No roles
|
||||
}
|
||||
|
||||
// Test each user
|
||||
for _, user := range users {
|
||||
log.Printf("\n=== Testing user: %s (Roles: %v) ===", user.Name, user.Roles)
|
||||
|
||||
// Create payload with user information
|
||||
payload := map[string]interface{}{
|
||||
"user": user,
|
||||
"message": fmt.Sprintf("Request from %s", user.Name),
|
||||
"data": "test data",
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal payload for user %s: %v", user.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Processing request for user %s with payload: %s", user.Name, string(payloadBytes))
|
||||
|
||||
result := flow.Process(context.Background(), payloadBytes)
|
||||
if result.Error != nil {
|
||||
log.Printf("❌ DAG processing failed for user %s: %v", user.Name, result.Error)
|
||||
} else {
|
||||
log.Printf("✅ DAG processing completed successfully for user %s: %s", user.Name, string(result.Payload))
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,134 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
)
|
||||
|
||||
// LoggingMiddleware logs the start and end of task processing
|
||||
func LoggingMiddleware(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("Middleware: Starting processing for node %s, task %s", task.Topic, task.ID)
|
||||
start := time.Now()
|
||||
|
||||
// For middleware, we return a successful result to continue to next middleware/processor
|
||||
// The actual processing will happen after all middlewares
|
||||
result := mq.Result{
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
Payload: task.Payload, // Pass through the payload
|
||||
}
|
||||
|
||||
log.Printf("Middleware: Completed in %v", time.Since(start))
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidationMiddleware validates the task payload
|
||||
func ValidationMiddleware(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("ValidationMiddleware: Validating payload for node %s", task.Topic)
|
||||
|
||||
// Check if payload is empty
|
||||
if len(task.Payload) == 0 {
|
||||
return mq.Result{
|
||||
Status: mq.Failed,
|
||||
Error: fmt.Errorf("empty payload not allowed"),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("ValidationMiddleware: Payload validation passed")
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
Payload: task.Payload,
|
||||
}
|
||||
}
|
||||
|
||||
// TimingMiddleware measures execution time
|
||||
func TimingMiddleware(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("TimingMiddleware: Starting timing for node %s", task.Topic)
|
||||
|
||||
// Add timing info to context
|
||||
ctx = context.WithValue(ctx, "start_time", time.Now())
|
||||
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
Payload: task.Payload,
|
||||
}
|
||||
}
|
||||
|
||||
// Example processor that simulates some work
|
||||
type ExampleProcessor struct {
|
||||
dag.Operation
|
||||
}
|
||||
|
||||
func (p *ExampleProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("Processor: Processing task %s on node %s", task.ID, task.Topic)
|
||||
|
||||
// Simulate some processing time
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check if timing middleware was used
|
||||
if startTime, ok := ctx.Value("start_time").(time.Time); ok {
|
||||
duration := time.Since(startTime)
|
||||
log.Printf("Processor: Task completed in %v", duration)
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("Processed: %s", string(task.Payload))
|
||||
return mq.Result{
|
||||
Status: mq.Completed,
|
||||
Payload: []byte(result),
|
||||
Ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Create a new DAG
|
||||
flow := dag.NewDAG("Middleware Example", "middleware-example", func(taskID string, result mq.Result) {
|
||||
log.Printf("Final result for task %s: %s", taskID, string(result.Payload))
|
||||
})
|
||||
|
||||
// Add nodes
|
||||
flow.AddNode(dag.Function, "Process A", "process_a", &ExampleProcessor{Operation: dag.Operation{Type: dag.Function}}, true)
|
||||
flow.AddNode(dag.Function, "Process B", "process_b", &ExampleProcessor{Operation: dag.Operation{Type: dag.Function}})
|
||||
flow.AddNode(dag.Function, "Process C", "process_c", &ExampleProcessor{Operation: dag.Operation{Type: dag.Function}})
|
||||
|
||||
// Add edges
|
||||
flow.AddEdge(dag.Simple, "A to B", "process_a", "process_b")
|
||||
flow.AddEdge(dag.Simple, "B to C", "process_b", "process_c")
|
||||
|
||||
// Add global middlewares that apply to all nodes
|
||||
flow.Use(LoggingMiddleware, ValidationMiddleware)
|
||||
|
||||
// Add node-specific middlewares
|
||||
flow.UseNodeMiddlewares(
|
||||
dag.NodeMiddleware{
|
||||
Node: "process_a",
|
||||
Middlewares: []mq.Handler{TimingMiddleware},
|
||||
},
|
||||
dag.NodeMiddleware{
|
||||
Node: "process_b",
|
||||
Middlewares: []mq.Handler{TimingMiddleware},
|
||||
},
|
||||
)
|
||||
|
||||
if flow.Error != nil {
|
||||
panic(flow.Error)
|
||||
}
|
||||
|
||||
// Test the DAG with middleware
|
||||
data := []byte(`{"message": "Hello from middleware example"}`)
|
||||
log.Printf("Starting DAG processing with payload: %s", string(data))
|
||||
|
||||
result := flow.Process(context.Background(), data)
|
||||
if result.Error != nil {
|
||||
log.Printf("DAG processing failed: %v", result.Error)
|
||||
} else {
|
||||
log.Printf("DAG processing completed successfully: %s", string(result.Payload))
|
||||
}
|
||||
}
|
122
examples/task_recovery_example.go
Normal file
122
examples/task_recovery_example.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
dagstorage "github.com/oarkflow/mq/dag/storage"
|
||||
)
|
||||
|
||||
// RecoveryProcessor demonstrates a simple processor for recovery example
|
||||
type RecoveryProcessor struct {
|
||||
nodeName string
|
||||
}
|
||||
|
||||
func (p *RecoveryProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
log.Printf("Processing task %s in node %s", task.ID, p.nodeName)
|
||||
|
||||
// Simulate some processing time
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return mq.Result{
|
||||
Payload: task.Payload,
|
||||
Status: mq.Completed,
|
||||
Ctx: ctx,
|
||||
TaskID: task.ID,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RecoveryProcessor) Consume(ctx context.Context) error { return nil }
|
||||
func (p *RecoveryProcessor) Pause(ctx context.Context) error { return nil }
|
||||
func (p *RecoveryProcessor) Resume(ctx context.Context) error { return nil }
|
||||
func (p *RecoveryProcessor) Stop(ctx context.Context) error { return nil }
|
||||
func (p *RecoveryProcessor) Close() error { return nil }
|
||||
func (p *RecoveryProcessor) GetKey() string { return p.nodeName }
|
||||
func (p *RecoveryProcessor) SetKey(key string) { p.nodeName = key }
|
||||
func (p *RecoveryProcessor) GetType() string { return "recovery" }
|
||||
func (p *RecoveryProcessor) SetConfig(payload dag.Payload) {}
|
||||
func (p *RecoveryProcessor) SetTags(tags ...string) {}
|
||||
func (p *RecoveryProcessor) GetTags() []string { return nil }
|
||||
|
||||
func demonstrateTaskRecovery() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a DAG with 5 nodes (simulating a complex workflow)
|
||||
dagInstance := dag.NewDAG("complex-workflow", "workflow-1", func(taskID string, result mq.Result) {
|
||||
log.Printf("Workflow completed for task: %s", taskID)
|
||||
})
|
||||
|
||||
// Configure memory storage for this example
|
||||
dagInstance.ConfigureMemoryStorage()
|
||||
|
||||
// Add nodes to simulate a complex workflow
|
||||
nodes := []string{"start", "validate", "process", "enrich", "finalize"}
|
||||
|
||||
for _, nodeName := range nodes {
|
||||
dagInstance.AddNode(dag.Function, nodeName, fmt.Sprintf("Node %s", nodeName), &RecoveryProcessor{nodeName: nodeName}, true)
|
||||
}
|
||||
|
||||
// Connect the nodes in sequence
|
||||
for i := 0; i < len(nodes)-1; i++ {
|
||||
dagInstance.AddEdge(dag.Simple, fmt.Sprintf("Connect %s to %s", nodes[i], nodes[i+1]), nodes[i], nodes[i+1])
|
||||
}
|
||||
|
||||
// Simulate a task that was running and got interrupted
|
||||
runningTask := &dagstorage.PersistentTask{
|
||||
ID: "interrupted-task-123",
|
||||
DAGID: "workflow-1",
|
||||
NodeID: "start", // Original starting node
|
||||
CurrentNodeID: "process", // Task was processing this node when interrupted
|
||||
SubDAGPath: "", // No sub-dags in this example
|
||||
ProcessingState: "processing", // Was actively processing
|
||||
Status: dagstorage.TaskStatusRunning,
|
||||
Payload: json.RawMessage(`{"user_id": 12345, "action": "process_data"}`),
|
||||
CreatedAt: time.Now().Add(-10 * time.Minute), // Started 10 minutes ago
|
||||
UpdatedAt: time.Now().Add(-2 * time.Minute), // Last updated 2 minutes ago
|
||||
}
|
||||
|
||||
// Save the interrupted task to storage
|
||||
err := dagInstance.GetTaskStorage().SaveTask(ctx, runningTask)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to save interrupted task:", err)
|
||||
}
|
||||
|
||||
log.Println("✅ Simulated an interrupted task that was processing 'process' node")
|
||||
|
||||
// Simulate system restart - recover tasks
|
||||
log.Println("🔄 Simulating system restart...")
|
||||
time.Sleep(1 * time.Second) // Simulate restart delay
|
||||
|
||||
log.Println("🚀 Starting task recovery...")
|
||||
err = dagInstance.RecoverTasks(ctx)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to recover tasks:", err)
|
||||
}
|
||||
|
||||
log.Println("✅ Task recovery completed successfully!")
|
||||
|
||||
// Verify the task was recovered
|
||||
recoveredTasks, err := dagInstance.GetTaskStorage().GetResumableTasks(ctx, "workflow-1")
|
||||
if err != nil {
|
||||
log.Fatal("Failed to get recovered tasks:", err)
|
||||
}
|
||||
|
||||
log.Printf("📊 Found %d recovered tasks", len(recoveredTasks))
|
||||
for _, task := range recoveredTasks {
|
||||
log.Printf("🔄 Recovered task: %s, Current Node: %s, Status: %s",
|
||||
task.ID, task.CurrentNodeID, task.Status)
|
||||
}
|
||||
|
||||
log.Println("🎉 Task recovery demonstration completed!")
|
||||
log.Println("💡 In a real scenario, the recovered task would continue processing from the 'process' node")
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Println("=== DAG Task Recovery Example ===")
|
||||
demonstrateTaskRecovery()
|
||||
}
|
3
go.mod
3
go.mod
@@ -5,6 +5,8 @@ go 1.24.2
|
||||
require (
|
||||
github.com/gofiber/fiber/v2 v2.52.9
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.32
|
||||
github.com/oarkflow/date v0.0.4
|
||||
github.com/oarkflow/dipper v0.0.6
|
||||
github.com/oarkflow/errors v0.0.6
|
||||
@@ -14,6 +16,7 @@ require (
|
||||
github.com/oarkflow/json v0.0.28
|
||||
github.com/oarkflow/jsonschema v0.0.4
|
||||
github.com/oarkflow/log v1.0.83
|
||||
github.com/oarkflow/squealx v0.0.56
|
||||
github.com/oarkflow/xid v1.2.8
|
||||
golang.org/x/crypto v0.41.0
|
||||
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
|
||||
|
6
go.sum
6
go.sum
@@ -22,6 +22,8 @@ github.com/kaptinlin/go-i18n v0.1.4 h1:wCiwAn1LOcvymvWIVAM4m5dUAMiHunTdEubLDk4hT
|
||||
github.com/kaptinlin/go-i18n v0.1.4/go.mod h1:g1fn1GvTgT4CiLE8/fFE1hboHWJ6erivrDpiDtCcFKg=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
@@ -29,6 +31,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/oarkflow/date v0.0.4 h1:EwY/wiS3CqZNBx7b2x+3kkJwVNuGk+G0dls76kL/fhU=
|
||||
github.com/oarkflow/date v0.0.4/go.mod h1:xQTFc6p6O5VX6J75ZrPJbelIFGca1ASmhpgirFqL8vM=
|
||||
github.com/oarkflow/dipper v0.0.6 h1:E+ak9i4R1lxx0B04CjfG5DTLTmwuWA1nrdS6KIHdUxQ=
|
||||
@@ -47,6 +51,8 @@ github.com/oarkflow/jsonschema v0.0.4 h1:n5Sb7WVb7NNQzn/ei9++4VPqKXCPJhhsHeTGJkI
|
||||
github.com/oarkflow/jsonschema v0.0.4/go.mod h1:AxNG3Nk7KZxnnjRJlHLmS1wE9brtARu5caTFuicCtnA=
|
||||
github.com/oarkflow/log v1.0.83 h1:T/38wvjuNeVJ9PDo0wJDTnTUQZ5XeqlcvpbCItuFFJo=
|
||||
github.com/oarkflow/log v1.0.83/go.mod h1:dMn57z9uq11Y264cx9c9Ac7ska9qM+EBhn4qf9CNlsM=
|
||||
github.com/oarkflow/squealx v0.0.56 h1:8rPx3jWNnt4ez2P10m1Lz4HTAbvrs0MZ7jjKDJ87Vqg=
|
||||
github.com/oarkflow/squealx v0.0.56/go.mod h1:J5PNHmu3fH+IgrNm8tltz0aX4drT5uZ5j3r9dW5jQ/8=
|
||||
github.com/oarkflow/xid v1.2.8 h1:uCIX61Binq2RPMsqImZM6pPGzoZTmRyD6jguxF9aAA0=
|
||||
github.com/oarkflow/xid v1.2.8/go.mod h1:jG4YBh+swbjlWApGWDBYnsJEa7hi3CCpmuqhB3RAxVo=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
|
Reference in New Issue
Block a user