feat: [wip] - implement scheduler

This commit is contained in:
sujit
2024-10-15 22:12:41 +05:45
parent 6d2b4e6df7
commit bc5aa4c6f5
6 changed files with 242 additions and 43 deletions

View File

@@ -140,7 +140,7 @@ func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn
return return
} }
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
if err := c.pool.AddTask(ctx, &task); err != nil { if err := c.pool.AddTask(ctx, &task, 1); err != nil {
c.sendDenyMessage(ctx, task.ID, msg.Queue, err) c.sendDenyMessage(ctx, task.ID, msg.Queue, err)
return return
} }

View File

@@ -357,7 +357,7 @@ func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
if ok { if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap()) ctxx = mq.SetHeaders(ctxx, headers.AsMap())
} }
if err := tm.pool.AddTask(ctxx, task); err != nil { if err := tm.pool.AddTask(ctxx, task, 1); err != nil {
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "FAILED", Error: err} return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "FAILED", Error: err}
} }
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"} return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"}

22
examples/priority.go Normal file
View File

@@ -0,0 +1,22 @@
package main
import (
"context"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks"
)
func main() {
pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback)
// Add tasks with different priorities
pool.AddTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1) // Lowest priority
pool.AddTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5) // Medium priority
pool.AddTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10) // Highest priority
// Let tasks run
time.Sleep(5 * time.Second)
pool.PrintMetrics()
pool.Stop()
}

44
examples/scheduler.go Normal file
View File

@@ -0,0 +1,44 @@
package main
import (
"context"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks"
)
func main() {
pool := mq.NewPool(3, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback)
pool.AddTask(context.Background(), &mq.Task{ID: "Task 1"}, 1)
pool.AddTask(context.Background(), &mq.Task{ID: "Task 2"}, 5)
pool.AddTask(context.Background(), &mq.Task{ID: "Task 3"}, 3)
// Adding scheduled tasks
pool.Scheduler.AddTask(
context.Background(),
tasks.SchedulerHandler,
&mq.Task{ID: "Scheduled Task 1"},
3*time.Second,
mq.SchedulerConfig{WithCallback: tasks.SchedulerCallback, WithOverlap: true},
)
pool.Scheduler.AddTask(
context.Background(),
tasks.SchedulerHandler,
&mq.Task{ID: "Scheduled Task 2"},
5*time.Second,
mq.SchedulerConfig{WithCallback: tasks.SchedulerCallback, WithOverlap: false},
)
// Let tasks run for a while
time.Sleep(10 * time.Second)
// Removing a scheduled task
pool.Scheduler.RemoveTask("Scheduled Task 1")
// Let remaining tasks run for a bit
time.Sleep(5 * time.Second)
pool.PrintMetrics()
pool.Stop()
}

View File

@@ -0,0 +1,22 @@
package tasks
import (
"context"
"fmt"
"github.com/oarkflow/mq"
)
func SchedulerHandler(ctx context.Context, task *mq.Task) mq.Result {
fmt.Printf("Processing task: %s\n", task.ID)
return mq.Result{Error: nil}
}
func SchedulerCallback(ctx context.Context, result mq.Result) error {
if result.Error != nil {
fmt.Println("Task failed!")
} else {
fmt.Println("Task completed successfully.")
}
return nil
}

193
pool.go
View File

