update: use oarkflow/json

This commit is contained in:
Oarkflow
2025-03-29 14:31:00 +05:45
parent d0774ba2e8
commit abc07e9360
7 changed files with 397 additions and 1434 deletions

View File

@@ -8,14 +8,16 @@ import (
"syscall" "syscall"
"time" "time"
v1 "github.com/oarkflow/mq/v1" "github.com/oarkflow/json"
v1 "github.com/oarkflow/mq"
) )
func main() { func main() {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop() defer stop()
pool := v1.NewPool(5, pool := v1.NewPool(5,
v1.WithTaskStorage(v1.NewInMemoryTaskStorage()), v1.WithTaskStorage(v1.NewMemoryTaskStorage(10*time.Minute)),
v1.WithHandler(func(ctx context.Context, payload *v1.Task) v1.Result { v1.WithHandler(func(ctx context.Context, payload *v1.Task) v1.Result {
v1.Logger.Info().Str("taskID", payload.ID).Msg("Processing task payload") v1.Logger.Info().Str("taskID", payload.ID).Msg("Processing task payload")
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
@@ -43,14 +45,14 @@ func main() {
metrics := pool.Metrics() metrics := pool.Metrics()
v1.Logger.Info().Msgf("Metrics: %+v", metrics) v1.Logger.Info().Msgf("Metrics: %+v", metrics)
pool.Stop() pool.Stop()
v1.Logger.Info().Msgf("Dead Letter Queue has %d tasks", len(v1.DLQ.Task())) v1.Logger.Info().Msgf("Dead Letter Queue has %d tasks", len(pool.DLQ.Task()))
}() }()
go func() { go func() {
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
task := &v1.Task{ task := &v1.Task{
ID: "", ID: "",
Payload: fmt.Sprintf("Task Payload %d", i), Payload: json.RawMessage(fmt.Sprintf("Task Payload %d", i)),
} }
if err := pool.EnqueueTask(context.Background(), task, rand.Intn(10)); err != nil { if err := pool.EnqueueTask(context.Background(), task, rand.Intn(10)); err != nil {
v1.Logger.Error().Err(err).Msg("Failed to enqueue task") v1.Logger.Error().Err(err).Msg("Failed to enqueue task")

266
pool.go
View File

@@ -3,14 +3,17 @@ package mq
import ( import (
"container/heap" "container/heap"
"context" "context"
"errors"
"fmt" "fmt"
"log"
"math/rand" "math/rand"
"net/http"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/oarkflow/mq/utils" "github.com/oarkflow/mq/utils"
"github.com/oarkflow/log"
) )
type Callback func(ctx context.Context, result Result) error type Callback func(ctx context.Context, result Result) error
@@ -25,6 +28,106 @@ type Metrics struct {
ExecutionTime int64 ExecutionTime int64
} }
type Plugin interface {
Initialize(config interface{}) error
BeforeTask(task *QueueTask)
AfterTask(task *QueueTask, result Result)
}
type DefaultPlugin struct{}
func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil }
func (dp *DefaultPlugin) BeforeTask(task *QueueTask) {
Logger.Info().Str("taskID", task.payload.ID).Msg("BeforeTask plugin invoked")
}
func (dp *DefaultPlugin) AfterTask(task *QueueTask, result Result) {
Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked")
}
type DeadLetterQueue struct {
tasks []*QueueTask
mu sync.Mutex
}
func NewDeadLetterQueue() *DeadLetterQueue {
return &DeadLetterQueue{
tasks: make([]*QueueTask, 0),
}
}
func (dlq *DeadLetterQueue) Task() []*QueueTask {
return dlq.tasks
}
func (dlq *DeadLetterQueue) Add(task *QueueTask) {
dlq.mu.Lock()
defer dlq.mu.Unlock()
dlq.tasks = append(dlq.tasks, task)
Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue")
}
type InMemoryMetricsRegistry struct {
metrics map[string]int64
mu sync.RWMutex
}
func NewInMemoryMetricsRegistry() *InMemoryMetricsRegistry {
return &InMemoryMetricsRegistry{
metrics: make(map[string]int64),
}
}
func (m *InMemoryMetricsRegistry) Register(metricName string, value interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
if v, ok := value.(int64); ok {
m.metrics[metricName] = v
Logger.Info().Str("metric", metricName).Msgf("Registered metric: %d", v)
}
}
func (m *InMemoryMetricsRegistry) Increment(metricName string) {
m.mu.Lock()
defer m.mu.Unlock()
m.metrics[metricName]++
}
func (m *InMemoryMetricsRegistry) Get(metricName string) interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
return m.metrics[metricName]
}
type WarningThresholds struct {
HighMemory int64
LongExecution time.Duration
}
type DynamicConfig struct {
Timeout time.Duration
BatchSize int
MaxMemoryLoad int64
IdleTimeout time.Duration
BackoffDuration time.Duration
MaxRetries int
ReloadInterval time.Duration
WarningThreshold WarningThresholds
}
var Config = &DynamicConfig{
Timeout: 10 * time.Second,
BatchSize: 1,
MaxMemoryLoad: 100 * 1024 * 1024,
IdleTimeout: 5 * time.Minute,
BackoffDuration: 2 * time.Second,
MaxRetries: 3,
ReloadInterval: 30 * time.Second,
WarningThreshold: WarningThresholds{
HighMemory: 1 * 1024 * 1024,
LongExecution: 2 * time.Second,
},
}
type Pool struct { type Pool struct {
taskStorage TaskStorage taskStorage TaskStorage
scheduler *Scheduler scheduler *Scheduler
@@ -35,6 +138,7 @@ type Pool struct {
completionCallback CompletionCallback completionCallback CompletionCallback
taskAvailableCond *sync.Cond taskAvailableCond *sync.Cond
callback Callback callback Callback
DLQ *DeadLetterQueue
taskQueue PriorityQueue taskQueue PriorityQueue
overflowBuffer []*QueueTask overflowBuffer []*QueueTask
metrics Metrics metrics Metrics
@@ -50,10 +154,8 @@ type Pool struct {
taskQueueLock sync.Mutex taskQueueLock sync.Mutex
numOfWorkers int32 numOfWorkers int32
paused bool paused bool
logger *log.Logger logger log.Logger
gracefulShutdown bool gracefulShutdown bool
// New fields for enhancements:
thresholds ThresholdConfig thresholds ThresholdConfig
diagnosticsEnabled bool diagnosticsEnabled bool
metricsRegistry MetricsRegistry metricsRegistry MetricsRegistry
@@ -61,33 +163,83 @@ type Pool struct {
circuitBreakerOpen bool circuitBreakerOpen bool
circuitBreakerFailureCount int32 circuitBreakerFailureCount int32
gracefulShutdownTimeout time.Duration gracefulShutdownTimeout time.Duration
plugins []Plugin
} }
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool { func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
pool := &Pool{ pool := &Pool{
stop: make(chan struct{}), stop: make(chan struct{}),
taskNotify: make(chan struct{}, numOfWorkers), taskNotify: make(chan struct{}, numOfWorkers),
batchSize: 1, batchSize: Config.BatchSize,
timeout: 10 * time.Second, timeout: Config.Timeout,
idleTimeout: 5 * time.Minute, idleTimeout: Config.IdleTimeout,
backoffDuration: 2 * time.Second, backoffDuration: Config.BackoffDuration,
maxRetries: 3, // Set max retries for failed tasks maxRetries: Config.MaxRetries,
logger: log.Default(), logger: Logger,
DLQ: NewDeadLetterQueue(),
metricsRegistry: NewInMemoryMetricsRegistry(),
diagnosticsEnabled: true,
gracefulShutdownTimeout: 10 * time.Second,
} }
pool.scheduler = NewScheduler(pool) pool.scheduler = NewScheduler(pool)
pool.taskAvailableCond = sync.NewCond(&sync.Mutex{}) pool.taskAvailableCond = sync.NewCond(&sync.Mutex{})
for _, opt := range opts { for _, opt := range opts {
opt(pool) opt(pool)
} }
if len(pool.taskQueue) == 0 { if pool.taskQueue == nil {
pool.taskQueue = make(PriorityQueue, 0, 10) pool.taskQueue = make(PriorityQueue, 0, 10)
} }
heap.Init(&pool.taskQueue) heap.Init(&pool.taskQueue)
pool.scheduler.Start() pool.scheduler.Start()
pool.Start(numOfWorkers) pool.Start(numOfWorkers)
startConfigReloader(pool)
go pool.dynamicWorkerScaler()
go pool.startHealthServer()
return pool return pool
} }
func validateDynamicConfig(c *DynamicConfig) error {
if c.Timeout <= 0 {
return errors.New("Timeout must be positive")
}
if c.BatchSize <= 0 {
return errors.New("BatchSize must be > 0")
}
if c.MaxMemoryLoad <= 0 {
return errors.New("MaxMemoryLoad must be > 0")
}
return nil
}
func startConfigReloader(pool *Pool) {
go func() {
ticker := time.NewTicker(Config.ReloadInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := validateDynamicConfig(Config); err != nil {
Logger.Error().Err(err).Msg("Invalid dynamic config, skipping reload")
continue
}
pool.timeout = Config.Timeout
pool.batchSize = Config.BatchSize
pool.maxMemoryLoad = Config.MaxMemoryLoad
pool.idleTimeout = Config.IdleTimeout
pool.backoffDuration = Config.BackoffDuration
pool.maxRetries = Config.MaxRetries
pool.thresholds = ThresholdConfig{
HighMemory: Config.WarningThreshold.HighMemory,
LongExecution: Config.WarningThreshold.LongExecution,
}
Logger.Info().Msg("Dynamic configuration reloaded")
case <-pool.stop:
return
}
}
}()
}
func (wp *Pool) Start(numWorkers int) { func (wp *Pool) Start(numWorkers int) {
storedTasks, err := wp.taskStorage.GetAllTasks() storedTasks, err := wp.taskStorage.GetAllTasks()
if err == nil { if err == nil {
@@ -144,13 +296,13 @@ func (wp *Pool) processNextBatch() {
wp.handleTask(task) wp.handleTask(task)
} }
} }
// @TODO - Why was this done?
//if len(tasks) > 0 {
// wp.taskCompletionNotifier.Done()
//}
} }
func (wp *Pool) handleTask(task *QueueTask) { func (wp *Pool) handleTask(task *QueueTask) {
if err := validateTaskInput(task.payload); err != nil {
wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Validation failed: %v", err)
return
}
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout) ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
defer cancel() defer cancel()
taskSize := int64(utils.SizeOf(task.payload)) taskSize := int64(utils.SizeOf(task.payload))
@@ -163,15 +315,15 @@ func (wp *Pool) handleTask(task *QueueTask) {
// Warning thresholds check // Warning thresholds check
if wp.thresholds.LongExecution > 0 && executionTime > int64(wp.thresholds.LongExecution.Milliseconds()) { if wp.thresholds.LongExecution > 0 && executionTime > int64(wp.thresholds.LongExecution.Milliseconds()) {
wp.logger.Printf("Warning: Task %s exceeded execution time threshold: %d ms", task.payload.ID, executionTime) wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", executionTime)
} }
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory { if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
wp.logger.Printf("Warning: Task %s memory usage %d exceeded threshold", task.payload.ID, taskSize) wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize)
} }
if result.Error != nil { if result.Error != nil {
atomic.AddInt64(&wp.metrics.ErrorCount, 1) atomic.AddInt64(&wp.metrics.ErrorCount, 1)
wp.logger.Printf("Error processing task %s: %v", task.payload.ID, result.Error) wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Error processing task: %v", result.Error)
wp.backoffAndStore(task) wp.backoffAndStore(task)
// Circuit breaker check // Circuit breaker check
@@ -179,12 +331,12 @@ func (wp *Pool) handleTask(task *QueueTask) {
newCount := atomic.AddInt32(&wp.circuitBreakerFailureCount, 1) newCount := atomic.AddInt32(&wp.circuitBreakerFailureCount, 1)
if newCount >= int32(wp.circuitBreaker.FailureThreshold) { if newCount >= int32(wp.circuitBreaker.FailureThreshold) {
wp.circuitBreakerOpen = true wp.circuitBreakerOpen = true
wp.logger.Println("Circuit breaker opened due to errors") wp.logger.Warn().Msg("Circuit breaker opened due to errors")
go func() { go func() {
time.Sleep(wp.circuitBreaker.ResetTimeout) time.Sleep(wp.circuitBreaker.ResetTimeout)
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0) atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
wp.circuitBreakerOpen = false wp.circuitBreakerOpen = false
wp.logger.Println("Circuit breaker reset") wp.logger.Info().Msg("Circuit breaker reset to closed state")
}() }()
} }
} }
@@ -198,17 +350,17 @@ func (wp *Pool) handleTask(task *QueueTask) {
// Diagnostics logging if enabled // Diagnostics logging if enabled
if wp.diagnosticsEnabled { if wp.diagnosticsEnabled {
wp.logger.Printf("Task %s executed in %d ms", task.payload.ID, executionTime) wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", executionTime)
} }
if wp.callback != nil { if wp.callback != nil {
if err := wp.callback(ctx, result); err != nil { if err := wp.callback(ctx, result); err != nil {
atomic.AddInt64(&wp.metrics.ErrorCount, 1) atomic.AddInt64(&wp.metrics.ErrorCount, 1)
wp.logger.Printf("Error in callback for task %s: %v", task.payload.ID, err) wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Callback error: %v", err)
} }
} }
_ = wp.taskStorage.DeleteTask(task.payload.ID) _ = wp.taskStorage.DeleteTask(task.payload.ID)
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize) atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
wp.metricsRegistry.Register("task_execution_time", executionTime)
} }
func (wp *Pool) backoffAndStore(task *QueueTask) { func (wp *Pool) backoffAndStore(task *QueueTask) {
@@ -219,10 +371,11 @@ func (wp *Pool) backoffAndStore(task *QueueTask) {
backoff := wp.backoffDuration * (1 << (task.retryCount - 1)) backoff := wp.backoffDuration * (1 << (task.retryCount - 1))
jitter := time.Duration(rand.Int63n(int64(backoff) / 2)) jitter := time.Duration(rand.Int63n(int64(backoff) / 2))
sleepDuration := backoff + jitter sleepDuration := backoff + jitter
wp.logger.Printf("Task %s retry %d: sleeping for %s", task.payload.ID, task.retryCount, sleepDuration) wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Retry %d: sleeping for %s", task.retryCount, sleepDuration)
time.Sleep(sleepDuration) time.Sleep(sleepDuration)
} else { } else {
wp.logger.Printf("Task %s failed after maximum retries", task.payload.ID) wp.logger.Error().Str("taskID", task.payload.ID).Msg("Task failed after maximum retries")
wp.DLQ.Add(task)
} }
} }
@@ -282,7 +435,9 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
if wp.circuitBreaker.Enabled && wp.circuitBreakerOpen { if wp.circuitBreaker.Enabled && wp.circuitBreakerOpen {
return fmt.Errorf("circuit breaker open, task rejected") return fmt.Errorf("circuit breaker open, task rejected")
} }
if err := validateTaskInput(payload); err != nil {
return fmt.Errorf("invalid task input: %w", err)
}
if payload.ID == "" { if payload.ID == "" {
payload.ID = NewID() payload.ID = NewID()
} }
@@ -344,9 +499,8 @@ func (wp *Pool) startOverflowDrainer() {
func (wp *Pool) drainOverflowBuffer() { func (wp *Pool) drainOverflowBuffer() {
wp.overflowBufferLock.Lock() wp.overflowBufferLock.Lock()
overflowTasks := wp.overflowBuffer overflowTasks := wp.overflowBuffer
wp.overflowBuffer = nil // Clear buffer wp.overflowBuffer = nil
wp.overflowBufferLock.Unlock() wp.overflowBufferLock.Unlock()
for _, task := range overflowTasks { for _, task := range overflowTasks {
select { select {
case wp.taskNotify <- struct{}{}: case wp.taskNotify <- struct{}{}:
@@ -363,8 +517,6 @@ func (wp *Pool) Stop() {
wp.gracefulShutdown = true wp.gracefulShutdown = true
wp.Pause() wp.Pause()
close(wp.stop) close(wp.stop)
// Graceful shutdown with timeout support
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
wp.wg.Wait() wp.wg.Wait()
@@ -373,9 +525,8 @@ func (wp *Pool) Stop() {
}() }()
select { select {
case <-done: case <-done:
// All workers finished gracefully.
case <-time.After(wp.gracefulShutdownTimeout): case <-time.After(wp.gracefulShutdownTimeout):
wp.logger.Println("Graceful shutdown timeout reached") wp.logger.Warn().Msg("Graceful shutdown timeout reached")
} }
if wp.completionCallback != nil { if wp.completionCallback != nil {
wp.completionCallback() wp.completionCallback()
@@ -395,3 +546,54 @@ func (wp *Pool) Metrics() Metrics {
} }
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler } func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
func (wp *Pool) dynamicWorkerScaler() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
wp.taskQueueLock.Lock()
queueLen := len(wp.taskQueue)
wp.taskQueueLock.Unlock()
newWorkers := queueLen/5 + 1
wp.logger.Info().Msgf("Auto-scaling: queue length %d, adjusting workers to %d", queueLen, newWorkers)
wp.AdjustWorkerCount(newWorkers)
case <-wp.stop:
return
}
}
}
func (wp *Pool) startHealthServer() {
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
status := "OK"
if wp.gracefulShutdown {
status = "shutting down"
}
fmt.Fprintf(w, "status: %s\nworkers: %d\nqueueLength: %d\n",
status, atomic.LoadInt32(&wp.numOfWorkers), len(wp.taskQueue))
})
server := &http.Server{
Addr: ":8080",
Handler: mux,
}
go func() {
wp.logger.Info().Msg("Starting health server on :8080")
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
wp.logger.Error().Err(err).Msg("Health server failed")
}
}()
go func() {
<-wp.stop
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
wp.logger.Error().Err(err).Msg("Health server shutdown failed")
} else {
wp.logger.Info().Msg("Health server shutdown gracefully")
}
}()
}

View File

@@ -6,13 +6,14 @@ import (
// New type definitions for enhancements // New type definitions for enhancements
type ThresholdConfig struct { type ThresholdConfig struct {
HighMemory int64 // e.g. in bytes HighMemory int64
LongExecution time.Duration // e.g. warning if task execution exceeds LongExecution time.Duration
} }
type MetricsRegistry interface { type MetricsRegistry interface {
Register(metricName string, value interface{}) Register(metricName string, value interface{})
// ...other methods as needed... Increment(metricName string)
Get(metricName string) interface{}
} }
type CircuitBreakerConfig struct { type CircuitBreakerConfig struct {
@@ -25,7 +26,6 @@ type PoolOption func(*Pool)
func WithTaskQueueSize(size int) PoolOption { func WithTaskQueueSize(size int) PoolOption {
return func(p *Pool) { return func(p *Pool) {
// Initialize the task queue with the specified size
p.taskQueue = make(PriorityQueue, 0, size) p.taskQueue = make(PriorityQueue, 0, size)
} }
} }
@@ -72,8 +72,6 @@ func WithTaskStorage(storage TaskStorage) PoolOption {
} }
} }
// New option functions:
func WithWarningThresholds(thresholds ThresholdConfig) PoolOption { func WithWarningThresholds(thresholds ThresholdConfig) PoolOption {
return func(p *Pool) { return func(p *Pool) {
p.thresholds = thresholds p.thresholds = thresholds
@@ -103,3 +101,9 @@ func WithGracefulShutdown(timeout time.Duration) PoolOption {
p.gracefulShutdownTimeout = timeout p.gracefulShutdownTimeout = timeout
} }
} }
func WithPlugin(plugin Plugin) PoolOption {
return func(p *Pool) {
p.plugins = append(p.plugins, plugin)
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/oarkflow/json" "github.com/oarkflow/json"
"github.com/oarkflow/json/jsonparser" "github.com/oarkflow/json/jsonparser"
"github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
) )
@@ -77,16 +78,18 @@ func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to broker: %w", err) return fmt.Errorf("failed to connect to broker: %w", err)
} }
defer conn.Close() defer func() {
_ = conn.Close()
}()
return p.send(ctx, queue, task, conn, consts.PUBLISH) return p.send(ctx, queue, task, conn, consts.PUBLISH)
} }
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error { func (p *Publisher) onClose(_ context.Context, conn net.Conn) error {
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr()) fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
return nil return nil
} }
func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) { func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr()) fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
} }
@@ -99,7 +102,9 @@ func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result
err = fmt.Errorf("failed to connect to broker: %w", err) err = fmt.Errorf("failed to connect to broker: %w", err)
return Result{Error: err} return Result{Error: err}
} }
defer conn.Close() defer func() {
_ = conn.Close()
}()
err = p.send(ctx, queue, task, conn, consts.PUBLISH) err = p.send(ctx, queue, task, conn, consts.PUBLISH)
resultCh := make(chan Result) resultCh := make(chan Result)
go func() { go func() {

View File

@@ -45,38 +45,33 @@ func (b *Broker) NewQueue(name string) *Queue {
type QueueTask struct { type QueueTask struct {
ctx context.Context ctx context.Context
payload *Task payload *Task
retryCount int
priority int priority int
index int // The index in the heap retryCount int
index int
} }
type PriorityQueue []*QueueTask type PriorityQueue []*QueueTask
func (pq PriorityQueue) Len() int { return len(pq) } func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool { func (pq PriorityQueue) Less(i, j int) bool {
return pq[i].priority > pq[j].priority return pq[i].priority > pq[j].priority
} }
func (pq PriorityQueue) Swap(i, j int) { func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i] pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i pq[i].index = i
pq[j].index = j pq[j].index = j
} }
func (pq *PriorityQueue) Push(x interface{}) { func (pq *PriorityQueue) Push(x interface{}) {
n := len(*pq) n := len(*pq)
task := x.(*QueueTask) task := x.(*QueueTask)
task.index = n task.index = n
*pq = append(*pq, task) *pq = append(*pq, task)
} }
func (pq *PriorityQueue) Pop() interface{} { func (pq *PriorityQueue) Pop() interface{} {
old := *pq old := *pq
n := len(old) n := len(old)
task := old[n-1] task := old[n-1]
old[n-1] = nil // avoid memory leak task.index = -1
task.index = -1 // for safety
*pq = old[0 : n-1] *pq = old[0 : n-1]
return task return task
} }

View File

@@ -2,14 +2,19 @@ package mq
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/oarkflow/log"
) )
var Logger = log.DefaultLogger
type ScheduleOptions struct { type ScheduleOptions struct {
Handler Handler Handler Handler
Callback Callback Callback Callback
@@ -20,7 +25,7 @@ type ScheduleOptions struct {
type SchedulerOption func(*ScheduleOptions) type SchedulerOption func(*ScheduleOptions)
// Helper functions to create SchedulerOptions // WithSchedulerHandler Helper functions to create SchedulerOptions
func WithSchedulerHandler(handler Handler) SchedulerOption { func WithSchedulerHandler(handler Handler) SchedulerOption {
return func(opts *ScheduleOptions) { return func(opts *ScheduleOptions) {
opts.Handler = handler opts.Handler = handler
@@ -177,17 +182,15 @@ func parseCronSpec(cronSpec string) (CronSchedule, error) {
} }
func cronFieldToString(field string, fieldName string) (string, error) { func cronFieldToString(field string, fieldName string) (string, error) {
switch field { if field == "*" {
case "*":
return fmt.Sprintf("every %s", fieldName), nil return fmt.Sprintf("every %s", fieldName), nil
default: }
values, err := parseCronValue(field) values, err := parseCronValue(field)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error()) return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
} }
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
} }
}
func parseCronValue(field string) ([]string, error) { func parseCronValue(field string) ([]string, error) {
var values []string var values []string
@@ -223,6 +226,8 @@ type Scheduler struct {
} }
func (s *Scheduler) Start() { func (s *Scheduler) Start() {
s.mu.Lock()
defer s.mu.Unlock()
for _, task := range s.tasks { for _, task := range s.tasks {
go s.schedule(task) go s.schedule(task)
} }
@@ -279,46 +284,81 @@ func (s *Scheduler) schedule(task ScheduledTask) {
} }
} }
func startSpan(operation string) (context.Context, func()) {
startTime := time.Now()
Logger.Info().Str("operation", operation).Msg("Span started")
ctx := context.WithValue(context.Background(), "traceID", fmt.Sprintf("%d", startTime.UnixNano()))
return ctx, func() {
duration := time.Since(startTime)
Logger.Info().Str("operation", operation).Msgf("Span ended; duration: %v", duration)
}
}
func acquireDistributedLock(taskID string) bool {
Logger.Info().Str("taskID", taskID).Msg("Acquiring distributed lock (stub)")
return true
}
func releaseDistributedLock(taskID string) {
Logger.Info().Str("taskID", taskID).Msg("Releasing distributed lock (stub)")
}
var taskPool = sync.Pool{
New: func() interface{} { return new(Task) },
}
var queueTaskPool = sync.Pool{
New: func() interface{} { return new(QueueTask) },
}
func getQueueTask() *QueueTask {
return queueTaskPool.Get().(*QueueTask)
}
// Enhance executeTask with circuit breaker and diagnostics logging support. // Enhance executeTask with circuit breaker and diagnostics logging support.
func (s *Scheduler) executeTask(task ScheduledTask) { func (s *Scheduler) executeTask(task ScheduledTask) {
if !task.config.Overlap && !s.pool.gracefulShutdown {
// Prevent overlapping execution if not allowed.
// ...existing code...
}
go func() { go func() {
// Recover from panic to keep scheduler running. _, cancelSpan := startSpan("executeTask")
defer cancelSpan()
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
fmt.Printf("Recovered from panic in scheduled task %s: %v\n", task.payload.ID, r) Logger.Error().Str("taskID", task.payload.ID).Msgf("Recovered from panic: %v", r)
} }
}() }()
start := time.Now() start := time.Now()
for _, plug := range s.pool.plugins {
plug.BeforeTask(getQueueTask())
}
if !acquireDistributedLock(task.payload.ID) {
Logger.Warn().Str("taskID", task.payload.ID).Msg("Failed to acquire distributed lock")
return
}
defer releaseDistributedLock(task.payload.ID)
result := task.handler(task.ctx, task.payload) result := task.handler(task.ctx, task.payload)
execTime := time.Since(start).Milliseconds() execTime := time.Since(start).Milliseconds()
// Diagnostics logging if enabled.
if s.pool.diagnosticsEnabled { if s.pool.diagnosticsEnabled {
s.pool.logger.Printf("Scheduled task %s executed in %d ms", task.payload.ID, execTime) Logger.Info().Str("taskID", task.payload.ID).Msgf("Executed in %d ms", execTime)
} }
// Circuit breaker check similar to pool.
if result.Error != nil && s.pool.circuitBreaker.Enabled { if result.Error != nil && s.pool.circuitBreaker.Enabled {
newCount := atomic.AddInt32(&s.pool.circuitBreakerFailureCount, 1) newCount := atomic.AddInt32(&s.pool.circuitBreakerFailureCount, 1)
if newCount >= int32(s.pool.circuitBreaker.FailureThreshold) { if newCount >= int32(s.pool.circuitBreaker.FailureThreshold) {
s.pool.circuitBreakerOpen = true s.pool.circuitBreakerOpen = true
s.pool.logger.Println("Scheduler: circuit breaker opened due to errors") Logger.Warn().Msg("Circuit breaker opened due to errors")
go func() { go func() {
time.Sleep(s.pool.circuitBreaker.ResetTimeout) time.Sleep(s.pool.circuitBreaker.ResetTimeout)
atomic.StoreInt32(&s.pool.circuitBreakerFailureCount, 0) atomic.StoreInt32(&s.pool.circuitBreakerFailureCount, 0)
s.pool.circuitBreakerOpen = false s.pool.circuitBreakerOpen = false
s.pool.logger.Println("Scheduler: circuit breaker reset") Logger.Info().Msg("Circuit breaker reset to closed state")
}() }()
} }
} }
// Invoke callback if defined.
if task.config.Callback != nil { if task.config.Callback != nil {
_ = task.config.Callback(task.ctx, result) _ = task.config.Callback(task.ctx, result)
} }
task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result}) task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result})
fmt.Printf("Executed scheduled task: %s\n", task.payload.ID) for _, plug := range s.pool.plugins {
plug.AfterTask(getQueueTask(), result)
}
Logger.Info().Str("taskID", task.payload.ID).Msg("Scheduled task executed")
}() }()
} }
@@ -354,7 +394,7 @@ func (task ScheduledTask) getNextRunTime(now time.Time) time.Time {
func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time { func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
cronSpecs, err := parseCronSpec(task.schedule.CronSpec) cronSpecs, err := parseCronSpec(task.schedule.CronSpec)
if err != nil { if err != nil {
fmt.Println(fmt.Sprintf("Invalid CRON spec format: %s", err.Error())) Logger.Error().Err(err).Msg("Invalid CRON spec")
return now return now
} }
nextRun := now nextRun := now
@@ -367,10 +407,9 @@ func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
} }
func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time { func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time {
switch fieldSpec { if fieldSpec == "*" {
case "*":
return t return t
default: }
value, _ := strconv.Atoi(fieldSpec) value, _ := strconv.Atoi(fieldSpec)
switch unit { switch unit {
case "minute": case "minute":
@@ -401,7 +440,6 @@ func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit str
} }
return t return t
} }
}
func nextWeekday(t time.Time, weekday time.Weekday) time.Time { func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
daysUntil := (int(weekday) - int(t.Weekday()) + 7) % 7 daysUntil := (int(weekday) - int(t.Weekday()) + 7) % 7
@@ -410,12 +448,21 @@ func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
} }
return t.AddDate(0, 0, daysUntil) return t.AddDate(0, 0, daysUntil)
} }
func validateTaskInput(task *Task) error {
if task.Payload == nil {
return errors.New("task payload cannot be nil")
}
Logger.Info().Str("taskID", task.ID).Msg("Task validated")
return nil
}
func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) { func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if err := validateTaskInput(payload); err != nil {
// Create a default options instance Logger.Error().Err(err).Msg("Invalid task input")
return
}
options := defaultSchedulerOptions() options := defaultSchedulerOptions()
for _, opt := range opts { for _, opt := range opts {
opt(options) opt(options)
@@ -427,9 +474,7 @@ func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...Schedule
options.Callback = s.pool.callback options.Callback = s.pool.callback
} }
stop := make(chan struct{}) stop := make(chan struct{})
newTask := ScheduledTask{
// Create a new ScheduledTask using the provided options
s.tasks = append(s.tasks, ScheduledTask{
ctx: ctx, ctx: ctx,
handler: options.Handler, handler: options.Handler,
payload: payload, payload: payload,
@@ -442,10 +487,9 @@ func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...Schedule
Interval: options.Interval, Interval: options.Interval,
Recurring: options.Recurring, Recurring: options.Recurring,
}, },
}) }
s.tasks = append(s.tasks, newTask)
// Start scheduling the task go s.schedule(newTask)
go s.schedule(s.tasks[len(s.tasks)-1])
} }
func (s *Scheduler) RemoveTask(payloadID string) { func (s *Scheduler) RemoveTask(payloadID string) {

1289
v1/v1.go

File diff suppressed because it is too large Load Diff