From 56d72d07068ca33eec12d8faa13479616b2fe5e3 Mon Sep 17 00:00:00 2001 From: sujit Date: Wed, 23 Oct 2024 14:29:23 +0545 Subject: [PATCH] feat: add task completion --- pool.go | 85 +++++++++++++++++++++++++++++-------------------- pool_options.go | 6 ++++ 2 files changed, 57 insertions(+), 34 deletions(-) diff --git a/pool.go b/pool.go index 945dfc2..30af7b6 100644 --- a/pool.go +++ b/pool.go @@ -8,11 +8,12 @@ import ( "sync" "sync/atomic" "time" - + "github.com/oarkflow/mq/utils" ) type Callback func(ctx context.Context, result Result) error +type CompletionCallback func() // Called when all tasks are completed type Metrics struct { TotalTasks int64 @@ -23,25 +24,27 @@ type Metrics struct { } type Pool struct { - taskStorage TaskStorage - taskQueue PriorityQueue - taskQueueLock sync.Mutex - stop chan struct{} - taskNotify chan struct{} - workerAdjust chan int - wg sync.WaitGroup - maxMemoryLoad int64 - numOfWorkers int32 - metrics Metrics - paused bool - scheduler *Scheduler - overflowBufferLock sync.RWMutex - overflowBuffer []*QueueTask - taskAvailableCond *sync.Cond - handler Handler - callback Callback - batchSize int - timeout time.Duration + taskStorage TaskStorage + taskQueue PriorityQueue + taskQueueLock sync.Mutex + stop chan struct{} + taskNotify chan struct{} + workerAdjust chan int + wg sync.WaitGroup + maxMemoryLoad int64 + numOfWorkers int32 + metrics Metrics + paused bool + scheduler *Scheduler + overflowBufferLock sync.RWMutex + overflowBuffer []*QueueTask + taskAvailableCond *sync.Cond + handler Handler + callback Callback + batchSize int + timeout time.Duration + completionCallback CompletionCallback + taskCompletionNotifier sync.WaitGroup } func NewPool(numOfWorkers int, opts ...PoolOption) *Pool { @@ -86,11 +89,9 @@ func (wp *Pool) Start(numWorkers int) { func (wp *Pool) worker() { defer wp.wg.Done() for { - wp.taskAvailableCond.L.Lock() for len(wp.taskQueue) == 0 && !wp.paused { - wp.taskAvailableCond.Wait() + wp.Dispatch(wp.taskAvailableCond.Wait) } - wp.taskAvailableCond.L.Unlock() select { case <-wp.stop: return @@ -122,12 +123,17 @@ func (wp *Pool) processNextBatch() { wp.handleTask(task) } } + + // Check if all tasks are completed + if len(tasks) > 0 { + wp.taskCompletionNotifier.Done() + } } func (wp *Pool) handleTask(task *QueueTask) { - ctx, cancel := context.WithTimeout(task.ctx, 10*time.Second) // Timeout for task + ctx, cancel := context.WithTimeout(task.ctx, wp.timeout) // Timeout for task defer cancel() - + taskSize := int64(utils.SizeOf(task.payload)) wp.metrics.TotalMemoryUsed += taskSize wp.metrics.TotalTasks++ @@ -194,24 +200,29 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er return fmt.Errorf("max memory load reached, task stored in overflow buffer of size %d", taskSize) } heap.Push(&wp.taskQueue, task) - wp.taskAvailableCond.L.Lock() - wp.taskAvailableCond.Signal() - wp.taskAvailableCond.L.Unlock() + wp.Dispatch(wp.taskAvailableCond.Signal) + wp.taskCompletionNotifier.Add(1) // Increment the counter for task completion return nil } +func (wp *Pool) Dispatch(event func()) { + wp.taskAvailableCond.L.Lock() + event() + wp.taskAvailableCond.L.Unlock() +} + func (wp *Pool) Pause() { wp.paused = true - wp.taskAvailableCond.L.Lock() - wp.taskAvailableCond.Broadcast() // Notify all waiting workers - wp.taskAvailableCond.L.Unlock() + wp.Dispatch(wp.taskAvailableCond.Broadcast) +} + +func (wp *Pool) SetBatchSize(size int) { + wp.batchSize = size } func (wp *Pool) Resume() { wp.paused = false - wp.taskAvailableCond.L.Lock() - wp.taskAvailableCond.Broadcast() // Notify all waiting workers - wp.taskAvailableCond.L.Unlock() + wp.Dispatch(wp.taskAvailableCond.Broadcast) } func (wp *Pool) storeInOverflow(task *QueueTask) { @@ -251,6 +262,12 @@ func (wp *Pool) drainOverflowBuffer() { func (wp *Pool) Stop() { close(wp.stop) wp.wg.Wait() + + // Wait for all tasks to complete + wp.taskCompletionNotifier.Wait() + if wp.completionCallback != nil { + wp.completionCallback() + } } func (wp *Pool) AdjustWorkerCount(newWorkerCount int) { diff --git a/pool_options.go b/pool_options.go index 9f2f3ed..04bcdf3 100644 --- a/pool_options.go +++ b/pool_options.go @@ -19,6 +19,12 @@ func WithTaskTimeout(t time.Duration) PoolOption { } } +func WithCompletionCallback(callback func()) PoolOption { + return func(p *Pool) { + p.completionCallback = callback + } +} + func WithMaxMemoryLoad(maxMemoryLoad int64) PoolOption { return func(p *Pool) { p.maxMemoryLoad = maxMemoryLoad