diff --git a/consumer.go b/consumer.go index 32b720b..541c288 100644 --- a/consumer.go +++ b/consumer.go @@ -232,7 +232,14 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } - c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.opts.storage) + c.pool = NewPool( + c.opts.numOfWorkers, + WithTaskQueueSize(c.opts.queueSize), + WithMaxMemoryLoad(c.opts.maxMemoryLoad), + WithHandler(c.ProcessTask), + WithPoolCallback(c.OnResponse), + WithTaskStorage(c.opts.storage), + ) if err := c.subscribe(ctx, c.queue); err != nil { return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) } diff --git a/dag/dag.go b/dag/dag.go index 48d9303..72435cb 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -3,12 +3,13 @@ package dag import ( "context" "fmt" - "github.com/oarkflow/mq/sio" "log" "net/http" "sync" "time" + "github.com/oarkflow/mq/sio" + "golang.org/x/time/rate" "github.com/oarkflow/mq" @@ -143,7 +144,14 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG { d.server = mq.NewBroker(opts...) d.opts = opts options := d.server.Options() - d.pool = mq.NewPool(options.NumOfWorkers(), options.QueueSize(), options.MaxMemoryLoad(), d.ProcessTask, callback, options.Storage()) + d.pool = mq.NewPool( + options.NumOfWorkers(), + mq.WithTaskQueueSize(options.QueueSize()), + mq.WithMaxMemoryLoad(options.MaxMemoryLoad()), + mq.WithHandler(d.ProcessTask), + mq.WithPoolCallback(callback), + mq.WithTaskStorage(options.Storage()), + ) d.pool.Start(d.server.Options().NumOfWorkers()) go d.listenForTaskCleanup() return d diff --git a/examples/priority.go b/examples/priority.go index 1140531..f7c89b1 100644 --- a/examples/priority.go +++ b/examples/priority.go @@ -9,7 +9,13 @@ import ( ) func main() { - pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback, mq.NewMemoryTaskStorage(10*time.Minute)) + pool := mq.NewPool(2, + mq.WithTaskQueueSize(5), + mq.WithMaxMemoryLoad(1000), + mq.WithHandler(tasks.SchedulerHandler), + mq.WithPoolCallback(tasks.SchedulerCallback), + mq.WithTaskStorage(mq.NewMemoryTaskStorage(10*time.Minute)), + ) for i := 0; i < 100; i++ { if i%10 == 0 { diff --git a/examples/scheduler.go b/examples/scheduler.go index f2c8e11..31d171b 100644 --- a/examples/scheduler.go +++ b/examples/scheduler.go @@ -11,7 +11,13 @@ import ( func main() { handler := tasks.SchedulerHandler callback := tasks.SchedulerCallback - pool := mq.NewPool(3, 5, 1000, handler, callback, mq.NewMemoryTaskStorage(10*time.Minute)) + pool := mq.NewPool(3, + mq.WithTaskQueueSize(5), + mq.WithMaxMemoryLoad(1000), + mq.WithHandler(handler), + mq.WithPoolCallback(callback), + mq.WithTaskStorage(mq.NewMemoryTaskStorage(10*time.Minute)), + ) ctx := context.Background() pool.EnqueueTask(context.Background(), &mq.Task{ID: "Task 1"}, 1) time.Sleep(1 * time.Second) diff --git a/pool.go b/pool.go index 67c77a3..0167f95 100644 --- a/pool.go +++ b/pool.go @@ -21,9 +21,6 @@ type Metrics struct { TotalScheduled int64 } -type PoolOption struct { -} - type Pool struct { taskStorage TaskStorage taskQueue PriorityQueue @@ -42,21 +39,23 @@ type Pool struct { taskAvailableCond *sync.Cond handler Handler callback Callback + batchSize int } -func NewPool(numOfWorkers, taskQueueSize int, maxMemoryLoad int64, handler Handler, callback Callback, storage TaskStorage) *Pool { +func NewPool(numOfWorkers int, opts ...PoolOption) *Pool { pool := &Pool{ - taskQueue: make(PriorityQueue, 0, taskQueueSize), - stop: make(chan struct{}), - taskNotify: make(chan struct{}, numOfWorkers), // Buffer for workers - maxMemoryLoad: maxMemoryLoad, - handler: handler, - callback: callback, - taskStorage: storage, - workerAdjust: make(chan int), + stop: make(chan struct{}), + taskNotify: make(chan struct{}, numOfWorkers), + batchSize: 1, } pool.scheduler = NewScheduler(pool) - pool.taskAvailableCond = sync.NewCond(&sync.Mutex{}) // Initialize condition variable + pool.taskAvailableCond = sync.NewCond(&sync.Mutex{}) + for _, opt := range opts { + opt(pool) + } + if len(pool.taskQueue) == 0 { + pool.taskQueue = make(PriorityQueue, 0, 10) + } heap.Init(&pool.taskQueue) pool.scheduler.Start() pool.Start(numOfWorkers) @@ -84,37 +83,41 @@ func (wp *Pool) Start(numWorkers int) { func (wp *Pool) worker() { defer wp.wg.Done() for { - wp.taskAvailableCond.L.Lock() // Lock the condition variable mutex - for len(wp.taskQueue) == 0 && !wp.paused { // Wait if there are no tasks and not paused + wp.taskAvailableCond.L.Lock() + for len(wp.taskQueue) == 0 && !wp.paused { wp.taskAvailableCond.Wait() } - wp.taskAvailableCond.L.Unlock() // Unlock the condition variable mutex - + wp.taskAvailableCond.L.Unlock() select { case <-wp.stop: return default: - wp.processNextTask() // Process next task if there are any + wp.processNextBatch() } } } -func (wp *Pool) processNextTask() { +func (wp *Pool) processNextBatch() { wp.taskQueueLock.Lock() - var task *QueueTask - if len(wp.taskQueue) > 0 && !wp.paused { - task = heap.Pop(&wp.taskQueue).(*QueueTask) + defer wp.taskQueueLock.Unlock() + tasks := make([]*QueueTask, 0, wp.batchSize) + for len(wp.taskQueue) > 0 && !wp.paused && len(tasks) < wp.batchSize { + task := heap.Pop(&wp.taskQueue).(*QueueTask) + tasks = append(tasks, task) } - wp.taskQueueLock.Unlock() - if task == nil && !wp.paused { - var err error - task, err = wp.taskStorage.FetchNextTask() - if err != nil { - return + if len(tasks) == 0 && !wp.paused { + for len(tasks) < wp.batchSize { + task, err := wp.taskStorage.FetchNextTask() + if err != nil { + break + } + tasks = append(tasks, task) } } - if task != nil { - wp.handleTask(task) + for _, task := range tasks { + if task != nil { + wp.handleTask(task) + } } } @@ -133,9 +136,7 @@ func (wp *Pool) handleTask(task *QueueTask) { wp.metrics.ErrorCount++ } } - if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil { - // Handle deletion error - } + _ = wp.taskStorage.DeleteTask(task.payload.ID) wp.metrics.TotalMemoryUsed -= taskSize } @@ -181,34 +182,26 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er defer wp.taskQueueLock.Unlock() taskSize := int64(utils.SizeOf(payload)) if wp.metrics.TotalMemoryUsed+taskSize > wp.maxMemoryLoad && wp.maxMemoryLoad > 0 { - return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize) + wp.storeInOverflow(task) + return fmt.Errorf("max memory load reached, task stored in overflow buffer of size %d", taskSize) } heap.Push(&wp.taskQueue, task) - - // Notify one worker that a task has been added wp.taskAvailableCond.L.Lock() wp.taskAvailableCond.Signal() wp.taskAvailableCond.L.Unlock() - return nil } -func (wp *Pool) Pause() { - wp.paused = true -} +func (wp *Pool) Pause() { wp.paused = true } -func (wp *Pool) Resume() { - wp.paused = false -} +func (wp *Pool) Resume() { wp.paused = false } -// Overflow Handling func (wp *Pool) storeInOverflow(task *QueueTask) { wp.overflowBufferLock.Lock() wp.overflowBuffer = append(wp.overflowBuffer, task) wp.overflowBufferLock.Unlock() } -// Drains tasks from the overflow buffer func (wp *Pool) startOverflowDrainer() { for { wp.drainOverflowBuffer() @@ -216,7 +209,7 @@ func (wp *Pool) startOverflowDrainer() { case <-wp.stop: return default: - time.Sleep(50 * time.Millisecond) + time.Sleep(100 * time.Millisecond) } } } @@ -224,17 +217,14 @@ func (wp *Pool) startOverflowDrainer() { func (wp *Pool) drainOverflowBuffer() { wp.overflowBufferLock.Lock() defer wp.overflowBufferLock.Unlock() - for len(wp.overflowBuffer) > 0 { select { case wp.taskNotify <- struct{}{}: - // Move the first task from the overflow buffer to the queue wp.taskQueueLock.Lock() heap.Push(&wp.taskQueue, wp.overflowBuffer[0]) wp.overflowBuffer = wp.overflowBuffer[1:] wp.taskQueueLock.Unlock() default: - // Stop if taskNotify is full return } } @@ -257,6 +247,4 @@ func (wp *Pool) Metrics() Metrics { return wp.metrics } -func (wp *Pool) Scheduler() *Scheduler { - return wp.scheduler -} +func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler } diff --git a/pool_options.go b/pool_options.go new file mode 100644 index 0000000..9d255b2 --- /dev/null +++ b/pool_options.go @@ -0,0 +1,40 @@ +package mq + +type PoolOption func(*Pool) + +func WithTaskQueueSize(size int) PoolOption { + return func(p *Pool) { + // Initialize the task queue with the specified size + p.taskQueue = make(PriorityQueue, 0, size) + } +} + +func WithMaxMemoryLoad(maxMemoryLoad int64) PoolOption { + return func(p *Pool) { + p.maxMemoryLoad = maxMemoryLoad + } +} + +func WithBatchSize(batchSize int) PoolOption { + return func(p *Pool) { + p.batchSize = batchSize + } +} + +func WithHandler(handler Handler) PoolOption { + return func(p *Pool) { + p.handler = handler + } +} + +func WithPoolCallback(callback Callback) PoolOption { + return func(p *Pool) { + p.callback = callback + } +} + +func WithTaskStorage(storage TaskStorage) PoolOption { + return func(p *Pool) { + p.taskStorage = storage + } +}