diff --git a/consumer.go b/consumer.go index 9947645..05a28e1 100644 --- a/consumer.go +++ b/consumer.go @@ -140,7 +140,7 @@ func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn return } ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) - if err := c.pool.AddTask(ctx, &task); err != nil { + if err := c.pool.AddTask(ctx, &task, 1); err != nil { c.sendDenyMessage(ctx, task.ID, msg.Queue, err) return } diff --git a/dag/dag.go b/dag/dag.go index 1a5db9a..843cef7 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -357,7 +357,7 @@ func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { if ok { ctxx = mq.SetHeaders(ctxx, headers.AsMap()) } - if err := tm.pool.AddTask(ctxx, task); err != nil { + if err := tm.pool.AddTask(ctxx, task, 1); err != nil { return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "FAILED", Error: err} } return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"} diff --git a/examples/priority.go b/examples/priority.go new file mode 100644 index 0000000..6efc770 --- /dev/null +++ b/examples/priority.go @@ -0,0 +1,22 @@ +package main + +import ( + "context" + "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/examples/tasks" +) + +func main() { + pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback) + + // Add tasks with different priorities + pool.AddTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1) // Lowest priority + pool.AddTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5) // Medium priority + pool.AddTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10) // Highest priority + // Let tasks run + time.Sleep(5 * time.Second) + pool.PrintMetrics() + pool.Stop() +} diff --git a/examples/scheduler.go b/examples/scheduler.go new file mode 100644 index 0000000..d01ba8c --- /dev/null +++ b/examples/scheduler.go @@ -0,0 +1,44 @@ +package main + +import ( + "context" + "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/examples/tasks" +) + +func main() { + pool := mq.NewPool(3, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback) + pool.AddTask(context.Background(), &mq.Task{ID: "Task 1"}, 1) + pool.AddTask(context.Background(), &mq.Task{ID: "Task 2"}, 5) + pool.AddTask(context.Background(), &mq.Task{ID: "Task 3"}, 3) + + // Adding scheduled tasks + pool.Scheduler.AddTask( + context.Background(), + tasks.SchedulerHandler, + &mq.Task{ID: "Scheduled Task 1"}, + 3*time.Second, + mq.SchedulerConfig{WithCallback: tasks.SchedulerCallback, WithOverlap: true}, + ) + + pool.Scheduler.AddTask( + context.Background(), + tasks.SchedulerHandler, + &mq.Task{ID: "Scheduled Task 2"}, + 5*time.Second, + mq.SchedulerConfig{WithCallback: tasks.SchedulerCallback, WithOverlap: false}, + ) + + // Let tasks run for a while + time.Sleep(10 * time.Second) + + // Removing a scheduled task + pool.Scheduler.RemoveTask("Scheduled Task 1") + + // Let remaining tasks run for a bit + time.Sleep(5 * time.Second) + pool.PrintMetrics() + pool.Stop() +} diff --git a/examples/tasks/scheduler.go b/examples/tasks/scheduler.go new file mode 100644 index 0000000..4392094 --- /dev/null +++ b/examples/tasks/scheduler.go @@ -0,0 +1,22 @@ +package tasks + +import ( + "context" + "fmt" + + "github.com/oarkflow/mq" +) + +func SchedulerHandler(ctx context.Context, task *mq.Task) mq.Result { + fmt.Printf("Processing task: %s\n", task.ID) + return mq.Result{Error: nil} +} + +func SchedulerCallback(ctx context.Context, result mq.Result) error { + if result.Error != nil { + fmt.Println("Task failed!") + } else { + fmt.Println("Task completed successfully.") + } + return nil +} diff --git a/pool.go b/pool.go index 8f8f755..2560bc5 100644 --- a/pool.go +++ b/pool.go @@ -1,34 +1,139 @@ package mq import ( + "container/heap" "context" "fmt" "sync" "sync/atomic" + "time" "github.com/oarkflow/mq/utils" ) type QueueTask struct { - ctx context.Context - payload *Task + ctx context.Context + payload *Task + priority int +} + +// PriorityQueue implements heap.Interface and holds QueueTasks. +type PriorityQueue []*QueueTask + +func (pq PriorityQueue) Len() int { return len(pq) } + +func (pq PriorityQueue) Less(i, j int) bool { + return pq[i].priority > pq[j].priority // Higher priority first +} + +func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } + +func (pq *PriorityQueue) Push(x interface{}) { + item := x.(*QueueTask) + *pq = append(*pq, item) +} + +func (pq *PriorityQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + *pq = old[0 : n-1] + return item } type Callback func(ctx context.Context, result Result) error +type SchedulerConfig struct { + WithCallback Callback + WithOverlap bool // true allows overlapping, false waits for previous execution to complete +} + +type ScheduledTask struct { + ctx context.Context + handler Handler + payload *Task + interval time.Duration + config SchedulerConfig + stop chan struct{} + execution chan struct{} // Channel to signal task execution status +} + +// Scheduler manages scheduled tasks. +type Scheduler struct { + tasks []ScheduledTask + mu sync.Mutex +} + +func (s *Scheduler) Start() { + for _, task := range s.tasks { + go s.schedule(task) + } +} + +func (s *Scheduler) schedule(task ScheduledTask) { + ticker := time.NewTicker(task.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if task.config.WithOverlap || len(task.execution) == 0 { // Check if task can execute + task.execution <- struct{}{} + go func() { + defer func() { <-task.execution }() // Free the channel for the next execution + result := task.handler(task.ctx, task.payload) + if task.config.WithCallback != nil { + task.config.WithCallback(task.ctx, result) + } + fmt.Printf("Executed scheduled task: %s\n", task.payload.ID) + }() + } + case <-task.stop: + return + } + } +} + +func (s *Scheduler) AddTask(ctx context.Context, handler Handler, payload *Task, interval time.Duration, config SchedulerConfig) { + s.mu.Lock() + defer s.mu.Unlock() + + stop := make(chan struct{}) + execution := make(chan struct{}, 1) // Buffer for one execution signal + s.tasks = append(s.tasks, ScheduledTask{ctx: ctx, handler: handler, payload: payload, interval: interval, stop: stop, execution: execution, config: config}) + go s.schedule(s.tasks[len(s.tasks)-1]) // Start scheduling immediately +} + +func (s *Scheduler) RemoveTask(payloadID string) { + s.mu.Lock() + defer s.mu.Unlock() + + for i, task := range s.tasks { + if task.payload.ID == payloadID { + close(task.stop) // Stop the task + // Remove the task from the slice + s.tasks = append(s.tasks[:i], s.tasks[i+1:]...) + break + } + } +} + type Pool struct { - taskQueue chan QueueTask + taskQueue PriorityQueue + taskQueueLock sync.Mutex stop chan struct{} handler Handler callback Callback - workerAdjust chan int // Channel for adjusting workers dynamically + workerAdjust chan int wg sync.WaitGroup totalMemoryUsed int64 completedTasks int errorCount, maxMemoryLoad int64 totalTasks int - numOfWorkers int32 // Change to int32 for atomic operations + numOfWorkers int32 paused bool + Scheduler Scheduler + totalScheduledTasks int } func NewPool( @@ -37,14 +142,15 @@ func NewPool( handler Handler, callback Callback) *Pool { pool := &Pool{ - numOfWorkers: int32(numOfWorkers), - taskQueue: make(chan QueueTask, taskQueueSize), + taskQueue: make(PriorityQueue, 0, taskQueueSize), stop: make(chan struct{}), maxMemoryLoad: maxMemoryLoad, handler: handler, callback: callback, workerAdjust: make(chan int), } + heap.Init(&pool.taskQueue) // Initialize the priority queue as a heap + pool.Scheduler.Start() // Start the scheduler pool.Start(numOfWorkers) return pool } @@ -62,32 +168,37 @@ func (wp *Pool) worker() { defer wp.wg.Done() for { select { - case task := <-wp.taskQueue: - if wp.paused { - continue - } - taskSize := int64(utils.SizeOf(task.payload)) - wp.totalMemoryUsed += taskSize - wp.totalTasks++ - - result := wp.handler(task.ctx, task.payload) - - if result.Error != nil { - wp.errorCount++ - } else { - wp.completedTasks++ - } - - if wp.callback != nil { - if err := wp.callback(task.ctx, result); err != nil { - wp.errorCount++ - } - } - - wp.totalMemoryUsed -= taskSize - case <-wp.stop: return + default: + wp.taskQueueLock.Lock() + if len(wp.taskQueue) > 0 && !wp.paused { + // Pop the highest priority task + task := heap.Pop(&wp.taskQueue).(*QueueTask) + wp.taskQueueLock.Unlock() + + taskSize := int64(utils.SizeOf(task.payload)) + wp.totalMemoryUsed += taskSize + wp.totalTasks++ + + result := wp.handler(task.ctx, task.payload) + + if result.Error != nil { + wp.errorCount++ + } else { + wp.completedTasks++ + } + + if wp.callback != nil { + if err := wp.callback(task.ctx, result); err != nil { + wp.errorCount++ + } + } + + wp.totalMemoryUsed -= taskSize + } else { + wp.taskQueueLock.Unlock() + } } } } @@ -122,19 +233,19 @@ func (wp *Pool) adjustWorkers(newWorkerCount int) { atomic.StoreInt32(&wp.numOfWorkers, int32(newWorkerCount)) } -func (wp *Pool) AddTask(ctx context.Context, payload *Task) error { - task := QueueTask{ctx: ctx, payload: payload} +func (wp *Pool) AddTask(ctx context.Context, payload *Task, priority int) error { + task := &QueueTask{ctx: ctx, payload: payload, priority: priority} taskSize := int64(utils.SizeOf(payload)) + if wp.totalMemoryUsed+taskSize > wp.maxMemoryLoad && wp.maxMemoryLoad > 0 { return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize) } - select { - case wp.taskQueue <- task: - return nil - default: - return fmt.Errorf("task queue is full, cannot add task") - } + wp.taskQueueLock.Lock() + heap.Push(&wp.taskQueue, task) // Add task to priority queue + wp.taskQueueLock.Unlock() + + return nil } func (wp *Pool) Pause() { @@ -158,6 +269,6 @@ func (wp *Pool) AdjustWorkerCount(newWorkerCount int) { } func (wp *Pool) PrintMetrics() { - fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes\n", - wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed) + fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes, Scheduled Tasks: %d\n", + wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed, len(wp.Scheduler.tasks)) }