@@ -1,34 +1,139 @@
package mq package mq
import ( import (
"container/heap"
"context" "context"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/oarkflow/mq/utils" "github.com/oarkflow/mq/utils"
) )
type QueueTask struct { type QueueTask struct {
ctx context.Context ctx context.Context
payload *Task payload *Task
priority int
}
// PriorityQueue implements heap.Interface and holds QueueTasks.
type PriorityQueue []*QueueTask
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
return pq[i].priority > pq[j].priority // Higher priority first
}
func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
func (pq *PriorityQueue) Push(x interface{}) {
item := x.(*QueueTask)
*pq = append(*pq, item)
}
func (pq *PriorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
*pq = old[0 : n-1]
return item
} }
type Callback func(ctx context.Context, result Result) error type Callback func(ctx context.Context, result Result) error
type SchedulerConfig struct {
WithCallback Callback
WithOverlap bool // true allows overlapping, false waits for previous execution to complete
}
type ScheduledTask struct {
ctx context.Context
handler Handler
payload *Task
interval time.Duration
config SchedulerConfig
stop chan struct{}
execution chan struct{} // Channel to signal task execution status
}
// Scheduler manages scheduled tasks.
type Scheduler struct {
tasks []ScheduledTask
mu sync.Mutex
}
func (s *Scheduler) Start() {
for _, task := range s.tasks {
go s.schedule(task)
}
}
func (s *Scheduler) schedule(task ScheduledTask) {
ticker := time.NewTicker(task.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if task.config.WithOverlap || len(task.execution) == 0 { // Check if task can execute
task.execution <- struct{}{}
go func() {
defer func() { <-task.execution }() // Free the channel for the next execution
result := task.handler(task.ctx, task.payload)
if task.config.WithCallback != nil {
task.config.WithCallback(task.ctx, result)
}
fmt.Printf("Executed scheduled task: %s\n", task.payload.ID)
}()
}
case <-task.stop:
return
}
}
}
func (s *Scheduler) AddTask(ctx context.Context, handler Handler, payload *Task, interval time.Duration, config SchedulerConfig) {
s.mu.Lock()
defer s.mu.Unlock()
stop := make(chan struct{})
execution := make(chan struct{}, 1) // Buffer for one execution signal
s.tasks = append(s.tasks, ScheduledTask{ctx: ctx, handler: handler, payload: payload, interval: interval, stop: stop, execution: execution, config: config})
go s.schedule(s.tasks[len(s.tasks)-1]) // Start scheduling immediately
}
func (s *Scheduler) RemoveTask(payloadID string) {
s.mu.Lock()
defer s.mu.Unlock()
for i, task := range s.tasks {
if task.payload.ID == payloadID {
close(task.stop) // Stop the task
// Remove the task from the slice
s.tasks = append(s.tasks[:i], s.tasks[i+1:]...)
break
}
}
}
type Pool struct { type Pool struct {
taskQueue chan QueueTask taskQueue PriorityQueue
taskQueueLock sync.Mutex
stop chan struct{} stop chan struct{}
handler Handler handler Handler
callback Callback callback Callback
workerAdjust chan int // Channel for adjusting workers dynamically workerAdjust chan int
wg sync.WaitGroup wg sync.WaitGroup
totalMemoryUsed int64 totalMemoryUsed int64
completedTasks int completedTasks int
errorCount, maxMemoryLoad int64 errorCount, maxMemoryLoad int64
totalTasks int totalTasks int
numOfWorkers int32 // Change to int32 for atomic operations numOfWorkers int32
paused bool paused bool
Scheduler Scheduler
totalScheduledTasks int
} }
func NewPool( func NewPool(
@@ -37,14 +142,15 @@ func NewPool(
handler Handler, handler Handler,
callback Callback) *Pool { callback Callback) *Pool {
pool := &Pool{ pool := &Pool{
numOfWorkers: int32(numOfWorkers), taskQueue: make(PriorityQueue, 0, taskQueueSize),
taskQueue: make(chan QueueTask, taskQueueSize),
stop: make(chan struct{}), stop: make(chan struct{}),
maxMemoryLoad: maxMemoryLoad, maxMemoryLoad: maxMemoryLoad,
handler: handler, handler: handler,
callback: callback, callback: callback,
workerAdjust: make(chan int), workerAdjust: make(chan int),
} }
heap.Init(&pool.taskQueue) // Initialize the priority queue as a heap
pool.Scheduler.Start() // Start the scheduler
pool.Start(numOfWorkers) pool.Start(numOfWorkers)
return pool return pool
} }
@@ -62,32 +168,37 @@ func (wp *Pool) worker() {
defer wp.wg.Done() defer wp.wg.Done()
for { for {
select { select {
case task := <-wp.taskQueue:
if wp.paused {
continue
}
taskSize := int64(utils.SizeOf(task.payload))
wp.totalMemoryUsed += taskSize
wp.totalTasks++
result := wp.handler(task.ctx, task.payload)
if result.Error != nil {
wp.errorCount++
} else {
wp.completedTasks++
}
if wp.callback != nil {
if err := wp.callback(task.ctx, result); err != nil {
wp.errorCount++
}
}
wp.totalMemoryUsed -= taskSize
case <-wp.stop: case <-wp.stop:
return return
default:
wp.taskQueueLock.Lock()
if len(wp.taskQueue) > 0 && !wp.paused {
// Pop the highest priority task
task := heap.Pop(&wp.taskQueue).(*QueueTask)
wp.taskQueueLock.Unlock()
taskSize := int64(utils.SizeOf(task.payload))
wp.totalMemoryUsed += taskSize
wp.totalTasks++
result := wp.handler(task.ctx, task.payload)
if result.Error != nil {
wp.errorCount++
} else {
wp.completedTasks++
}
if wp.callback != nil {
if err := wp.callback(task.ctx, result); err != nil {
wp.errorCount++
}
}
wp.totalMemoryUsed -= taskSize
} else {
wp.taskQueueLock.Unlock()
}
} }
} }
} }
@@ -122,19 +233,19 @@ func (wp *Pool) adjustWorkers(newWorkerCount int) {
atomic.StoreInt32(&wp.numOfWorkers, int32(newWorkerCount)) atomic.StoreInt32(&wp.numOfWorkers, int32(newWorkerCount))
} }
func (wp *Pool) AddTask(ctx context.Context, payload *Task) error { func (wp *Pool) AddTask(ctx context.Context, payload *Task, priority int) error {
task := QueueTask{ctx: ctx, payload: payload} task := &QueueTask{ctx: ctx, payload: payload, priority: priority}
taskSize := int64(utils.SizeOf(payload)) taskSize := int64(utils.SizeOf(payload))
if wp.totalMemoryUsed+taskSize > wp.maxMemoryLoad && wp.maxMemoryLoad > 0 { if wp.totalMemoryUsed+taskSize > wp.maxMemoryLoad && wp.maxMemoryLoad > 0 {
return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize) return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize)
} }
select { wp.taskQueueLock.Lock()
case wp.taskQueue <- task: heap.Push(&wp.taskQueue, task) // Add task to priority queue
return nil wp.taskQueueLock.Unlock()
default:
return fmt.Errorf("task queue is full, cannot add task") return nil
}
} }
func (wp *Pool) Pause() { func (wp *Pool) Pause() {
@@ -158,6 +269,6 @@ func (wp *Pool) AdjustWorkerCount(newWorkerCount int) {
} }
func (wp *Pool) PrintMetrics() { func (wp *Pool) PrintMetrics() {
fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes\n", fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes, Scheduled Tasks: %d\n",
wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed) wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed, len(wp.Scheduler.tasks))
} }