mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-08 01:10:09 +08:00
update: use oarkflow/json
This commit is contained in:
@@ -8,14 +8,16 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/oarkflow/mq/v1"
|
"github.com/oarkflow/json"
|
||||||
|
|
||||||
|
v1 "github.com/oarkflow/mq"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
defer stop()
|
defer stop()
|
||||||
pool := v1.NewPool(5,
|
pool := v1.NewPool(5,
|
||||||
v1.WithTaskStorage(v1.NewInMemoryTaskStorage()),
|
v1.WithTaskStorage(v1.NewMemoryTaskStorage(10*time.Minute)),
|
||||||
v1.WithHandler(func(ctx context.Context, payload *v1.Task) v1.Result {
|
v1.WithHandler(func(ctx context.Context, payload *v1.Task) v1.Result {
|
||||||
v1.Logger.Info().Str("taskID", payload.ID).Msg("Processing task payload")
|
v1.Logger.Info().Str("taskID", payload.ID).Msg("Processing task payload")
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
@@ -43,14 +45,14 @@ func main() {
|
|||||||
metrics := pool.Metrics()
|
metrics := pool.Metrics()
|
||||||
v1.Logger.Info().Msgf("Metrics: %+v", metrics)
|
v1.Logger.Info().Msgf("Metrics: %+v", metrics)
|
||||||
pool.Stop()
|
pool.Stop()
|
||||||
v1.Logger.Info().Msgf("Dead Letter Queue has %d tasks", len(v1.DLQ.Task()))
|
v1.Logger.Info().Msgf("Dead Letter Queue has %d tasks", len(pool.DLQ.Task()))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < 50; i++ {
|
for i := 0; i < 50; i++ {
|
||||||
task := &v1.Task{
|
task := &v1.Task{
|
||||||
ID: "",
|
ID: "",
|
||||||
Payload: fmt.Sprintf("Task Payload %d", i),
|
Payload: json.RawMessage(fmt.Sprintf("Task Payload %d", i)),
|
||||||
}
|
}
|
||||||
if err := pool.EnqueueTask(context.Background(), task, rand.Intn(10)); err != nil {
|
if err := pool.EnqueueTask(context.Background(), task, rand.Intn(10)); err != nil {
|
||||||
v1.Logger.Error().Err(err).Msg("Failed to enqueue task")
|
v1.Logger.Error().Err(err).Msg("Failed to enqueue task")
|
||||||
|
266
pool.go
266
pool.go
@@ -3,14 +3,17 @@ package mq
|
|||||||
import (
|
import (
|
||||||
"container/heap"
|
"container/heap"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/utils"
|
"github.com/oarkflow/mq/utils"
|
||||||
|
|
||||||
|
"github.com/oarkflow/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Callback func(ctx context.Context, result Result) error
|
type Callback func(ctx context.Context, result Result) error
|
||||||
@@ -25,6 +28,106 @@ type Metrics struct {
|
|||||||
ExecutionTime int64
|
ExecutionTime int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Plugin interface {
|
||||||
|
Initialize(config interface{}) error
|
||||||
|
BeforeTask(task *QueueTask)
|
||||||
|
AfterTask(task *QueueTask, result Result)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultPlugin struct{}
|
||||||
|
|
||||||
|
func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil }
|
||||||
|
func (dp *DefaultPlugin) BeforeTask(task *QueueTask) {
|
||||||
|
Logger.Info().Str("taskID", task.payload.ID).Msg("BeforeTask plugin invoked")
|
||||||
|
}
|
||||||
|
func (dp *DefaultPlugin) AfterTask(task *QueueTask, result Result) {
|
||||||
|
Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeadLetterQueue struct {
|
||||||
|
tasks []*QueueTask
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeadLetterQueue() *DeadLetterQueue {
|
||||||
|
return &DeadLetterQueue{
|
||||||
|
tasks: make([]*QueueTask, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dlq *DeadLetterQueue) Task() []*QueueTask {
|
||||||
|
return dlq.tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dlq *DeadLetterQueue) Add(task *QueueTask) {
|
||||||
|
dlq.mu.Lock()
|
||||||
|
defer dlq.mu.Unlock()
|
||||||
|
dlq.tasks = append(dlq.tasks, task)
|
||||||
|
Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue")
|
||||||
|
}
|
||||||
|
|
||||||
|
type InMemoryMetricsRegistry struct {
|
||||||
|
metrics map[string]int64
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInMemoryMetricsRegistry() *InMemoryMetricsRegistry {
|
||||||
|
return &InMemoryMetricsRegistry{
|
||||||
|
metrics: make(map[string]int64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InMemoryMetricsRegistry) Register(metricName string, value interface{}) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if v, ok := value.(int64); ok {
|
||||||
|
m.metrics[metricName] = v
|
||||||
|
Logger.Info().Str("metric", metricName).Msgf("Registered metric: %d", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InMemoryMetricsRegistry) Increment(metricName string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.metrics[metricName]++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InMemoryMetricsRegistry) Get(metricName string) interface{} {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.metrics[metricName]
|
||||||
|
}
|
||||||
|
|
||||||
|
type WarningThresholds struct {
|
||||||
|
HighMemory int64
|
||||||
|
LongExecution time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type DynamicConfig struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
BatchSize int
|
||||||
|
MaxMemoryLoad int64
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
BackoffDuration time.Duration
|
||||||
|
MaxRetries int
|
||||||
|
ReloadInterval time.Duration
|
||||||
|
WarningThreshold WarningThresholds
|
||||||
|
}
|
||||||
|
|
||||||
|
var Config = &DynamicConfig{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
BatchSize: 1,
|
||||||
|
MaxMemoryLoad: 100 * 1024 * 1024,
|
||||||
|
IdleTimeout: 5 * time.Minute,
|
||||||
|
BackoffDuration: 2 * time.Second,
|
||||||
|
MaxRetries: 3,
|
||||||
|
ReloadInterval: 30 * time.Second,
|
||||||
|
WarningThreshold: WarningThresholds{
|
||||||
|
HighMemory: 1 * 1024 * 1024,
|
||||||
|
LongExecution: 2 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
type Pool struct {
|
type Pool struct {
|
||||||
taskStorage TaskStorage
|
taskStorage TaskStorage
|
||||||
scheduler *Scheduler
|
scheduler *Scheduler
|
||||||
@@ -35,6 +138,7 @@ type Pool struct {
|
|||||||
completionCallback CompletionCallback
|
completionCallback CompletionCallback
|
||||||
taskAvailableCond *sync.Cond
|
taskAvailableCond *sync.Cond
|
||||||
callback Callback
|
callback Callback
|
||||||
|
DLQ *DeadLetterQueue
|
||||||
taskQueue PriorityQueue
|
taskQueue PriorityQueue
|
||||||
overflowBuffer []*QueueTask
|
overflowBuffer []*QueueTask
|
||||||
metrics Metrics
|
metrics Metrics
|
||||||
@@ -50,10 +154,8 @@ type Pool struct {
|
|||||||
taskQueueLock sync.Mutex
|
taskQueueLock sync.Mutex
|
||||||
numOfWorkers int32
|
numOfWorkers int32
|
||||||
paused bool
|
paused bool
|
||||||
logger *log.Logger
|
logger log.Logger
|
||||||
gracefulShutdown bool
|
gracefulShutdown bool
|
||||||
|
|
||||||
// New fields for enhancements:
|
|
||||||
thresholds ThresholdConfig
|
thresholds ThresholdConfig
|
||||||
diagnosticsEnabled bool
|
diagnosticsEnabled bool
|
||||||
metricsRegistry MetricsRegistry
|
metricsRegistry MetricsRegistry
|
||||||
@@ -61,33 +163,83 @@ type Pool struct {
|
|||||||
circuitBreakerOpen bool
|
circuitBreakerOpen bool
|
||||||
circuitBreakerFailureCount int32
|
circuitBreakerFailureCount int32
|
||||||
gracefulShutdownTimeout time.Duration
|
gracefulShutdownTimeout time.Duration
|
||||||
|
plugins []Plugin
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
||||||
pool := &Pool{
|
pool := &Pool{
|
||||||
stop: make(chan struct{}),
|
stop: make(chan struct{}),
|
||||||
taskNotify: make(chan struct{}, numOfWorkers),
|
taskNotify: make(chan struct{}, numOfWorkers),
|
||||||
batchSize: 1,
|
batchSize: Config.BatchSize,
|
||||||
timeout: 10 * time.Second,
|
timeout: Config.Timeout,
|
||||||
idleTimeout: 5 * time.Minute,
|
idleTimeout: Config.IdleTimeout,
|
||||||
backoffDuration: 2 * time.Second,
|
backoffDuration: Config.BackoffDuration,
|
||||||
maxRetries: 3, // Set max retries for failed tasks
|
maxRetries: Config.MaxRetries,
|
||||||
logger: log.Default(),
|
logger: Logger,
|
||||||
|
DLQ: NewDeadLetterQueue(),
|
||||||
|
metricsRegistry: NewInMemoryMetricsRegistry(),
|
||||||
|
diagnosticsEnabled: true,
|
||||||
|
gracefulShutdownTimeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
pool.scheduler = NewScheduler(pool)
|
pool.scheduler = NewScheduler(pool)
|
||||||
pool.taskAvailableCond = sync.NewCond(&sync.Mutex{})
|
pool.taskAvailableCond = sync.NewCond(&sync.Mutex{})
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(pool)
|
opt(pool)
|
||||||
}
|
}
|
||||||
if len(pool.taskQueue) == 0 {
|
if pool.taskQueue == nil {
|
||||||
pool.taskQueue = make(PriorityQueue, 0, 10)
|
pool.taskQueue = make(PriorityQueue, 0, 10)
|
||||||
}
|
}
|
||||||
heap.Init(&pool.taskQueue)
|
heap.Init(&pool.taskQueue)
|
||||||
pool.scheduler.Start()
|
pool.scheduler.Start()
|
||||||
pool.Start(numOfWorkers)
|
pool.Start(numOfWorkers)
|
||||||
|
startConfigReloader(pool)
|
||||||
|
go pool.dynamicWorkerScaler()
|
||||||
|
go pool.startHealthServer()
|
||||||
return pool
|
return pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateDynamicConfig(c *DynamicConfig) error {
|
||||||
|
if c.Timeout <= 0 {
|
||||||
|
return errors.New("Timeout must be positive")
|
||||||
|
}
|
||||||
|
if c.BatchSize <= 0 {
|
||||||
|
return errors.New("BatchSize must be > 0")
|
||||||
|
}
|
||||||
|
if c.MaxMemoryLoad <= 0 {
|
||||||
|
return errors.New("MaxMemoryLoad must be > 0")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startConfigReloader(pool *Pool) {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(Config.ReloadInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := validateDynamicConfig(Config); err != nil {
|
||||||
|
Logger.Error().Err(err).Msg("Invalid dynamic config, skipping reload")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pool.timeout = Config.Timeout
|
||||||
|
pool.batchSize = Config.BatchSize
|
||||||
|
pool.maxMemoryLoad = Config.MaxMemoryLoad
|
||||||
|
pool.idleTimeout = Config.IdleTimeout
|
||||||
|
pool.backoffDuration = Config.BackoffDuration
|
||||||
|
pool.maxRetries = Config.MaxRetries
|
||||||
|
pool.thresholds = ThresholdConfig{
|
||||||
|
HighMemory: Config.WarningThreshold.HighMemory,
|
||||||
|
LongExecution: Config.WarningThreshold.LongExecution,
|
||||||
|
}
|
||||||
|
Logger.Info().Msg("Dynamic configuration reloaded")
|
||||||
|
case <-pool.stop:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
func (wp *Pool) Start(numWorkers int) {
|
func (wp *Pool) Start(numWorkers int) {
|
||||||
storedTasks, err := wp.taskStorage.GetAllTasks()
|
storedTasks, err := wp.taskStorage.GetAllTasks()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -144,13 +296,13 @@ func (wp *Pool) processNextBatch() {
|
|||||||
wp.handleTask(task)
|
wp.handleTask(task)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// @TODO - Why was this done?
|
|
||||||
//if len(tasks) > 0 {
|
|
||||||
// wp.taskCompletionNotifier.Done()
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wp *Pool) handleTask(task *QueueTask) {
|
func (wp *Pool) handleTask(task *QueueTask) {
|
||||||
|
if err := validateTaskInput(task.payload); err != nil {
|
||||||
|
wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Validation failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
|
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
taskSize := int64(utils.SizeOf(task.payload))
|
taskSize := int64(utils.SizeOf(task.payload))
|
||||||
@@ -163,15 +315,15 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
|||||||
|
|
||||||
// Warning thresholds check
|
// Warning thresholds check
|
||||||
if wp.thresholds.LongExecution > 0 && executionTime > int64(wp.thresholds.LongExecution.Milliseconds()) {
|
if wp.thresholds.LongExecution > 0 && executionTime > int64(wp.thresholds.LongExecution.Milliseconds()) {
|
||||||
wp.logger.Printf("Warning: Task %s exceeded execution time threshold: %d ms", task.payload.ID, executionTime)
|
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", executionTime)
|
||||||
}
|
}
|
||||||
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
|
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
|
||||||
wp.logger.Printf("Warning: Task %s memory usage %d exceeded threshold", task.payload.ID, taskSize)
|
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
atomic.AddInt64(&wp.metrics.ErrorCount, 1)
|
atomic.AddInt64(&wp.metrics.ErrorCount, 1)
|
||||||
wp.logger.Printf("Error processing task %s: %v", task.payload.ID, result.Error)
|
wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Error processing task: %v", result.Error)
|
||||||
wp.backoffAndStore(task)
|
wp.backoffAndStore(task)
|
||||||
|
|
||||||
// Circuit breaker check
|
// Circuit breaker check
|
||||||
@@ -179,12 +331,12 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
|||||||
newCount := atomic.AddInt32(&wp.circuitBreakerFailureCount, 1)
|
newCount := atomic.AddInt32(&wp.circuitBreakerFailureCount, 1)
|
||||||
if newCount >= int32(wp.circuitBreaker.FailureThreshold) {
|
if newCount >= int32(wp.circuitBreaker.FailureThreshold) {
|
||||||
wp.circuitBreakerOpen = true
|
wp.circuitBreakerOpen = true
|
||||||
wp.logger.Println("Circuit breaker opened due to errors")
|
wp.logger.Warn().Msg("Circuit breaker opened due to errors")
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(wp.circuitBreaker.ResetTimeout)
|
time.Sleep(wp.circuitBreaker.ResetTimeout)
|
||||||
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
|
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
|
||||||
wp.circuitBreakerOpen = false
|
wp.circuitBreakerOpen = false
|
||||||
wp.logger.Println("Circuit breaker reset")
|
wp.logger.Info().Msg("Circuit breaker reset to closed state")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -198,17 +350,17 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
|||||||
|
|
||||||
// Diagnostics logging if enabled
|
// Diagnostics logging if enabled
|
||||||
if wp.diagnosticsEnabled {
|
if wp.diagnosticsEnabled {
|
||||||
wp.logger.Printf("Task %s executed in %d ms", task.payload.ID, executionTime)
|
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", executionTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
if wp.callback != nil {
|
if wp.callback != nil {
|
||||||
if err := wp.callback(ctx, result); err != nil {
|
if err := wp.callback(ctx, result); err != nil {
|
||||||
atomic.AddInt64(&wp.metrics.ErrorCount, 1)
|
atomic.AddInt64(&wp.metrics.ErrorCount, 1)
|
||||||
wp.logger.Printf("Error in callback for task %s: %v", task.payload.ID, err)
|
wp.logger.Error().Str("taskID", task.payload.ID).Msgf("Callback error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ = wp.taskStorage.DeleteTask(task.payload.ID)
|
_ = wp.taskStorage.DeleteTask(task.payload.ID)
|
||||||
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
|
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
|
||||||
|
wp.metricsRegistry.Register("task_execution_time", executionTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wp *Pool) backoffAndStore(task *QueueTask) {
|
func (wp *Pool) backoffAndStore(task *QueueTask) {
|
||||||
@@ -219,10 +371,11 @@ func (wp *Pool) backoffAndStore(task *QueueTask) {
|
|||||||
backoff := wp.backoffDuration * (1 << (task.retryCount - 1))
|
backoff := wp.backoffDuration * (1 << (task.retryCount - 1))
|
||||||
jitter := time.Duration(rand.Int63n(int64(backoff) / 2))
|
jitter := time.Duration(rand.Int63n(int64(backoff) / 2))
|
||||||
sleepDuration := backoff + jitter
|
sleepDuration := backoff + jitter
|
||||||
wp.logger.Printf("Task %s retry %d: sleeping for %s", task.payload.ID, task.retryCount, sleepDuration)
|
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Retry %d: sleeping for %s", task.retryCount, sleepDuration)
|
||||||
time.Sleep(sleepDuration)
|
time.Sleep(sleepDuration)
|
||||||
} else {
|
} else {
|
||||||
wp.logger.Printf("Task %s failed after maximum retries", task.payload.ID)
|
wp.logger.Error().Str("taskID", task.payload.ID).Msg("Task failed after maximum retries")
|
||||||
|
wp.DLQ.Add(task)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,7 +435,9 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
|
|||||||
if wp.circuitBreaker.Enabled && wp.circuitBreakerOpen {
|
if wp.circuitBreaker.Enabled && wp.circuitBreakerOpen {
|
||||||
return fmt.Errorf("circuit breaker open, task rejected")
|
return fmt.Errorf("circuit breaker open, task rejected")
|
||||||
}
|
}
|
||||||
|
if err := validateTaskInput(payload); err != nil {
|
||||||
|
return fmt.Errorf("invalid task input: %w", err)
|
||||||
|
}
|
||||||
if payload.ID == "" {
|
if payload.ID == "" {
|
||||||
payload.ID = NewID()
|
payload.ID = NewID()
|
||||||
}
|
}
|
||||||
@@ -344,9 +499,8 @@ func (wp *Pool) startOverflowDrainer() {
|
|||||||
func (wp *Pool) drainOverflowBuffer() {
|
func (wp *Pool) drainOverflowBuffer() {
|
||||||
wp.overflowBufferLock.Lock()
|
wp.overflowBufferLock.Lock()
|
||||||
overflowTasks := wp.overflowBuffer
|
overflowTasks := wp.overflowBuffer
|
||||||
wp.overflowBuffer = nil // Clear buffer
|
wp.overflowBuffer = nil
|
||||||
wp.overflowBufferLock.Unlock()
|
wp.overflowBufferLock.Unlock()
|
||||||
|
|
||||||
for _, task := range overflowTasks {
|
for _, task := range overflowTasks {
|
||||||
select {
|
select {
|
||||||
case wp.taskNotify <- struct{}{}:
|
case wp.taskNotify <- struct{}{}:
|
||||||
@@ -363,8 +517,6 @@ func (wp *Pool) Stop() {
|
|||||||
wp.gracefulShutdown = true
|
wp.gracefulShutdown = true
|
||||||
wp.Pause()
|
wp.Pause()
|
||||||
close(wp.stop)
|
close(wp.stop)
|
||||||
|
|
||||||
// Graceful shutdown with timeout support
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
wp.wg.Wait()
|
wp.wg.Wait()
|
||||||
@@ -373,9 +525,8 @@ func (wp *Pool) Stop() {
|
|||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
// All workers finished gracefully.
|
|
||||||
case <-time.After(wp.gracefulShutdownTimeout):
|
case <-time.After(wp.gracefulShutdownTimeout):
|
||||||
wp.logger.Println("Graceful shutdown timeout reached")
|
wp.logger.Warn().Msg("Graceful shutdown timeout reached")
|
||||||
}
|
}
|
||||||
if wp.completionCallback != nil {
|
if wp.completionCallback != nil {
|
||||||
wp.completionCallback()
|
wp.completionCallback()
|
||||||
@@ -395,3 +546,54 @@ func (wp *Pool) Metrics() Metrics {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
|
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
|
||||||
|
|
||||||
|
func (wp *Pool) dynamicWorkerScaler() {
|
||||||
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
wp.taskQueueLock.Lock()
|
||||||
|
queueLen := len(wp.taskQueue)
|
||||||
|
wp.taskQueueLock.Unlock()
|
||||||
|
newWorkers := queueLen/5 + 1
|
||||||
|
wp.logger.Info().Msgf("Auto-scaling: queue length %d, adjusting workers to %d", queueLen, newWorkers)
|
||||||
|
wp.AdjustWorkerCount(newWorkers)
|
||||||
|
case <-wp.stop:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wp *Pool) startHealthServer() {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
status := "OK"
|
||||||
|
if wp.gracefulShutdown {
|
||||||
|
status = "shutting down"
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "status: %s\nworkers: %d\nqueueLength: %d\n",
|
||||||
|
status, atomic.LoadInt32(&wp.numOfWorkers), len(wp.taskQueue))
|
||||||
|
})
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: ":8080",
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
wp.logger.Info().Msg("Starting health server on :8080")
|
||||||
|
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
wp.logger.Error().Err(err).Msg("Health server failed")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-wp.stop
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := server.Shutdown(ctx); err != nil {
|
||||||
|
wp.logger.Error().Err(err).Msg("Health server shutdown failed")
|
||||||
|
} else {
|
||||||
|
wp.logger.Info().Msg("Health server shutdown gracefully")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
@@ -6,13 +6,14 @@ import (
|
|||||||
|
|
||||||
// New type definitions for enhancements
|
// New type definitions for enhancements
|
||||||
type ThresholdConfig struct {
|
type ThresholdConfig struct {
|
||||||
HighMemory int64 // e.g. in bytes
|
HighMemory int64
|
||||||
LongExecution time.Duration // e.g. warning if task execution exceeds
|
LongExecution time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type MetricsRegistry interface {
|
type MetricsRegistry interface {
|
||||||
Register(metricName string, value interface{})
|
Register(metricName string, value interface{})
|
||||||
// ...other methods as needed...
|
Increment(metricName string)
|
||||||
|
Get(metricName string) interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type CircuitBreakerConfig struct {
|
type CircuitBreakerConfig struct {
|
||||||
@@ -25,7 +26,6 @@ type PoolOption func(*Pool)
|
|||||||
|
|
||||||
func WithTaskQueueSize(size int) PoolOption {
|
func WithTaskQueueSize(size int) PoolOption {
|
||||||
return func(p *Pool) {
|
return func(p *Pool) {
|
||||||
// Initialize the task queue with the specified size
|
|
||||||
p.taskQueue = make(PriorityQueue, 0, size)
|
p.taskQueue = make(PriorityQueue, 0, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -72,8 +72,6 @@ func WithTaskStorage(storage TaskStorage) PoolOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// New option functions:
|
|
||||||
|
|
||||||
func WithWarningThresholds(thresholds ThresholdConfig) PoolOption {
|
func WithWarningThresholds(thresholds ThresholdConfig) PoolOption {
|
||||||
return func(p *Pool) {
|
return func(p *Pool) {
|
||||||
p.thresholds = thresholds
|
p.thresholds = thresholds
|
||||||
@@ -103,3 +101,9 @@ func WithGracefulShutdown(timeout time.Duration) PoolOption {
|
|||||||
p.gracefulShutdownTimeout = timeout
|
p.gracefulShutdownTimeout = timeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithPlugin(plugin Plugin) PoolOption {
|
||||||
|
return func(p *Pool) {
|
||||||
|
p.plugins = append(p.plugins, plugin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
13
publisher.go
13
publisher.go
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/oarkflow/json"
|
"github.com/oarkflow/json"
|
||||||
|
|
||||||
"github.com/oarkflow/json/jsonparser"
|
"github.com/oarkflow/json/jsonparser"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/codec"
|
"github.com/oarkflow/mq/codec"
|
||||||
"github.com/oarkflow/mq/consts"
|
"github.com/oarkflow/mq/consts"
|
||||||
)
|
)
|
||||||
@@ -77,16 +78,18 @@ func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to broker: %w", err)
|
return fmt.Errorf("failed to connect to broker: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
return p.send(ctx, queue, task, conn, consts.PUBLISH)
|
return p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error {
|
func (p *Publisher) onClose(_ context.Context, conn net.Conn) error {
|
||||||
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
|
func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) {
|
||||||
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,7 +102,9 @@ func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result
|
|||||||
err = fmt.Errorf("failed to connect to broker: %w", err)
|
err = fmt.Errorf("failed to connect to broker: %w", err)
|
||||||
return Result{Error: err}
|
return Result{Error: err}
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||||
resultCh := make(chan Result)
|
resultCh := make(chan Result)
|
||||||
go func() {
|
go func() {
|
||||||
|
11
queue.go
11
queue.go
@@ -45,38 +45,33 @@ func (b *Broker) NewQueue(name string) *Queue {
|
|||||||
type QueueTask struct {
|
type QueueTask struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
payload *Task
|
payload *Task
|
||||||
retryCount int
|
|
||||||
priority int
|
priority int
|
||||||
index int // The index in the heap
|
retryCount int
|
||||||
|
index int
|
||||||
}
|
}
|
||||||
|
|
||||||
type PriorityQueue []*QueueTask
|
type PriorityQueue []*QueueTask
|
||||||
|
|
||||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||||
|
|
||||||
func (pq PriorityQueue) Less(i, j int) bool {
|
func (pq PriorityQueue) Less(i, j int) bool {
|
||||||
return pq[i].priority > pq[j].priority
|
return pq[i].priority > pq[j].priority
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq PriorityQueue) Swap(i, j int) {
|
func (pq PriorityQueue) Swap(i, j int) {
|
||||||
pq[i], pq[j] = pq[j], pq[i]
|
pq[i], pq[j] = pq[j], pq[i]
|
||||||
pq[i].index = i
|
pq[i].index = i
|
||||||
pq[j].index = j
|
pq[j].index = j
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PriorityQueue) Push(x interface{}) {
|
func (pq *PriorityQueue) Push(x interface{}) {
|
||||||
n := len(*pq)
|
n := len(*pq)
|
||||||
task := x.(*QueueTask)
|
task := x.(*QueueTask)
|
||||||
task.index = n
|
task.index = n
|
||||||
*pq = append(*pq, task)
|
*pq = append(*pq, task)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pq *PriorityQueue) Pop() interface{} {
|
func (pq *PriorityQueue) Pop() interface{} {
|
||||||
old := *pq
|
old := *pq
|
||||||
n := len(old)
|
n := len(old)
|
||||||
task := old[n-1]
|
task := old[n-1]
|
||||||
old[n-1] = nil // avoid memory leak
|
task.index = -1
|
||||||
task.index = -1 // for safety
|
|
||||||
*pq = old[0 : n-1]
|
*pq = old[0 : n-1]
|
||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
|
108
scheduler.go
108
scheduler.go
@@ -2,14 +2,19 @@ package mq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var Logger = log.DefaultLogger
|
||||||
|
|
||||||
type ScheduleOptions struct {
|
type ScheduleOptions struct {
|
||||||
Handler Handler
|
Handler Handler
|
||||||
Callback Callback
|
Callback Callback
|
||||||
@@ -20,7 +25,7 @@ type ScheduleOptions struct {
|
|||||||
|
|
||||||
type SchedulerOption func(*ScheduleOptions)
|
type SchedulerOption func(*ScheduleOptions)
|
||||||
|
|
||||||
// Helper functions to create SchedulerOptions
|
// WithSchedulerHandler Helper functions to create SchedulerOptions
|
||||||
func WithSchedulerHandler(handler Handler) SchedulerOption {
|
func WithSchedulerHandler(handler Handler) SchedulerOption {
|
||||||
return func(opts *ScheduleOptions) {
|
return func(opts *ScheduleOptions) {
|
||||||
opts.Handler = handler
|
opts.Handler = handler
|
||||||
@@ -177,17 +182,15 @@ func parseCronSpec(cronSpec string) (CronSchedule, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cronFieldToString(field string, fieldName string) (string, error) {
|
func cronFieldToString(field string, fieldName string) (string, error) {
|
||||||
switch field {
|
if field == "*" {
|
||||||
case "*":
|
|
||||||
return fmt.Sprintf("every %s", fieldName), nil
|
return fmt.Sprintf("every %s", fieldName), nil
|
||||||
default:
|
}
|
||||||
values, err := parseCronValue(field)
|
values, err := parseCronValue(field)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
|
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
|
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func parseCronValue(field string) ([]string, error) {
|
func parseCronValue(field string) ([]string, error) {
|
||||||
var values []string
|
var values []string
|
||||||
@@ -223,6 +226,8 @@ type Scheduler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scheduler) Start() {
|
func (s *Scheduler) Start() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
for _, task := range s.tasks {
|
for _, task := range s.tasks {
|
||||||
go s.schedule(task)
|
go s.schedule(task)
|
||||||
}
|
}
|
||||||
@@ -279,46 +284,81 @@ func (s *Scheduler) schedule(task ScheduledTask) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func startSpan(operation string) (context.Context, func()) {
|
||||||
|
startTime := time.Now()
|
||||||
|
Logger.Info().Str("operation", operation).Msg("Span started")
|
||||||
|
ctx := context.WithValue(context.Background(), "traceID", fmt.Sprintf("%d", startTime.UnixNano()))
|
||||||
|
return ctx, func() {
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
Logger.Info().Str("operation", operation).Msgf("Span ended; duration: %v", duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func acquireDistributedLock(taskID string) bool {
|
||||||
|
Logger.Info().Str("taskID", taskID).Msg("Acquiring distributed lock (stub)")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseDistributedLock(taskID string) {
|
||||||
|
Logger.Info().Str("taskID", taskID).Msg("Releasing distributed lock (stub)")
|
||||||
|
}
|
||||||
|
|
||||||
|
var taskPool = sync.Pool{
|
||||||
|
New: func() interface{} { return new(Task) },
|
||||||
|
}
|
||||||
|
var queueTaskPool = sync.Pool{
|
||||||
|
New: func() interface{} { return new(QueueTask) },
|
||||||
|
}
|
||||||
|
|
||||||
|
func getQueueTask() *QueueTask {
|
||||||
|
return queueTaskPool.Get().(*QueueTask)
|
||||||
|
}
|
||||||
|
|
||||||
// Enhance executeTask with circuit breaker and diagnostics logging support.
|
// Enhance executeTask with circuit breaker and diagnostics logging support.
|
||||||
func (s *Scheduler) executeTask(task ScheduledTask) {
|
func (s *Scheduler) executeTask(task ScheduledTask) {
|
||||||
if !task.config.Overlap && !s.pool.gracefulShutdown {
|
|
||||||
// Prevent overlapping execution if not allowed.
|
|
||||||
// ...existing code...
|
|
||||||
}
|
|
||||||
go func() {
|
go func() {
|
||||||
// Recover from panic to keep scheduler running.
|
_, cancelSpan := startSpan("executeTask")
|
||||||
|
defer cancelSpan()
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
fmt.Printf("Recovered from panic in scheduled task %s: %v\n", task.payload.ID, r)
|
Logger.Error().Str("taskID", task.payload.ID).Msgf("Recovered from panic: %v", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
for _, plug := range s.pool.plugins {
|
||||||
|
plug.BeforeTask(getQueueTask())
|
||||||
|
}
|
||||||
|
if !acquireDistributedLock(task.payload.ID) {
|
||||||
|
Logger.Warn().Str("taskID", task.payload.ID).Msg("Failed to acquire distributed lock")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer releaseDistributedLock(task.payload.ID)
|
||||||
result := task.handler(task.ctx, task.payload)
|
result := task.handler(task.ctx, task.payload)
|
||||||
execTime := time.Since(start).Milliseconds()
|
execTime := time.Since(start).Milliseconds()
|
||||||
// Diagnostics logging if enabled.
|
|
||||||
if s.pool.diagnosticsEnabled {
|
if s.pool.diagnosticsEnabled {
|
||||||
s.pool.logger.Printf("Scheduled task %s executed in %d ms", task.payload.ID, execTime)
|
Logger.Info().Str("taskID", task.payload.ID).Msgf("Executed in %d ms", execTime)
|
||||||
}
|
}
|
||||||
// Circuit breaker check similar to pool.
|
|
||||||
if result.Error != nil && s.pool.circuitBreaker.Enabled {
|
if result.Error != nil && s.pool.circuitBreaker.Enabled {
|
||||||
newCount := atomic.AddInt32(&s.pool.circuitBreakerFailureCount, 1)
|
newCount := atomic.AddInt32(&s.pool.circuitBreakerFailureCount, 1)
|
||||||
if newCount >= int32(s.pool.circuitBreaker.FailureThreshold) {
|
if newCount >= int32(s.pool.circuitBreaker.FailureThreshold) {
|
||||||
s.pool.circuitBreakerOpen = true
|
s.pool.circuitBreakerOpen = true
|
||||||
s.pool.logger.Println("Scheduler: circuit breaker opened due to errors")
|
Logger.Warn().Msg("Circuit breaker opened due to errors")
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(s.pool.circuitBreaker.ResetTimeout)
|
time.Sleep(s.pool.circuitBreaker.ResetTimeout)
|
||||||
atomic.StoreInt32(&s.pool.circuitBreakerFailureCount, 0)
|
atomic.StoreInt32(&s.pool.circuitBreakerFailureCount, 0)
|
||||||
s.pool.circuitBreakerOpen = false
|
s.pool.circuitBreakerOpen = false
|
||||||
s.pool.logger.Println("Scheduler: circuit breaker reset")
|
Logger.Info().Msg("Circuit breaker reset to closed state")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Invoke callback if defined.
|
|
||||||
if task.config.Callback != nil {
|
if task.config.Callback != nil {
|
||||||
_ = task.config.Callback(task.ctx, result)
|
_ = task.config.Callback(task.ctx, result)
|
||||||
}
|
}
|
||||||
task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result})
|
task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result})
|
||||||
fmt.Printf("Executed scheduled task: %s\n", task.payload.ID)
|
for _, plug := range s.pool.plugins {
|
||||||
|
plug.AfterTask(getQueueTask(), result)
|
||||||
|
}
|
||||||
|
Logger.Info().Str("taskID", task.payload.ID).Msg("Scheduled task executed")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,7 +394,7 @@ func (task ScheduledTask) getNextRunTime(now time.Time) time.Time {
|
|||||||
func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
|
func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
|
||||||
cronSpecs, err := parseCronSpec(task.schedule.CronSpec)
|
cronSpecs, err := parseCronSpec(task.schedule.CronSpec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(fmt.Sprintf("Invalid CRON spec format: %s", err.Error()))
|
Logger.Error().Err(err).Msg("Invalid CRON spec")
|
||||||
return now
|
return now
|
||||||
}
|
}
|
||||||
nextRun := now
|
nextRun := now
|
||||||
@@ -367,10 +407,9 @@ func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time {
|
func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time {
|
||||||
switch fieldSpec {
|
if fieldSpec == "*" {
|
||||||
case "*":
|
|
||||||
return t
|
return t
|
||||||
default:
|
}
|
||||||
value, _ := strconv.Atoi(fieldSpec)
|
value, _ := strconv.Atoi(fieldSpec)
|
||||||
switch unit {
|
switch unit {
|
||||||
case "minute":
|
case "minute":
|
||||||
@@ -401,7 +440,6 @@ func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit str
|
|||||||
}
|
}
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
|
func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
|
||||||
daysUntil := (int(weekday) - int(t.Weekday()) + 7) % 7
|
daysUntil := (int(weekday) - int(t.Weekday()) + 7) % 7
|
||||||
@@ -410,12 +448,21 @@ func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
|
|||||||
}
|
}
|
||||||
return t.AddDate(0, 0, daysUntil)
|
return t.AddDate(0, 0, daysUntil)
|
||||||
}
|
}
|
||||||
|
func validateTaskInput(task *Task) error {
|
||||||
|
if task.Payload == nil {
|
||||||
|
return errors.New("task payload cannot be nil")
|
||||||
|
}
|
||||||
|
Logger.Info().Str("taskID", task.ID).Msg("Task validated")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) {
|
func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
if err := validateTaskInput(payload); err != nil {
|
||||||
// Create a default options instance
|
Logger.Error().Err(err).Msg("Invalid task input")
|
||||||
|
return
|
||||||
|
}
|
||||||
options := defaultSchedulerOptions()
|
options := defaultSchedulerOptions()
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(options)
|
opt(options)
|
||||||
@@ -427,9 +474,7 @@ func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...Schedule
|
|||||||
options.Callback = s.pool.callback
|
options.Callback = s.pool.callback
|
||||||
}
|
}
|
||||||
stop := make(chan struct{})
|
stop := make(chan struct{})
|
||||||
|
newTask := ScheduledTask{
|
||||||
// Create a new ScheduledTask using the provided options
|
|
||||||
s.tasks = append(s.tasks, ScheduledTask{
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
handler: options.Handler,
|
handler: options.Handler,
|
||||||
payload: payload,
|
payload: payload,
|
||||||
@@ -442,10 +487,9 @@ func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...Schedule
|
|||||||
Interval: options.Interval,
|
Interval: options.Interval,
|
||||||
Recurring: options.Recurring,
|
Recurring: options.Recurring,
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
s.tasks = append(s.tasks, newTask)
|
||||||
// Start scheduling the task
|
go s.schedule(newTask)
|
||||||
go s.schedule(s.tasks[len(s.tasks)-1])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scheduler) RemoveTask(payloadID string) {
|
func (s *Scheduler) RemoveTask(payloadID string) {
|
||||||
|
Reference in New Issue
Block a user