mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-07 17:00:57 +08:00
update: use oarkflow/json
This commit is contained in:
@@ -8,14 +8,16 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
v1 "github.com/oarkflow/mq/v1"
|
||||
"github.com/oarkflow/json"
|
||||
|
||||
v1 "github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
pool := v1.NewPool(5,
|
||||
v1.WithTaskStorage(v1.NewInMemoryTaskStorage()),
|
||||
v1.WithTaskStorage(v1.NewMemoryTaskStorage(10*time.Minute)),
|
||||
v1.WithHandler(func(ctx context.Context, payload *v1.Task) v1.Result {
|
||||
v1.Logger.Info().Str("taskID", payload.ID).Msg("Processing task payload")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -43,14 +45,14 @@ func main() {
|
||||
metrics := pool.Metrics()
|
||||
v1.Logger.Info().Msgf("Metrics: %+v", metrics)
|
||||
pool.Stop()
|
||||
v1.Logger.Info().Msgf("Dead Letter Queue has %d tasks", len(v1.DLQ.Task()))
|
||||
v1.Logger.Info().Msgf("Dead Letter Queue has %d tasks", len(pool.DLQ.Task()))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 50; i++ {
|
||||
task := &v1.Task{
|
||||
ID: "",
|
||||
Payload: fmt.Sprintf("Task Payload %d", i),
|
||||
Payload: json.RawMessage(fmt.Sprintf("Task Payload %d", i)),
|
||||
}
|
||||
if err := pool.EnqueueTask(context.Background(), task, rand.Intn(10)); err != nil {
|
||||
v1.Logger.Error().Err(err).Msg("Failed to enqueue task")
|
||||
|
320
pool.go
320
pool.go
@@ -3,14 +3,17 @@ package mq
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/utils"
|
||||
|
||||
"github.com/oarkflow/log"
|
||||
)
|
||||
|
||||
type Callback func(ctx context.Context, result Result) error
|
||||
@@ -25,35 +28,134 @@ type Metrics struct {
|
||||
ExecutionTime int64
|
||||
}
|
||||
|
||||
type Pool struct {
|
||||
taskStorage TaskStorage
|
||||
scheduler *Scheduler
|
||||
stop chan struct{}
|
||||
taskNotify chan struct{}
|
||||
workerAdjust chan int
|
||||
handler Handler
|
||||
completionCallback CompletionCallback
|
||||
taskAvailableCond *sync.Cond
|
||||
callback Callback
|
||||
taskQueue PriorityQueue
|
||||
overflowBuffer []*QueueTask
|
||||
metrics Metrics
|
||||
wg sync.WaitGroup
|
||||
taskCompletionNotifier sync.WaitGroup
|
||||
timeout time.Duration
|
||||
batchSize int
|
||||
maxMemoryLoad int64
|
||||
idleTimeout time.Duration
|
||||
backoffDuration time.Duration
|
||||
maxRetries int
|
||||
overflowBufferLock sync.RWMutex
|
||||
taskQueueLock sync.Mutex
|
||||
numOfWorkers int32
|
||||
paused bool
|
||||
logger *log.Logger
|
||||
gracefulShutdown bool
|
||||
type Plugin interface {
|
||||
Initialize(config interface{}) error
|
||||
BeforeTask(task *QueueTask)
|
||||
AfterTask(task *QueueTask, result Result)
|
||||
}
|
||||
|
||||
// New fields for enhancements:
|
||||
type DefaultPlugin struct{}
|
||||
|
||||
func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil }
|
||||
func (dp *DefaultPlugin) BeforeTask(task *QueueTask) {
|
||||
Logger.Info().Str("taskID", task.payload.ID).Msg("BeforeTask plugin invoked")
|
||||
}
|
||||
func (dp *DefaultPlugin) AfterTask(task *QueueTask, result Result) {
|
||||
Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked")
|
||||
}
|
||||
|
||||
type DeadLetterQueue struct {
|
||||
tasks []*QueueTask
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewDeadLetterQueue() *DeadLetterQueue {
|
||||
return &DeadLetterQueue{
|
||||
tasks: make([]*QueueTask, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) Task() []*QueueTask {
|
||||
return dlq.tasks
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) Add(task *QueueTask) {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
dlq.tasks = append(dlq.tasks, task)
|
||||
Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue")
|
||||
}
|
||||
|
||||
type InMemoryMetricsRegistry struct {
|
||||
metrics map[string]int64
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewInMemoryMetricsRegistry() *InMemoryMetricsRegistry {
|
||||
return &InMemoryMetricsRegistry{
|
||||
metrics: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *InMemoryMetricsRegistry) Register(metricName string, value interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if v, ok := value.(int64); ok {
|
||||
m.metrics[metricName] = v
|
||||
Logger.Info().Str("metric", metricName).Msgf("Registered metric: %d", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *InMemoryMetricsRegistry) Increment(metricName string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.metrics[metricName]++
|
||||
}
|
||||
|
||||
func (m *InMemoryMetricsRegistry) Get(metricName string) interface{} {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.metrics[metricName]
|
||||
}
|
||||
|
||||
type WarningThresholds struct {
|
||||
HighMemory int64
|
||||
LongExecution time.Duration
|
||||
}
|
||||
|
||||
type DynamicConfig struct {
|
||||
Timeout time.Duration
|
||||
BatchSize int
|
||||
MaxMemoryLoad int64
|
||||
IdleTimeout time.Duration
|
||||
BackoffDuration time.Duration
|
||||
MaxRetries int
|
||||
ReloadInterval time.Duration
|
||||
WarningThreshold WarningThresholds
|
||||
}
|
||||
|
||||
var Config = &DynamicConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
BatchSize: 1,
|
||||
MaxMemoryLoad: 100 * 1024 * 1024,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
BackoffDuration: 2 * time.Second,
|
||||
MaxRetries: 3,
|
||||
ReloadInterval: 30 * time.Second,
|
||||
WarningThreshold: WarningThresholds{
|
||||
HighMemory: 1 * 1024 * 1024,
|
||||
LongExecution: 2 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
type Pool struct {
|
||||
taskStorage TaskStorage
|
||||
scheduler *Scheduler
|
||||
stop chan struct{}
|
||||
taskNotify chan struct{}
|
||||
workerAdjust chan int
|
||||
handler Handler
|
||||
completionCallback CompletionCallback
|
||||
taskAvailableCond *sync.Cond
|
||||
callback Callback
|
||||
DLQ *DeadLetterQueue
|
||||
taskQueue PriorityQueue
|
||||
overflowBuffer []*QueueTask
|
||||
metrics Metrics
|
||||
wg sync.WaitGroup
|
||||
taskCompletionNotifier sync.WaitGroup
|
||||
timeout time.Duration
|
||||
batchSize int
|
||||
maxMemoryLoad int64
|
||||
idleTimeout time.Duration
|
||||
backoffDuration time.Duration
|
||||
maxRetries int
|
||||
overflowBufferLock sync.RWMutex
|
||||
taskQueueLock sync.Mutex
|
||||
numOfWorkers int32
|
||||
paused bool
|
||||
logger log.Logger
|
||||
gracefulShutdown bool
|
||||
thresholds ThresholdConfig
|
||||
diagnosticsEnabled bool
|
||||
metricsRegistry MetricsRegistry
|
||||
@@ -61,33 +163,83 @@ type Pool struct {
|
||||
circuitBreakerOpen bool
|
||||
circuitBreakerFailureCount int32
|
||||
gracefulShutdownTimeout time.Duration
|
||||
plugins []Plugin
|
||||
}
|
||||
|
||||
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
||||
pool := &Pool{
|
||||
stop: make(chan struct{}),
|
||||
taskNotify: make(chan struct{}, numOfWorkers),
|
||||
batchSize: 1,
|
||||
timeout: 10 * time.Second,
|
||||
idleTimeout: 5 * time.Minute,
|
||||
backoffDuration: 2 * time.Second,
|
||||
maxRetries: 3, // Set max retries for failed tasks
|
||||
logger: log.Default(),
|
||||
stop: make(chan struct{}),
|
||||
taskNotify: make(chan struct{}, numOfWorkers),
|
||||
batchSize: Config.BatchSize,
|
||||
timeout: Config.Timeout,
|
||||
idleTimeout: Config.IdleTimeout,
|
||||
backoffDuration: Config.BackoffDuration,
|
||||
maxRetries: Config.MaxRetries,
|
||||
logger: Logger,
|
||||
DLQ: NewDeadLetterQueue(),
|
||||
metricsRegistry: NewInMemoryMetricsRegistry(),
|
||||
diagnosticsEnabled: true,
|
||||
gracefulShutdownTimeout: 10 * time.Second,
|
||||
}
|
||||
pool.scheduler = NewScheduler(pool)
|
||||
pool.taskAvailableCond = sync.NewCond(&sync.Mutex{})
|
||||
for _, opt := range opts {
|
||||
opt(pool)
|
||||
}
|
||||
if len(pool.taskQueue) == 0 {
|
||||
if pool.taskQueue == nil {
|
||||
pool.taskQueue = make(PriorityQueue, 0, 10)
|
||||
}
|
||||
heap.Init(&pool.taskQueue)
|
||||
pool.scheduler.Start()
|
||||
pool.Start(numOfWorkers)
|
||||
startConfigReloader(pool)
|
||||
go pool.dynamicWorkerScaler()
|
||||
go pool.startHealthServer()
|
||||
return pool
|
||||
}
|
||||
|
||||
func validateDynamicConfig(c *DynamicConfig) error {
|
||||
if c.Timeout <= 0 {
|
||||
return errors.New("Timeout must be positive")
|
||||
}
|
||||
if c.BatchSize <= 0 {
|
||||
return errors.New("BatchSize must be > 0")
|
||||
}
|
||||
if c.MaxMemoryLoad <= 0 {
|
||||
return errors.New("MaxMemoryLoad must be > 0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func startConfigReloader(pool *Pool) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(Config.ReloadInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := validateDynamicConfig(Config); err != nil {
|
||||
Logger.Error().Err(err).Msg("Invalid dynamic config, skipping reload")
|
||||
continue
|
||||
}
|
||||
pool.timeout = Config.Timeout
|
||||
pool.batchSize = Config.BatchSize
|
||||
pool.maxMemoryLoad = Config.MaxMemoryLoad
|
||||
pool.idleTimeout = Config.IdleTimeout
|
||||
pool.backoffDuration = Config.BackoffDuration
|
||||
pool.maxRetries = Config.MaxRetries
|
||||
pool.thresholds = ThresholdConfig{
|
||||
HighMemory: Config.WarningThreshold.HighMemory,
|
||||
LongExecution: Config.WarningThreshold.LongExecution,
|
||||
}
|
||||
Logger.Info().Msg("Dynamic configuration reloaded")
|
||||
case <-pool.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (wp *Pool) Start(numWorkers int) {
|
||||
storedTasks, err := wp.taskStorage.GetAllTasks()
|
||||
if err == nil {
|
||||
@@ -144,13 +296,13 @@ func (wp *Pool) processNextBatch() {
|
||||
wp.handleTask(task)
|
||||
}
|
||||
}
|
||||
// @TODO - Why was this done?
|
||||
//if len(tasks) > 0 {
|
||||
// wp.taskCompletionNotifier.Done()
|
||||
//}
|
||||
}
|
||||
|
||||
func (wp *Pool) handleTask(task *QueueTask) {
|
||||
if err := validateTaskInput(task.payload); err != nil {
|
||||
wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Validation failed: %v", err)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
|
||||
defer cancel()
|
||||
taskSize := int64(utils.SizeOf(task.payload))
|
||||
@@ -163,15 +315,15 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
||||
|
||||
// Warning thresholds check
|
||||
if wp.thresholds.LongExecution > 0 && executionTime > int64(wp.thresholds.LongExecution.Milliseconds()) {
|
||||
wp.logger.Printf("Warning: Task %s exceeded execution time threshold: %d ms", task.payload.ID, executionTime)
|
||||
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", executionTime)
|
||||
}
|
||||
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
|
||||
wp.logger.Printf("Warning: Task %s memory usage %d exceeded threshold", task.payload.ID, taskSize)
|
||||
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize)
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
atomic.AddInt64(&wp.metrics.ErrorCount, 1)
|
||||
wp.logger.Printf("Error processing task %s: %v", task.payload.ID, result.Error)
|
||||
wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Error processing task: %v", result.Error)
|
||||
wp.backoffAndStore(task)
|
||||
|
||||
// Circuit breaker check
|
||||
@@ -179,12 +331,12 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
||||
newCount := atomic.AddInt32(&wp.circuitBreakerFailureCount, 1)
|
||||
if newCount >= int32(wp.circuitBreaker.FailureThreshold) {
|
||||
wp.circuitBreakerOpen = true
|
||||
wp.logger.Println("Circuit breaker opened due to errors")
|
||||
wp.logger.Warn().Msg("Circuit breaker opened due to errors")
|
||||
go func() {
|
||||
time.Sleep(wp.circuitBreaker.ResetTimeout)
|
||||
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
|
||||
wp.circuitBreakerOpen = false
|
||||
wp.logger.Println("Circuit breaker reset")
|
||||
wp.logger.Info().Msg("Circuit breaker reset to closed state")
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -198,17 +350,17 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
||||
|
||||
// Diagnostics logging if enabled
|
||||
if wp.diagnosticsEnabled {
|
||||
wp.logger.Printf("Task %s executed in %d ms", task.payload.ID, executionTime)
|
||||
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", executionTime)
|
||||
}
|
||||
|
||||
if wp.callback != nil {
|
||||
if err := wp.callback(ctx, result); err != nil {
|
||||
atomic.AddInt64(&wp.metrics.ErrorCount, 1)
|
||||
wp.logger.Printf("Error in callback for task %s: %v", task.payload.ID, err)
|
||||
wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Callback error: %v", err)
|
||||
}
|
||||
}
|
||||
_ = wp.taskStorage.DeleteTask(task.payload.ID)
|
||||
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
|
||||
wp.metricsRegistry.Register("task_execution_time", executionTime)
|
||||
}
|
||||
|
||||
func (wp *Pool) backoffAndStore(task *QueueTask) {
|
||||
@@ -219,10 +371,11 @@ func (wp *Pool) backoffAndStore(task *QueueTask) {
|
||||
backoff := wp.backoffDuration * (1 << (task.retryCount - 1))
|
||||
jitter := time.Duration(rand.Int63n(int64(backoff) / 2))
|
||||
sleepDuration := backoff + jitter
|
||||
wp.logger.Printf("Task %s retry %d: sleeping for %s", task.payload.ID, task.retryCount, sleepDuration)
|
||||
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Retry %d: sleeping for %s", task.retryCount, sleepDuration)
|
||||
time.Sleep(sleepDuration)
|
||||
} else {
|
||||
wp.logger.Printf("Task %s failed after maximum retries", task.payload.ID)
|
||||
wp.logger.Error().Str("taskID", task.payload.ID).Msg("Task failed after maximum retries")
|
||||
wp.DLQ.Add(task)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,7 +435,9 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
|
||||
if wp.circuitBreaker.Enabled && wp.circuitBreakerOpen {
|
||||
return fmt.Errorf("circuit breaker open, task rejected")
|
||||
}
|
||||
|
||||
if err := validateTaskInput(payload); err != nil {
|
||||
return fmt.Errorf("invalid task input: %w", err)
|
||||
}
|
||||
if payload.ID == "" {
|
||||
payload.ID = NewID()
|
||||
}
|
||||
@@ -344,9 +499,8 @@ func (wp *Pool) startOverflowDrainer() {
|
||||
func (wp *Pool) drainOverflowBuffer() {
|
||||
wp.overflowBufferLock.Lock()
|
||||
overflowTasks := wp.overflowBuffer
|
||||
wp.overflowBuffer = nil // Clear buffer
|
||||
wp.overflowBuffer = nil
|
||||
wp.overflowBufferLock.Unlock()
|
||||
|
||||
for _, task := range overflowTasks {
|
||||
select {
|
||||
case wp.taskNotify <- struct{}{}:
|
||||
@@ -363,8 +517,6 @@ func (wp *Pool) Stop() {
|
||||
wp.gracefulShutdown = true
|
||||
wp.Pause()
|
||||
close(wp.stop)
|
||||
|
||||
// Graceful shutdown with timeout support
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wp.wg.Wait()
|
||||
@@ -373,9 +525,8 @@ func (wp *Pool) Stop() {
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
// All workers finished gracefully.
|
||||
case <-time.After(wp.gracefulShutdownTimeout):
|
||||
wp.logger.Println("Graceful shutdown timeout reached")
|
||||
wp.logger.Warn().Msg("Graceful shutdown timeout reached")
|
||||
}
|
||||
if wp.completionCallback != nil {
|
||||
wp.completionCallback()
|
||||
@@ -395,3 +546,54 @@ func (wp *Pool) Metrics() Metrics {
|
||||
}
|
||||
|
||||
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
|
||||
|
||||
func (wp *Pool) dynamicWorkerScaler() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
wp.taskQueueLock.Lock()
|
||||
queueLen := len(wp.taskQueue)
|
||||
wp.taskQueueLock.Unlock()
|
||||
newWorkers := queueLen/5 + 1
|
||||
wp.logger.Info().Msgf("Auto-scaling: queue length %d, adjusting workers to %d", queueLen, newWorkers)
|
||||
wp.AdjustWorkerCount(newWorkers)
|
||||
case <-wp.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *Pool) startHealthServer() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
status := "OK"
|
||||
if wp.gracefulShutdown {
|
||||
status = "shutting down"
|
||||
}
|
||||
fmt.Fprintf(w, "status: %s\nworkers: %d\nqueueLength: %d\n",
|
||||
status, atomic.LoadInt32(&wp.numOfWorkers), len(wp.taskQueue))
|
||||
})
|
||||
server := &http.Server{
|
||||
Addr: ":8080",
|
||||
Handler: mux,
|
||||
}
|
||||
go func() {
|
||||
wp.logger.Info().Msg("Starting health server on :8080")
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
wp.logger.Error().Err(err).Msg("Health server failed")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-wp.stop
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
wp.logger.Error().Err(err).Msg("Health server shutdown failed")
|
||||
} else {
|
||||
wp.logger.Info().Msg("Health server shutdown gracefully")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@@ -6,13 +6,14 @@ import (
|
||||
|
||||
// New type definitions for enhancements
|
||||
type ThresholdConfig struct {
|
||||
HighMemory int64 // e.g. in bytes
|
||||
LongExecution time.Duration // e.g. warning if task execution exceeds
|
||||
HighMemory int64
|
||||
LongExecution time.Duration
|
||||
}
|
||||
|
||||
type MetricsRegistry interface {
|
||||
Register(metricName string, value interface{})
|
||||
// ...other methods as needed...
|
||||
Increment(metricName string)
|
||||
Get(metricName string) interface{}
|
||||
}
|
||||
|
||||
type CircuitBreakerConfig struct {
|
||||
@@ -25,7 +26,6 @@ type PoolOption func(*Pool)
|
||||
|
||||
func WithTaskQueueSize(size int) PoolOption {
|
||||
return func(p *Pool) {
|
||||
// Initialize the task queue with the specified size
|
||||
p.taskQueue = make(PriorityQueue, 0, size)
|
||||
}
|
||||
}
|
||||
@@ -72,8 +72,6 @@ func WithTaskStorage(storage TaskStorage) PoolOption {
|
||||
}
|
||||
}
|
||||
|
||||
// New option functions:
|
||||
|
||||
func WithWarningThresholds(thresholds ThresholdConfig) PoolOption {
|
||||
return func(p *Pool) {
|
||||
p.thresholds = thresholds
|
||||
@@ -103,3 +101,9 @@ func WithGracefulShutdown(timeout time.Duration) PoolOption {
|
||||
p.gracefulShutdownTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func WithPlugin(plugin Plugin) PoolOption {
|
||||
return func(p *Pool) {
|
||||
p.plugins = append(p.plugins, plugin)
|
||||
}
|
||||
}
|
||||
|
13
publisher.go
13
publisher.go
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/oarkflow/json"
|
||||
|
||||
"github.com/oarkflow/json/jsonparser"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
@@ -77,16 +78,18 @@ func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to broker: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
return p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||
}
|
||||
|
||||
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error {
|
||||
func (p *Publisher) onClose(_ context.Context, conn net.Conn) error {
|
||||
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
|
||||
func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) {
|
||||
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
@@ -99,7 +102,9 @@ func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result
|
||||
err = fmt.Errorf("failed to connect to broker: %w", err)
|
||||
return Result{Error: err}
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||
resultCh := make(chan Result)
|
||||
go func() {
|
||||
|
11
queue.go
11
queue.go
@@ -45,38 +45,33 @@ func (b *Broker) NewQueue(name string) *Queue {
|
||||
type QueueTask struct {
|
||||
ctx context.Context
|
||||
payload *Task
|
||||
retryCount int
|
||||
priority int
|
||||
index int // The index in the heap
|
||||
retryCount int
|
||||
index int
|
||||
}
|
||||
|
||||
type PriorityQueue []*QueueTask
|
||||
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
return pq[i].priority > pq[j].priority
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
pq[i].index = i
|
||||
pq[j].index = j
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Push(x interface{}) {
|
||||
n := len(*pq)
|
||||
task := x.(*QueueTask)
|
||||
task.index = n
|
||||
*pq = append(*pq, task)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() interface{} {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
task := old[n-1]
|
||||
old[n-1] = nil // avoid memory leak
|
||||
task.index = -1 // for safety
|
||||
task.index = -1
|
||||
*pq = old[0 : n-1]
|
||||
return task
|
||||
}
|
||||
|
172
scheduler.go
172
scheduler.go
@@ -2,14 +2,19 @@ package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/log"
|
||||
)
|
||||
|
||||
var Logger = log.DefaultLogger
|
||||
|
||||
type ScheduleOptions struct {
|
||||
Handler Handler
|
||||
Callback Callback
|
||||
@@ -20,7 +25,7 @@ type ScheduleOptions struct {
|
||||
|
||||
type SchedulerOption func(*ScheduleOptions)
|
||||
|
||||
// Helper functions to create SchedulerOptions
|
||||
// WithSchedulerHandler Helper functions to create SchedulerOptions
|
||||
func WithSchedulerHandler(handler Handler) SchedulerOption {
|
||||
return func(opts *ScheduleOptions) {
|
||||
opts.Handler = handler
|
||||
@@ -177,16 +182,14 @@ func parseCronSpec(cronSpec string) (CronSchedule, error) {
|
||||
}
|
||||
|
||||
func cronFieldToString(field string, fieldName string) (string, error) {
|
||||
switch field {
|
||||
case "*":
|
||||
if field == "*" {
|
||||
return fmt.Sprintf("every %s", fieldName), nil
|
||||
default:
|
||||
values, err := parseCronValue(field)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
|
||||
}
|
||||
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
|
||||
}
|
||||
values, err := parseCronValue(field)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
|
||||
}
|
||||
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
|
||||
}
|
||||
|
||||
func parseCronValue(field string) ([]string, error) {
|
||||
@@ -223,6 +226,8 @@ type Scheduler struct {
|
||||
}
|
||||
|
||||
func (s *Scheduler) Start() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, task := range s.tasks {
|
||||
go s.schedule(task)
|
||||
}
|
||||
@@ -279,46 +284,81 @@ func (s *Scheduler) schedule(task ScheduledTask) {
|
||||
}
|
||||
}
|
||||
|
||||
func startSpan(operation string) (context.Context, func()) {
|
||||
startTime := time.Now()
|
||||
Logger.Info().Str("operation", operation).Msg("Span started")
|
||||
ctx := context.WithValue(context.Background(), "traceID", fmt.Sprintf("%d", startTime.UnixNano()))
|
||||
return ctx, func() {
|
||||
duration := time.Since(startTime)
|
||||
Logger.Info().Str("operation", operation).Msgf("Span ended; duration: %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func acquireDistributedLock(taskID string) bool {
|
||||
Logger.Info().Str("taskID", taskID).Msg("Acquiring distributed lock (stub)")
|
||||
return true
|
||||
}
|
||||
|
||||
func releaseDistributedLock(taskID string) {
|
||||
Logger.Info().Str("taskID", taskID).Msg("Releasing distributed lock (stub)")
|
||||
}
|
||||
|
||||
var taskPool = sync.Pool{
|
||||
New: func() interface{} { return new(Task) },
|
||||
}
|
||||
var queueTaskPool = sync.Pool{
|
||||
New: func() interface{} { return new(QueueTask) },
|
||||
}
|
||||
|
||||
func getQueueTask() *QueueTask {
|
||||
return queueTaskPool.Get().(*QueueTask)
|
||||
}
|
||||
|
||||
// Enhance executeTask with circuit breaker and diagnostics logging support.
|
||||
func (s *Scheduler) executeTask(task ScheduledTask) {
|
||||
if !task.config.Overlap && !s.pool.gracefulShutdown {
|
||||
// Prevent overlapping execution if not allowed.
|
||||
// ...existing code...
|
||||
}
|
||||
go func() {
|
||||
// Recover from panic to keep scheduler running.
|
||||
_, cancelSpan := startSpan("executeTask")
|
||||
defer cancelSpan()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Printf("Recovered from panic in scheduled task %s: %v\n", task.payload.ID, r)
|
||||
Logger.Error().Str("taskID", task.payload.ID).Msgf("Recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
start := time.Now()
|
||||
for _, plug := range s.pool.plugins {
|
||||
plug.BeforeTask(getQueueTask())
|
||||
}
|
||||
if !acquireDistributedLock(task.payload.ID) {
|
||||
Logger.Warn().Str("taskID", task.payload.ID).Msg("Failed to acquire distributed lock")
|
||||
return
|
||||
}
|
||||
defer releaseDistributedLock(task.payload.ID)
|
||||
result := task.handler(task.ctx, task.payload)
|
||||
execTime := time.Since(start).Milliseconds()
|
||||
// Diagnostics logging if enabled.
|
||||
if s.pool.diagnosticsEnabled {
|
||||
s.pool.logger.Printf("Scheduled task %s executed in %d ms", task.payload.ID, execTime)
|
||||
Logger.Info().Str("taskID", task.payload.ID).Msgf("Executed in %d ms", execTime)
|
||||
}
|
||||
// Circuit breaker check similar to pool.
|
||||
if result.Error != nil && s.pool.circuitBreaker.Enabled {
|
||||
newCount := atomic.AddInt32(&s.pool.circuitBreakerFailureCount, 1)
|
||||
if newCount >= int32(s.pool.circuitBreaker.FailureThreshold) {
|
||||
s.pool.circuitBreakerOpen = true
|
||||
s.pool.logger.Println("Scheduler: circuit breaker opened due to errors")
|
||||
Logger.Warn().Msg("Circuit breaker opened due to errors")
|
||||
go func() {
|
||||
time.Sleep(s.pool.circuitBreaker.ResetTimeout)
|
||||
atomic.StoreInt32(&s.pool.circuitBreakerFailureCount, 0)
|
||||
s.pool.circuitBreakerOpen = false
|
||||
s.pool.logger.Println("Scheduler: circuit breaker reset")
|
||||
Logger.Info().Msg("Circuit breaker reset to closed state")
|
||||
}()
|
||||
}
|
||||
}
|
||||
// Invoke callback if defined.
|
||||
if task.config.Callback != nil {
|
||||
_ = task.config.Callback(task.ctx, result)
|
||||
}
|
||||
task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result})
|
||||
fmt.Printf("Executed scheduled task: %s\n", task.payload.ID)
|
||||
for _, plug := range s.pool.plugins {
|
||||
plug.AfterTask(getQueueTask(), result)
|
||||
}
|
||||
Logger.Info().Str("taskID", task.payload.ID).Msg("Scheduled task executed")
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -354,7 +394,7 @@ func (task ScheduledTask) getNextRunTime(now time.Time) time.Time {
|
||||
func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
|
||||
cronSpecs, err := parseCronSpec(task.schedule.CronSpec)
|
||||
if err != nil {
|
||||
fmt.Println(fmt.Sprintf("Invalid CRON spec format: %s", err.Error()))
|
||||
Logger.Error().Err(err).Msg("Invalid CRON spec")
|
||||
return now
|
||||
}
|
||||
nextRun := now
|
||||
@@ -367,40 +407,38 @@ func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
|
||||
}
|
||||
|
||||
func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time {
|
||||
switch fieldSpec {
|
||||
case "*":
|
||||
return t
|
||||
default:
|
||||
value, _ := strconv.Atoi(fieldSpec)
|
||||
switch unit {
|
||||
case "minute":
|
||||
if t.Minute() > value {
|
||||
t = t.Add(time.Hour)
|
||||
}
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), value, 0, 0, t.Location())
|
||||
case "hour":
|
||||
if t.Hour() > value {
|
||||
t = t.AddDate(0, 0, 1)
|
||||
}
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), value, t.Minute(), 0, 0, t.Location())
|
||||
case "day":
|
||||
if t.Day() > value {
|
||||
t = t.AddDate(0, 1, 0)
|
||||
}
|
||||
t = time.Date(t.Year(), t.Month(), value, t.Hour(), t.Minute(), 0, 0, t.Location())
|
||||
case "month":
|
||||
if int(t.Month()) > value {
|
||||
t = t.AddDate(1, 0, 0)
|
||||
}
|
||||
t = time.Date(t.Year(), time.Month(value), t.Day(), t.Hour(), t.Minute(), 0, 0, t.Location())
|
||||
case "weekday":
|
||||
weekday := time.Weekday(value)
|
||||
for t.Weekday() != weekday {
|
||||
t = t.AddDate(0, 0, 1)
|
||||
}
|
||||
}
|
||||
if fieldSpec == "*" {
|
||||
return t
|
||||
}
|
||||
value, _ := strconv.Atoi(fieldSpec)
|
||||
switch unit {
|
||||
case "minute":
|
||||
if t.Minute() > value {
|
||||
t = t.Add(time.Hour)
|
||||
}
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), value, 0, 0, t.Location())
|
||||
case "hour":
|
||||
if t.Hour() > value {
|
||||
t = t.AddDate(0, 0, 1)
|
||||
}
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), value, t.Minute(), 0, 0, t.Location())
|
||||
case "day":
|
||||
if t.Day() > value {
|
||||
t = t.AddDate(0, 1, 0)
|
||||
}
|
||||
t = time.Date(t.Year(), t.Month(), value, t.Hour(), t.Minute(), 0, 0, t.Location())
|
||||
case "month":
|
||||
if int(t.Month()) > value {
|
||||
t = t.AddDate(1, 0, 0)
|
||||
}
|
||||
t = time.Date(t.Year(), time.Month(value), t.Day(), t.Hour(), t.Minute(), 0, 0, t.Location())
|
||||
case "weekday":
|
||||
weekday := time.Weekday(value)
|
||||
for t.Weekday() != weekday {
|
||||
t = t.AddDate(0, 0, 1)
|
||||
}
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
|
||||
@@ -410,12 +448,21 @@ func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
|
||||
}
|
||||
return t.AddDate(0, 0, daysUntil)
|
||||
}
|
||||
func validateTaskInput(task *Task) error {
|
||||
if task.Payload == nil {
|
||||
return errors.New("task payload cannot be nil")
|
||||
}
|
||||
Logger.Info().Str("taskID", task.ID).Msg("Task validated")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Create a default options instance
|
||||
if err := validateTaskInput(payload); err != nil {
|
||||
Logger.Error().Err(err).Msg("Invalid task input")
|
||||
return
|
||||
}
|
||||
options := defaultSchedulerOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
@@ -427,9 +474,7 @@ func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...Schedule
|
||||
options.Callback = s.pool.callback
|
||||
}
|
||||
stop := make(chan struct{})
|
||||
|
||||
// Create a new ScheduledTask using the provided options
|
||||
s.tasks = append(s.tasks, ScheduledTask{
|
||||
newTask := ScheduledTask{
|
||||
ctx: ctx,
|
||||
handler: options.Handler,
|
||||
payload: payload,
|
||||
@@ -442,10 +487,9 @@ func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...Schedule
|
||||
Interval: options.Interval,
|
||||
Recurring: options.Recurring,
|
||||
},
|
||||
})
|
||||
|
||||
// Start scheduling the task
|
||||
go s.schedule(s.tasks[len(s.tasks)-1])
|
||||
}
|
||||
s.tasks = append(s.tasks, newTask)
|
||||
go s.schedule(newTask)
|
||||
}
|
||||
|
||||
func (s *Scheduler) RemoveTask(payloadID string) {
|
||||
|
Reference in New Issue
Block a user