mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-06 00:16:49 +08:00
feat: [wip] - implement scheduler
This commit is contained in:
@@ -140,7 +140,7 @@ func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn
|
||||
return
|
||||
}
|
||||
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
||||
if err := c.pool.AddTask(ctx, &task); err != nil {
|
||||
if err := c.pool.AddTask(ctx, &task, 1); err != nil {
|
||||
c.sendDenyMessage(ctx, task.ID, msg.Queue, err)
|
||||
return
|
||||
}
|
||||
|
@@ -357,7 +357,7 @@ func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
|
||||
if ok {
|
||||
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
||||
}
|
||||
if err := tm.pool.AddTask(ctxx, task); err != nil {
|
||||
if err := tm.pool.AddTask(ctxx, task, 1); err != nil {
|
||||
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "FAILED", Error: err}
|
||||
}
|
||||
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"}
|
||||
|
22
examples/priority.go
Normal file
22
examples/priority.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
)
|
||||
|
||||
func main() {
|
||||
pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback)
|
||||
|
||||
// Add tasks with different priorities
|
||||
pool.AddTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1) // Lowest priority
|
||||
pool.AddTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5) // Medium priority
|
||||
pool.AddTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10) // Highest priority
|
||||
// Let tasks run
|
||||
time.Sleep(5 * time.Second)
|
||||
pool.PrintMetrics()
|
||||
pool.Stop()
|
||||
}
|
44
examples/scheduler.go
Normal file
44
examples/scheduler.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
)
|
||||
|
||||
func main() {
|
||||
pool := mq.NewPool(3, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback)
|
||||
pool.AddTask(context.Background(), &mq.Task{ID: "Task 1"}, 1)
|
||||
pool.AddTask(context.Background(), &mq.Task{ID: "Task 2"}, 5)
|
||||
pool.AddTask(context.Background(), &mq.Task{ID: "Task 3"}, 3)
|
||||
|
||||
// Adding scheduled tasks
|
||||
pool.Scheduler.AddTask(
|
||||
context.Background(),
|
||||
tasks.SchedulerHandler,
|
||||
&mq.Task{ID: "Scheduled Task 1"},
|
||||
3*time.Second,
|
||||
mq.SchedulerConfig{WithCallback: tasks.SchedulerCallback, WithOverlap: true},
|
||||
)
|
||||
|
||||
pool.Scheduler.AddTask(
|
||||
context.Background(),
|
||||
tasks.SchedulerHandler,
|
||||
&mq.Task{ID: "Scheduled Task 2"},
|
||||
5*time.Second,
|
||||
mq.SchedulerConfig{WithCallback: tasks.SchedulerCallback, WithOverlap: false},
|
||||
)
|
||||
|
||||
// Let tasks run for a while
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Removing a scheduled task
|
||||
pool.Scheduler.RemoveTask("Scheduled Task 1")
|
||||
|
||||
// Let remaining tasks run for a bit
|
||||
time.Sleep(5 * time.Second)
|
||||
pool.PrintMetrics()
|
||||
pool.Stop()
|
||||
}
|
22
examples/tasks/scheduler.go
Normal file
22
examples/tasks/scheduler.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package tasks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
func SchedulerHandler(ctx context.Context, task *mq.Task) mq.Result {
|
||||
fmt.Printf("Processing task: %s\n", task.ID)
|
||||
return mq.Result{Error: nil}
|
||||
}
|
||||
|
||||
func SchedulerCallback(ctx context.Context, result mq.Result) error {
|
||||
if result.Error != nil {
|
||||
fmt.Println("Task failed!")
|
||||
} else {
|
||||
fmt.Println("Task completed successfully.")
|
||||
}
|
||||
return nil
|
||||
}
|
193
pool.go
193
pool.go
@@ -1,34 +1,139 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
type QueueTask struct {
|
||||
ctx context.Context
|
||||
payload *Task
|
||||
ctx context.Context
|
||||
payload *Task
|
||||
priority int
|
||||
}
|
||||
|
||||
// PriorityQueue implements heap.Interface and holds QueueTasks.
|
||||
type PriorityQueue []*QueueTask
|
||||
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
return pq[i].priority > pq[j].priority // Higher priority first
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
|
||||
|
||||
func (pq *PriorityQueue) Push(x interface{}) {
|
||||
item := x.(*QueueTask)
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() interface{} {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*pq = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
type Callback func(ctx context.Context, result Result) error
|
||||
|
||||
type SchedulerConfig struct {
|
||||
WithCallback Callback
|
||||
WithOverlap bool // true allows overlapping, false waits for previous execution to complete
|
||||
}
|
||||
|
||||
type ScheduledTask struct {
|
||||
ctx context.Context
|
||||
handler Handler
|
||||
payload *Task
|
||||
interval time.Duration
|
||||
config SchedulerConfig
|
||||
stop chan struct{}
|
||||
execution chan struct{} // Channel to signal task execution status
|
||||
}
|
||||
|
||||
// Scheduler manages scheduled tasks.
|
||||
type Scheduler struct {
|
||||
tasks []ScheduledTask
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *Scheduler) Start() {
|
||||
for _, task := range s.tasks {
|
||||
go s.schedule(task)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) schedule(task ScheduledTask) {
|
||||
ticker := time.NewTicker(task.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if task.config.WithOverlap || len(task.execution) == 0 { // Check if task can execute
|
||||
task.execution <- struct{}{}
|
||||
go func() {
|
||||
defer func() { <-task.execution }() // Free the channel for the next execution
|
||||
result := task.handler(task.ctx, task.payload)
|
||||
if task.config.WithCallback != nil {
|
||||
task.config.WithCallback(task.ctx, result)
|
||||
}
|
||||
fmt.Printf("Executed scheduled task: %s\n", task.payload.ID)
|
||||
}()
|
||||
}
|
||||
case <-task.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) AddTask(ctx context.Context, handler Handler, payload *Task, interval time.Duration, config SchedulerConfig) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
stop := make(chan struct{})
|
||||
execution := make(chan struct{}, 1) // Buffer for one execution signal
|
||||
s.tasks = append(s.tasks, ScheduledTask{ctx: ctx, handler: handler, payload: payload, interval: interval, stop: stop, execution: execution, config: config})
|
||||
go s.schedule(s.tasks[len(s.tasks)-1]) // Start scheduling immediately
|
||||
}
|
||||
|
||||
func (s *Scheduler) RemoveTask(payloadID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for i, task := range s.tasks {
|
||||
if task.payload.ID == payloadID {
|
||||
close(task.stop) // Stop the task
|
||||
// Remove the task from the slice
|
||||
s.tasks = append(s.tasks[:i], s.tasks[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Pool struct {
|
||||
taskQueue chan QueueTask
|
||||
taskQueue PriorityQueue
|
||||
taskQueueLock sync.Mutex
|
||||
stop chan struct{}
|
||||
handler Handler
|
||||
callback Callback
|
||||
workerAdjust chan int // Channel for adjusting workers dynamically
|
||||
workerAdjust chan int
|
||||
wg sync.WaitGroup
|
||||
totalMemoryUsed int64
|
||||
completedTasks int
|
||||
errorCount, maxMemoryLoad int64
|
||||
totalTasks int
|
||||
numOfWorkers int32 // Change to int32 for atomic operations
|
||||
numOfWorkers int32
|
||||
paused bool
|
||||
Scheduler Scheduler
|
||||
totalScheduledTasks int
|
||||
}
|
||||
|
||||
func NewPool(
|
||||
@@ -37,14 +142,15 @@ func NewPool(
|
||||
handler Handler,
|
||||
callback Callback) *Pool {
|
||||
pool := &Pool{
|
||||
numOfWorkers: int32(numOfWorkers),
|
||||
taskQueue: make(chan QueueTask, taskQueueSize),
|
||||
taskQueue: make(PriorityQueue, 0, taskQueueSize),
|
||||
stop: make(chan struct{}),
|
||||
maxMemoryLoad: maxMemoryLoad,
|
||||
handler: handler,
|
||||
callback: callback,
|
||||
workerAdjust: make(chan int),
|
||||
}
|
||||
heap.Init(&pool.taskQueue) // Initialize the priority queue as a heap
|
||||
pool.Scheduler.Start() // Start the scheduler
|
||||
pool.Start(numOfWorkers)
|
||||
return pool
|
||||
}
|
||||
@@ -62,32 +168,37 @@ func (wp *Pool) worker() {
|
||||
defer wp.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case task := <-wp.taskQueue:
|
||||
if wp.paused {
|
||||
continue
|
||||
}
|
||||
taskSize := int64(utils.SizeOf(task.payload))
|
||||
wp.totalMemoryUsed += taskSize
|
||||
wp.totalTasks++
|
||||
|
||||
result := wp.handler(task.ctx, task.payload)
|
||||
|
||||
if result.Error != nil {
|
||||
wp.errorCount++
|
||||
} else {
|
||||
wp.completedTasks++
|
||||
}
|
||||
|
||||
if wp.callback != nil {
|
||||
if err := wp.callback(task.ctx, result); err != nil {
|
||||
wp.errorCount++
|
||||
}
|
||||
}
|
||||
|
||||
wp.totalMemoryUsed -= taskSize
|
||||
|
||||
case <-wp.stop:
|
||||
return
|
||||
default:
|
||||
wp.taskQueueLock.Lock()
|
||||
if len(wp.taskQueue) > 0 && !wp.paused {
|
||||
// Pop the highest priority task
|
||||
task := heap.Pop(&wp.taskQueue).(*QueueTask)
|
||||
wp.taskQueueLock.Unlock()
|
||||
|
||||
taskSize := int64(utils.SizeOf(task.payload))
|
||||
wp.totalMemoryUsed += taskSize
|
||||
wp.totalTasks++
|
||||
|
||||
result := wp.handler(task.ctx, task.payload)
|
||||
|
||||
if result.Error != nil {
|
||||
wp.errorCount++
|
||||
} else {
|
||||
wp.completedTasks++
|
||||
}
|
||||
|
||||
if wp.callback != nil {
|
||||
if err := wp.callback(task.ctx, result); err != nil {
|
||||
wp.errorCount++
|
||||
}
|
||||
}
|
||||
|
||||
wp.totalMemoryUsed -= taskSize
|
||||
} else {
|
||||
wp.taskQueueLock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -122,19 +233,19 @@ func (wp *Pool) adjustWorkers(newWorkerCount int) {
|
||||
atomic.StoreInt32(&wp.numOfWorkers, int32(newWorkerCount))
|
||||
}
|
||||
|
||||
func (wp *Pool) AddTask(ctx context.Context, payload *Task) error {
|
||||
task := QueueTask{ctx: ctx, payload: payload}
|
||||
func (wp *Pool) AddTask(ctx context.Context, payload *Task, priority int) error {
|
||||
task := &QueueTask{ctx: ctx, payload: payload, priority: priority}
|
||||
taskSize := int64(utils.SizeOf(payload))
|
||||
|
||||
if wp.totalMemoryUsed+taskSize > wp.maxMemoryLoad && wp.maxMemoryLoad > 0 {
|
||||
return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize)
|
||||
}
|
||||
|
||||
select {
|
||||
case wp.taskQueue <- task:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("task queue is full, cannot add task")
|
||||
}
|
||||
wp.taskQueueLock.Lock()
|
||||
heap.Push(&wp.taskQueue, task) // Add task to priority queue
|
||||
wp.taskQueueLock.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wp *Pool) Pause() {
|
||||
@@ -158,6 +269,6 @@ func (wp *Pool) AdjustWorkerCount(newWorkerCount int) {
|
||||
}
|
||||
|
||||
func (wp *Pool) PrintMetrics() {
|
||||
fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes\n",
|
||||
wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed)
|
||||
fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes, Scheduled Tasks: %d\n",
|
||||
wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed, len(wp.Scheduler.tasks))
|
||||
}
|
||||
|
Reference in New Issue
Block a user