update: use oarkflow/json

This commit is contained in:
Oarkflow
2025-03-29 14:31:00 +05:45
parent d0774ba2e8
commit abc07e9360
7 changed files with 397 additions and 1434 deletions

View File

@@ -8,14 +8,16 @@ import (
"syscall" "syscall"
"time" "time"
v1 "github.com/oarkflow/mq/v1" "github.com/oarkflow/json"
v1 "github.com/oarkflow/mq"
) )
func main() { func main() {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop() defer stop()
pool := v1.NewPool(5, pool := v1.NewPool(5,
v1.WithTaskStorage(v1.NewInMemoryTaskStorage()), v1.WithTaskStorage(v1.NewMemoryTaskStorage(10*time.Minute)),
v1.WithHandler(func(ctx context.Context, payload *v1.Task) v1.Result { v1.WithHandler(func(ctx context.Context, payload *v1.Task) v1.Result {
v1.Logger.Info().Str("taskID", payload.ID).Msg("Processing task payload") v1.Logger.Info().Str("taskID", payload.ID).Msg("Processing task payload")
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
@@ -43,14 +45,14 @@ func main() {
metrics := pool.Metrics() metrics := pool.Metrics()
v1.Logger.Info().Msgf("Metrics: %+v", metrics) v1.Logger.Info().Msgf("Metrics: %+v", metrics)
pool.Stop() pool.Stop()
v1.Logger.Info().Msgf("Dead Letter Queue has %d tasks", len(v1.DLQ.Task())) v1.Logger.Info().Msgf("Dead Letter Queue has %d tasks", len(pool.DLQ.Task()))
}() }()
go func() { go func() {
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
task := &v1.Task{ task := &v1.Task{
ID: "", ID: "",
Payload: fmt.Sprintf("Task Payload %d", i), Payload: json.RawMessage(fmt.Sprintf("Task Payload %d", i)),
} }
if err := pool.EnqueueTask(context.Background(), task, rand.Intn(10)); err != nil { if err := pool.EnqueueTask(context.Background(), task, rand.Intn(10)); err != nil {
v1.Logger.Error().Err(err).Msg("Failed to enqueue task") v1.Logger.Error().Err(err).Msg("Failed to enqueue task")

320
pool.go
View File

@@ -3,14 +3,17 @@ package mq
import ( import (
"container/heap" "container/heap"
"context" "context"
"errors"
"fmt" "fmt"
"log"
"math/rand" "math/rand"
"net/http"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/oarkflow/mq/utils" "github.com/oarkflow/mq/utils"
"github.com/oarkflow/log"
) )
type Callback func(ctx context.Context, result Result) error type Callback func(ctx context.Context, result Result) error
@@ -25,35 +28,134 @@ type Metrics struct {
ExecutionTime int64 ExecutionTime int64
} }
type Pool struct { type Plugin interface {
taskStorage TaskStorage Initialize(config interface{}) error
scheduler *Scheduler BeforeTask(task *QueueTask)
stop chan struct{} AfterTask(task *QueueTask, result Result)
taskNotify chan struct{} }
workerAdjust chan int
handler Handler
completionCallback CompletionCallback
taskAvailableCond *sync.Cond
callback Callback
taskQueue PriorityQueue
overflowBuffer []*QueueTask
metrics Metrics
wg sync.WaitGroup
taskCompletionNotifier sync.WaitGroup
timeout time.Duration
batchSize int
maxMemoryLoad int64
idleTimeout time.Duration
backoffDuration time.Duration
maxRetries int
overflowBufferLock sync.RWMutex
taskQueueLock sync.Mutex
numOfWorkers int32
paused bool
logger *log.Logger
gracefulShutdown bool
// New fields for enhancements: type DefaultPlugin struct{}
func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil }
func (dp *DefaultPlugin) BeforeTask(task *QueueTask) {
Logger.Info().Str("taskID", task.payload.ID).Msg("BeforeTask plugin invoked")
}
func (dp *DefaultPlugin) AfterTask(task *QueueTask, result Result) {
Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked")
}
type DeadLetterQueue struct {
tasks []*QueueTask
mu sync.Mutex
}
func NewDeadLetterQueue() *DeadLetterQueue {
return &DeadLetterQueue{
tasks: make([]*QueueTask, 0),
}
}
func (dlq *DeadLetterQueue) Task() []*QueueTask {
return dlq.tasks
}
func (dlq *DeadLetterQueue) Add(task *QueueTask) {
dlq.mu.Lock()
defer dlq.mu.Unlock()
dlq.tasks = append(dlq.tasks, task)
Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue")
}
type InMemoryMetricsRegistry struct {
metrics map[string]int64
mu sync.RWMutex
}
func NewInMemoryMetricsRegistry() *InMemoryMetricsRegistry {
return &InMemoryMetricsRegistry{
metrics: make(map[string]int64),
}
}
func (m *InMemoryMetricsRegistry) Register(metricName string, value interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
if v, ok := value.(int64); ok {
m.metrics[metricName] = v
Logger.Info().Str("metric", metricName).Msgf("Registered metric: %d", v)
}
}
func (m *InMemoryMetricsRegistry) Increment(metricName string) {
m.mu.Lock()
defer m.mu.Unlock()
m.metrics[metricName]++
}
func (m *InMemoryMetricsRegistry) Get(metricName string) interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
return m.metrics[metricName]
}
type WarningThresholds struct {
HighMemory int64
LongExecution time.Duration
}
type DynamicConfig struct {
Timeout time.Duration
BatchSize int
MaxMemoryLoad int64
IdleTimeout time.Duration
BackoffDuration time.Duration
MaxRetries int
ReloadInterval time.Duration
WarningThreshold WarningThresholds
}
var Config = &DynamicConfig{
Timeout: 10 * time.Second,
BatchSize: 1,
MaxMemoryLoad: 100 * 1024 * 1024,
IdleTimeout: 5 * time.Minute,
BackoffDuration: 2 * time.Second,
MaxRetries: 3,
ReloadInterval: 30 * time.Second,
WarningThreshold: WarningThresholds{
HighMemory: 1 * 1024 * 1024,
LongExecution: 2 * time.Second,
},
}
type Pool struct {
taskStorage TaskStorage
scheduler *Scheduler
stop chan struct{}
taskNotify chan struct{}
workerAdjust chan int
handler Handler
completionCallback CompletionCallback
taskAvailableCond *sync.Cond
callback Callback
DLQ *DeadLetterQueue
taskQueue PriorityQueue
overflowBuffer []*QueueTask
metrics Metrics
wg sync.WaitGroup
taskCompletionNotifier sync.WaitGroup
timeout time.Duration
batchSize int
maxMemoryLoad int64
idleTimeout time.Duration
backoffDuration time.Duration
maxRetries int
overflowBufferLock sync.RWMutex
taskQueueLock sync.Mutex
numOfWorkers int32
paused bool
logger log.Logger
gracefulShutdown bool
thresholds ThresholdConfig thresholds ThresholdConfig
diagnosticsEnabled bool diagnosticsEnabled bool
metricsRegistry MetricsRegistry metricsRegistry MetricsRegistry
@@ -61,33 +163,83 @@ type Pool struct {
circuitBreakerOpen bool circuitBreakerOpen bool
circuitBreakerFailureCount int32 circuitBreakerFailureCount int32
gracefulShutdownTimeout time.Duration gracefulShutdownTimeout time.Duration
plugins []Plugin
} }
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool { func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
pool := &Pool{ pool := &Pool{
stop: make(chan struct{}), stop: make(chan struct{}),
taskNotify: make(chan struct{}, numOfWorkers), taskNotify: make(chan struct{}, numOfWorkers),
batchSize: 1, batchSize: Config.BatchSize,
timeout: 10 * time.Second, timeout: Config.Timeout,
idleTimeout: 5 * time.Minute, idleTimeout: Config.IdleTimeout,
backoffDuration: 2 * time.Second, backoffDuration: Config.BackoffDuration,
maxRetries: 3, // Set max retries for failed tasks maxRetries: Config.MaxRetries,
logger: log.Default(), logger: Logger,
DLQ: NewDeadLetterQueue(),
metricsRegistry: NewInMemoryMetricsRegistry(),
diagnosticsEnabled: true,
gracefulShutdownTimeout: 10 * time.Second,
} }
pool.scheduler = NewScheduler(pool) pool.scheduler = NewScheduler(pool)
pool.taskAvailableCond = sync.NewCond(&sync.Mutex{}) pool.taskAvailableCond = sync.NewCond(&sync.Mutex{})
for _, opt := range opts { for _, opt := range opts {
opt(pool) opt(pool)
} }
if len(pool.taskQueue) == 0 { if pool.taskQueue == nil {
pool.taskQueue = make(PriorityQueue, 0, 10) pool.taskQueue = make(PriorityQueue, 0, 10)
} }
heap.Init(&pool.taskQueue) heap.Init(&pool.taskQueue)
pool.scheduler.Start() pool.scheduler.Start()
pool.Start(numOfWorkers) pool.Start(numOfWorkers)
startConfigReloader(pool)
go pool.dynamicWorkerScaler()
go pool.startHealthServer()
return pool return pool
} }
func validateDynamicConfig(c *DynamicConfig) error {
if c.Timeout <= 0 {
return errors.New("Timeout must be positive")
}
if c.BatchSize <= 0 {
return errors.New("BatchSize must be > 0")
}
if c.MaxMemoryLoad <= 0 {
return errors.New("MaxMemoryLoad must be > 0")
}
return nil
}
func startConfigReloader(pool *Pool) {
go func() {
ticker := time.NewTicker(Config.ReloadInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := validateDynamicConfig(Config); err != nil {
Logger.Error().Err(err).Msg("Invalid dynamic config, skipping reload")
continue
}
pool.timeout = Config.Timeout
pool.batchSize = Config.BatchSize
pool.maxMemoryLoad = Config.MaxMemoryLoad
pool.idleTimeout = Config.IdleTimeout
pool.backoffDuration = Config.BackoffDuration
pool.maxRetries = Config.MaxRetries
pool.thresholds = ThresholdConfig{
HighMemory: Config.WarningThreshold.HighMemory,
LongExecution: Config.WarningThreshold.LongExecution,
}
Logger.Info().Msg("Dynamic configuration reloaded")
case <-pool.stop:
return
}
}
}()
}
func (wp *Pool) Start(numWorkers int) { func (wp *Pool) Start(numWorkers int) {
storedTasks, err := wp.taskStorage.GetAllTasks() storedTasks, err := wp.taskStorage.GetAllTasks()
if err == nil { if err == nil {
@@ -144,13 +296,13 @@ func (wp *Pool) processNextBatch() {
wp.handleTask(task) wp.handleTask(task)
} }
} }
// @TODO - Why was this done?
//if len(tasks) > 0 {
// wp.taskCompletionNotifier.Done()
//}
} }
func (wp *Pool) handleTask(task *QueueTask) { func (wp *Pool) handleTask(task *QueueTask) {
if err := validateTaskInput(task.payload); err != nil {
wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Validation failed: %v", err)
return
}
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout) ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
defer cancel() defer cancel()
taskSize := int64(utils.SizeOf(task.payload)) taskSize := int64(utils.SizeOf(task.payload))
@@ -163,15 +315,15 @@ func (wp *Pool) handleTask(task *QueueTask) {
// Warning thresholds check // Warning thresholds check
if wp.thresholds.LongExecution > 0 && executionTime > int64(wp.thresholds.LongExecution.Milliseconds()) { if wp.thresholds.LongExecution > 0 && executionTime > int64(wp.thresholds.LongExecution.Milliseconds()) {
wp.logger.Printf("Warning: Task %s exceeded execution time threshold: %d ms", task.payload.ID, executionTime) wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", executionTime)
} }
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory { if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
wp.logger.Printf("Warning: Task %s memory usage %d exceeded threshold", task.payload.ID, taskSize) wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize)
} }
if result.Error != nil { if result.Error != nil {
atomic.AddInt64(&wp.metrics.ErrorCount, 1) atomic.AddInt64(&wp.metrics.ErrorCount, 1)
wp.logger.Printf("Error processing task %s: %v", task.payload.ID, result.Error) wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Error processing task: %v", result.Error)
wp.backoffAndStore(task) wp.backoffAndStore(task)
// Circuit breaker check // Circuit breaker check
@@ -179,12 +331,12 @@ func (wp *Pool) handleTask(task *QueueTask) {
newCount := atomic.AddInt32(&wp.circuitBreakerFailureCount, 1) newCount := atomic.AddInt32(&wp.circuitBreakerFailureCount, 1)
if newCount >= int32(wp.circuitBreaker.FailureThreshold) { if newCount >= int32(wp.circuitBreaker.FailureThreshold) {
wp.circuitBreakerOpen = true wp.circuitBreakerOpen = true
wp.logger.Println("Circuit breaker opened due to errors") wp.logger.Warn().Msg("Circuit breaker opened due to errors")
go func() { go func() {
time.Sleep(wp.circuitBreaker.ResetTimeout) time.Sleep(wp.circuitBreaker.ResetTimeout)
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0) atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
wp.circuitBreakerOpen = false wp.circuitBreakerOpen = false
wp.logger.Println("Circuit breaker reset") wp.logger.Info().Msg("Circuit breaker reset to closed state")
}() }()
} }
} }
@@ -198,17 +350,17 @@ func (wp *Pool) handleTask(task *QueueTask) {
// Diagnostics logging if enabled // Diagnostics logging if enabled
if wp.diagnosticsEnabled { if wp.diagnosticsEnabled {
wp.logger.Printf("Task %s executed in %d ms", task.payload.ID, executionTime) wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", executionTime)
} }
if wp.callback != nil { if wp.callback != nil {
if err := wp.callback(ctx, result); err != nil { if err := wp.callback(ctx, result); err != nil {
atomic.AddInt64(&wp.metrics.ErrorCount, 1) atomic.AddInt64(&wp.metrics.ErrorCount, 1)
wp.logger.Printf("Error in callback for task %s: %v", task.payload.ID, err) wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Callback error: %v", err)
} }
} }
_ = wp.taskStorage.DeleteTask(task.payload.ID) _ = wp.taskStorage.DeleteTask(task.payload.ID)
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize) atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
wp.metricsRegistry.Register("task_execution_time", executionTime)
} }
func (wp *Pool) backoffAndStore(task *QueueTask) { func (wp *Pool) backoffAndStore(task *QueueTask) {
@@ -219,10 +371,11 @@ func (wp *Pool) backoffAndStore(task *QueueTask) {
backoff := wp.backoffDuration * (1 << (task.retryCount - 1)) backoff := wp.backoffDuration * (1 << (task.retryCount - 1))
jitter := time.Duration(rand.Int63n(int64(backoff) / 2)) jitter := time.Duration(rand.Int63n(int64(backoff) / 2))
sleepDuration := backoff + jitter sleepDuration := backoff + jitter
wp.logger.Printf("Task %s retry %d: sleeping for %s", task.payload.ID, task.retryCount, sleepDuration) wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Retry %d: sleeping for %s", task.retryCount, sleepDuration)
time.Sleep(sleepDuration) time.Sleep(sleepDuration)
} else { } else {
wp.logger.Printf("Task %s failed after maximum retries", task.payload.ID) wp.logger.Error().Str("taskID", task.payload.ID).Msg("Task failed after maximum retries")
wp.DLQ.Add(task)
} }
} }
@@ -282,7 +435,9 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
if wp.circuitBreaker.Enabled && wp.circuitBreakerOpen { if wp.circuitBreaker.Enabled && wp.circuitBreakerOpen {
return fmt.Errorf("circuit breaker open, task rejected") return fmt.Errorf("circuit breaker open, task rejected")
} }
if err := validateTaskInput(payload); err != nil {
return fmt.Errorf("invalid task input: %w", err)
}
if payload.ID == "" { if payload.ID == "" {
payload.ID = NewID() payload.ID = NewID()
} }
@@ -344,9 +499,8 @@ func (wp *Pool) startOverflowDrainer() {
func (wp *Pool) drainOverflowBuffer() { func (wp *Pool) drainOverflowBuffer() {
wp.overflowBufferLock.Lock() wp.overflowBufferLock.Lock()
overflowTasks := wp.overflowBuffer overflowTasks := wp.overflowBuffer
wp.overflowBuffer = nil // Clear buffer wp.overflowBuffer = nil
wp.overflowBufferLock.Unlock() wp.overflowBufferLock.Unlock()
for _, task := range overflowTasks { for _, task := range overflowTasks {
select { select {
case wp.taskNotify <- struct{}{}: case wp.taskNotify <- struct{}{}:
@@ -363,8 +517,6 @@ func (wp *Pool) Stop() {
wp.gracefulShutdown = true wp.gracefulShutdown = true
wp.Pause() wp.Pause()
close(wp.stop) close(wp.stop)
// Graceful shutdown with timeout support
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
wp.wg.Wait() wp.wg.Wait()
@@ -373,9 +525,8 @@ func (wp *Pool) Stop() {
}() }()
select { select {
case <-done: case <-done:
// All workers finished gracefully.
case <-time.After(wp.gracefulShutdownTimeout): case <-time.After(wp.gracefulShutdownTimeout):
wp.logger.Println("Graceful shutdown timeout reached") wp.logger.Warn().Msg("Graceful shutdown timeout reached")
} }
if wp.completionCallback != nil { if wp.completionCallback != nil {
wp.completionCallback() wp.completionCallback()
@@ -395,3 +546,54 @@ func (wp *Pool) Metrics() Metrics {
} }
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler } func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
func (wp *Pool) dynamicWorkerScaler() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
wp.taskQueueLock.Lock()
queueLen := len(wp.taskQueue)
wp.taskQueueLock.Unlock()
newWorkers := queueLen/5 + 1
wp.logger.Info().Msgf("Auto-scaling: queue length %d, adjusting workers to %d", queueLen, newWorkers)
wp.AdjustWorkerCount(newWorkers)
case <-wp.stop:
return
}
}
}
func (wp *Pool) startHealthServer() {
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
status := "OK"
if wp.gracefulShutdown {
status = "shutting down"
}
fmt.Fprintf(w, "status: %s\nworkers: %d\nqueueLength: %d\n",
status, atomic.LoadInt32(&wp.numOfWorkers), len(wp.taskQueue))
})
server := &http.Server{
Addr: ":8080",
Handler: mux,
}
go func() {
wp.logger.Info().Msg("Starting health server on :8080")
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
wp.logger.Error().Err(err).Msg("Health server failed")
}
}()
go func() {
<-wp.stop
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
wp.logger.Error().Err(err).Msg("Health server shutdown failed")
} else {
wp.logger.Info().Msg("Health server shutdown gracefully")
}
}()
}

View File

@@ -6,13 +6,14 @@ import (
// New type definitions for enhancements // New type definitions for enhancements
type ThresholdConfig struct { type ThresholdConfig struct {
HighMemory int64 // e.g. in bytes HighMemory int64
LongExecution time.Duration // e.g. warning if task execution exceeds LongExecution time.Duration
} }
type MetricsRegistry interface { type MetricsRegistry interface {
Register(metricName string, value interface{}) Register(metricName string, value interface{})
// ...other methods as needed... Increment(metricName string)
Get(metricName string) interface{}
} }
type CircuitBreakerConfig struct { type CircuitBreakerConfig struct {
@@ -25,7 +26,6 @@ type PoolOption func(*Pool)
func WithTaskQueueSize(size int) PoolOption { func WithTaskQueueSize(size int) PoolOption {
return func(p *Pool) { return func(p *Pool) {
// Initialize the task queue with the specified size
p.taskQueue = make(PriorityQueue, 0, size) p.taskQueue = make(PriorityQueue, 0, size)
} }
} }
@@ -72,8 +72,6 @@ func WithTaskStorage(storage TaskStorage) PoolOption {
} }
} }
// New option functions:
func WithWarningThresholds(thresholds ThresholdConfig) PoolOption { func WithWarningThresholds(thresholds ThresholdConfig) PoolOption {
return func(p *Pool) { return func(p *Pool) {
p.thresholds = thresholds p.thresholds = thresholds
@@ -103,3 +101,9 @@ func WithGracefulShutdown(timeout time.Duration) PoolOption {
p.gracefulShutdownTimeout = timeout p.gracefulShutdownTimeout = timeout
} }
} }
func WithPlugin(plugin Plugin) PoolOption {
return func(p *Pool) {
p.plugins = append(p.plugins, plugin)
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/oarkflow/json" "github.com/oarkflow/json"
"github.com/oarkflow/json/jsonparser" "github.com/oarkflow/json/jsonparser"
"github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
) )
@@ -77,16 +78,18 @@ func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to broker: %w", err) return fmt.Errorf("failed to connect to broker: %w", err)
} }
defer conn.Close() defer func() {
_ = conn.Close()
}()
return p.send(ctx, queue, task, conn, consts.PUBLISH) return p.send(ctx, queue, task, conn, consts.PUBLISH)
} }
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error { func (p *Publisher) onClose(_ context.Context, conn net.Conn) error {
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr()) fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
return nil return nil
} }
func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) { func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr()) fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
} }
@@ -99,7 +102,9 @@ func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result
err = fmt.Errorf("failed to connect to broker: %w", err) err = fmt.Errorf("failed to connect to broker: %w", err)
return Result{Error: err} return Result{Error: err}
} }
defer conn.Close() defer func() {
_ = conn.Close()
}()
err = p.send(ctx, queue, task, conn, consts.PUBLISH) err = p.send(ctx, queue, task, conn, consts.PUBLISH)
resultCh := make(chan Result) resultCh := make(chan Result)
go func() { go func() {

View File

@@ -45,38 +45,33 @@ func (b *Broker) NewQueue(name string) *Queue {
type QueueTask struct { type QueueTask struct {
ctx context.Context ctx context.Context
payload *Task payload *Task
retryCount int
priority int priority int
index int // The index in the heap retryCount int
index int
} }
type PriorityQueue []*QueueTask type PriorityQueue []*QueueTask
func (pq PriorityQueue) Len() int { return len(pq) } func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool { func (pq PriorityQueue) Less(i, j int) bool {
return pq[i].priority > pq[j].priority return pq[i].priority > pq[j].priority
} }
func (pq PriorityQueue) Swap(i, j int) { func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i] pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i pq[i].index = i
pq[j].index = j pq[j].index = j
} }
func (pq *PriorityQueue) Push(x interface{}) { func (pq *PriorityQueue) Push(x interface{}) {
n := len(*pq) n := len(*pq)
task := x.(*QueueTask) task := x.(*QueueTask)
task.index = n task.index = n
*pq = append(*pq, task) *pq = append(*pq, task)
} }
func (pq *PriorityQueue) Pop() interface{} { func (pq *PriorityQueue) Pop() interface{} {
old := *pq old := *pq
n := len(old) n := len(old)
task := old[n-1] task := old[n-1]
old[n-1] = nil // avoid memory leak task.index = -1
task.index = -1 // for safety
*pq = old[0 : n-1] *pq = old[0 : n-1]
return task return task
} }

View File

@@ -2,14 +2,19 @@ package mq
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/oarkflow/log"
) )
var Logger = log.DefaultLogger
type ScheduleOptions struct { type ScheduleOptions struct {
Handler Handler Handler Handler
Callback Callback Callback Callback
@@ -20,7 +25,7 @@ type ScheduleOptions struct {
type SchedulerOption func(*ScheduleOptions) type SchedulerOption func(*ScheduleOptions)
// Helper functions to create SchedulerOptions // WithSchedulerHandler Helper functions to create SchedulerOptions
func WithSchedulerHandler(handler Handler) SchedulerOption { func WithSchedulerHandler(handler Handler) SchedulerOption {
return func(opts *ScheduleOptions) { return func(opts *ScheduleOptions) {
opts.Handler = handler opts.Handler = handler
@@ -177,16 +182,14 @@ func parseCronSpec(cronSpec string) (CronSchedule, error) {
} }
func cronFieldToString(field string, fieldName string) (string, error) { func cronFieldToString(field string, fieldName string) (string, error) {
switch field { if field == "*" {
case "*":
return fmt.Sprintf("every %s", fieldName), nil return fmt.Sprintf("every %s", fieldName), nil
default:
values, err := parseCronValue(field)
if err != nil {
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
}
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
} }
values, err := parseCronValue(field)
if err != nil {
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
}
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
} }
func parseCronValue(field string) ([]string, error) { func parseCronValue(field string) ([]string, error) {
@@ -223,6 +226,8 @@ type Scheduler struct {
} }
func (s *Scheduler) Start() { func (s *Scheduler) Start() {
s.mu.Lock()
defer s.mu.Unlock()
for _, task := range s.tasks { for _, task := range s.tasks {
go s.schedule(task) go s.schedule(task)
} }
@@ -279,46 +284,81 @@ func (s *Scheduler) schedule(task ScheduledTask) {
} }
} }
func startSpan(operation string) (context.Context, func()) {
startTime := time.Now()
Logger.Info().Str("operation", operation).Msg("Span started")
ctx := context.WithValue(context.Background(), "traceID", fmt.Sprintf("%d", startTime.UnixNano()))
return ctx, func() {
duration := time.Since(startTime)
Logger.Info().Str("operation", operation).Msgf("Span ended; duration: %v", duration)
}
}
func acquireDistributedLock(taskID string) bool {
Logger.Info().Str("taskID", taskID).Msg("Acquiring distributed lock (stub)")
return true
}
func releaseDistributedLock(taskID string) {
Logger.Info().Str("taskID", taskID).Msg("Releasing distributed lock (stub)")
}
var taskPool = sync.Pool{
New: func() interface{} { return new(Task) },
}
var queueTaskPool = sync.Pool{
New: func() interface{} { return new(QueueTask) },
}
func getQueueTask() *QueueTask {
return queueTaskPool.Get().(*QueueTask)
}
// Enhance executeTask with circuit breaker and diagnostics logging support. // Enhance executeTask with circuit breaker and diagnostics logging support.
func (s *Scheduler) executeTask(task ScheduledTask) { func (s *Scheduler) executeTask(task ScheduledTask) {
if !task.config.Overlap && !s.pool.gracefulShutdown {
// Prevent overlapping execution if not allowed.
// ...existing code...
}
go func() { go func() {
// Recover from panic to keep scheduler running. _, cancelSpan := startSpan("executeTask")
defer cancelSpan()
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
fmt.Printf("Recovered from panic in scheduled task %s: %v\n", task.payload.ID, r) Logger.Error().Str("taskID", task.payload.ID).Msgf("Recovered from panic: %v", r)
} }
}() }()
start := time.Now() start := time.Now()
for _, plug := range s.pool.plugins {
plug.BeforeTask(getQueueTask())
}
if !acquireDistributedLock(task.payload.ID) {
Logger.Warn().Str("taskID", task.payload.ID).Msg("Failed to acquire distributed lock")
return
}
defer releaseDistributedLock(task.payload.ID)
result := task.handler(task.ctx, task.payload) result := task.handler(task.ctx, task.payload)
execTime := time.Since(start).Milliseconds() execTime := time.Since(start).Milliseconds()
// Diagnostics logging if enabled.
if s.pool.diagnosticsEnabled { if s.pool.diagnosticsEnabled {
s.pool.logger.Printf("Scheduled task %s executed in %d ms", task.payload.ID, execTime) Logger.Info().Str("taskID", task.payload.ID).Msgf("Executed in %d ms", execTime)
} }
// Circuit breaker check similar to pool.
if result.Error != nil && s.pool.circuitBreaker.Enabled { if result.Error != nil && s.pool.circuitBreaker.Enabled {
newCount := atomic.AddInt32(&s.pool.circuitBreakerFailureCount, 1) newCount := atomic.AddInt32(&s.pool.circuitBreakerFailureCount, 1)
if newCount >= int32(s.pool.circuitBreaker.FailureThreshold) { if newCount >= int32(s.pool.circuitBreaker.FailureThreshold) {
s.pool.circuitBreakerOpen = true s.pool.circuitBreakerOpen = true
s.pool.logger.Println("Scheduler: circuit breaker opened due to errors") Logger.Warn().Msg("Circuit breaker opened due to errors")
go func() { go func() {
time.Sleep(s.pool.circuitBreaker.ResetTimeout) time.Sleep(s.pool.circuitBreaker.ResetTimeout)
atomic.StoreInt32(&s.pool.circuitBreakerFailureCount, 0) atomic.StoreInt32(&s.pool.circuitBreakerFailureCount, 0)
s.pool.circuitBreakerOpen = false s.pool.circuitBreakerOpen = false
s.pool.logger.Println("Scheduler: circuit breaker reset") Logger.Info().Msg("Circuit breaker reset to closed state")
}() }()
} }
} }
// Invoke callback if defined.
if task.config.Callback != nil { if task.config.Callback != nil {
_ = task.config.Callback(task.ctx, result) _ = task.config.Callback(task.ctx, result)
} }
task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result}) task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result})
fmt.Printf("Executed scheduled task: %s\n", task.payload.ID) for _, plug := range s.pool.plugins {
plug.AfterTask(getQueueTask(), result)
}
Logger.Info().Str("taskID", task.payload.ID).Msg("Scheduled task executed")
}() }()
} }
@@ -354,7 +394,7 @@ func (task ScheduledTask) getNextRunTime(now time.Time) time.Time {
func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time { func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
cronSpecs, err := parseCronSpec(task.schedule.CronSpec) cronSpecs, err := parseCronSpec(task.schedule.CronSpec)
if err != nil { if err != nil {
fmt.Println(fmt.Sprintf("Invalid CRON spec format: %s", err.Error())) Logger.Error().Err(err).Msg("Invalid CRON spec")
return now return now
} }
nextRun := now nextRun := now
@@ -367,40 +407,38 @@ func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
} }
func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time { func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time {
switch fieldSpec { if fieldSpec == "*" {
case "*":
return t
default:
value, _ := strconv.Atoi(fieldSpec)
switch unit {
case "minute":
if t.Minute() > value {
t = t.Add(time.Hour)
}
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), value, 0, 0, t.Location())
case "hour":
if t.Hour() > value {
t = t.AddDate(0, 0, 1)
}
t = time.Date(t.Year(), t.Month(), t.Day(), value, t.Minute(), 0, 0, t.Location())
case "day":
if t.Day() > value {
t = t.AddDate(0, 1, 0)
}
t = time.Date(t.Year(), t.Month(), value, t.Hour(), t.Minute(), 0, 0, t.Location())
case "month":
if int(t.Month()) > value {
t = t.AddDate(1, 0, 0)
}
t = time.Date(t.Year(), time.Month(value), t.Day(), t.Hour(), t.Minute(), 0, 0, t.Location())
case "weekday":
weekday := time.Weekday(value)
for t.Weekday() != weekday {
t = t.AddDate(0, 0, 1)
}
}
return t return t
} }
value, _ := strconv.Atoi(fieldSpec)
switch unit {
case "minute":
if t.Minute() > value {
t = t.Add(time.Hour)
}
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), value, 0, 0, t.Location())
case "hour":
if t.Hour() > value {
t = t.AddDate(0, 0, 1)
}
t = time.Date(t.Year(), t.Month(), t.Day(), value, t.Minute(), 0, 0, t.Location())
case "day":
if t.Day() > value {
t = t.AddDate(0, 1, 0)
}
t = time.Date(t.Year(), t.Month(), value, t.Hour(), t.Minute(), 0, 0, t.Location())
case "month":
if int(t.Month()) > value {
t = t.AddDate(1, 0, 0)
}
t = time.Date(t.Year(), time.Month(value), t.Day(), t.Hour(), t.Minute(), 0, 0, t.Location())
case "weekday":
weekday := time.Weekday(value)
for t.Weekday() != weekday {
t = t.AddDate(0, 0, 1)
}
}
return t
} }
func nextWeekday(t time.Time, weekday time.Weekday) time.Time { func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
@@ -410,12 +448,21 @@ func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
} }
return t.AddDate(0, 0, daysUntil) return t.AddDate(0, 0, daysUntil)
} }
func validateTaskInput(task *Task) error {
if task.Payload == nil {
return errors.New("task payload cannot be nil")
}
Logger.Info().Str("taskID", task.ID).Msg("Task validated")
return nil
}
func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) { func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if err := validateTaskInput(payload); err != nil {
// Create a default options instance Logger.Error().Err(err).Msg("Invalid task input")
return
}
options := defaultSchedulerOptions() options := defaultSchedulerOptions()
for _, opt := range opts { for _, opt := range opts {
opt(options) opt(options)
@@ -427,9 +474,7 @@ func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...Schedule
options.Callback = s.pool.callback options.Callback = s.pool.callback
} }
stop := make(chan struct{}) stop := make(chan struct{})
newTask := ScheduledTask{
// Create a new ScheduledTask using the provided options
s.tasks = append(s.tasks, ScheduledTask{
ctx: ctx, ctx: ctx,
handler: options.Handler, handler: options.Handler,
payload: payload, payload: payload,
@@ -442,10 +487,9 @@ func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...Schedule
Interval: options.Interval, Interval: options.Interval,
Recurring: options.Recurring, Recurring: options.Recurring,
}, },
}) }
s.tasks = append(s.tasks, newTask)
// Start scheduling the task go s.schedule(newTask)
go s.schedule(s.tasks[len(s.tasks)-1])
} }
func (s *Scheduler) RemoveTask(payloadID string) { func (s *Scheduler) RemoveTask(payloadID string) {

1289
v1/v1.go

File diff suppressed because it is too large Load Diff