feat: add task completion

This commit is contained in:
sujit
2024-10-23 14:29:23 +05:45
parent 6caac4030f
commit 56d72d0706
2 changed files with 57 additions and 34 deletions

85
pool.go
View File

@@ -8,11 +8,12 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/oarkflow/mq/utils" "github.com/oarkflow/mq/utils"
) )
type Callback func(ctx context.Context, result Result) error type Callback func(ctx context.Context, result Result) error
type CompletionCallback func() // Called when all tasks are completed
type Metrics struct { type Metrics struct {
TotalTasks int64 TotalTasks int64
@@ -23,25 +24,27 @@ type Metrics struct {
} }
type Pool struct { type Pool struct {
taskStorage TaskStorage taskStorage TaskStorage
taskQueue PriorityQueue taskQueue PriorityQueue
taskQueueLock sync.Mutex taskQueueLock sync.Mutex
stop chan struct{} stop chan struct{}
taskNotify chan struct{} taskNotify chan struct{}
workerAdjust chan int workerAdjust chan int
wg sync.WaitGroup wg sync.WaitGroup
maxMemoryLoad int64 maxMemoryLoad int64
numOfWorkers int32 numOfWorkers int32
metrics Metrics metrics Metrics
paused bool paused bool
scheduler *Scheduler scheduler *Scheduler
overflowBufferLock sync.RWMutex overflowBufferLock sync.RWMutex
overflowBuffer []*QueueTask overflowBuffer []*QueueTask
taskAvailableCond *sync.Cond taskAvailableCond *sync.Cond
handler Handler handler Handler
callback Callback callback Callback
batchSize int batchSize int
timeout time.Duration timeout time.Duration
completionCallback CompletionCallback
taskCompletionNotifier sync.WaitGroup
} }
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool { func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
@@ -86,11 +89,9 @@ func (wp *Pool) Start(numWorkers int) {
func (wp *Pool) worker() { func (wp *Pool) worker() {
defer wp.wg.Done() defer wp.wg.Done()
for { for {
wp.taskAvailableCond.L.Lock()
for len(wp.taskQueue) == 0 && !wp.paused { for len(wp.taskQueue) == 0 && !wp.paused {
wp.taskAvailableCond.Wait() wp.Dispatch(wp.taskAvailableCond.Wait)
} }
wp.taskAvailableCond.L.Unlock()
select { select {
case <-wp.stop: case <-wp.stop:
return return
@@ -122,12 +123,17 @@ func (wp *Pool) processNextBatch() {
wp.handleTask(task) wp.handleTask(task)
} }
} }
// Check if all tasks are completed
if len(tasks) > 0 {
wp.taskCompletionNotifier.Done()
}
} }
func (wp *Pool) handleTask(task *QueueTask) { func (wp *Pool) handleTask(task *QueueTask) {
ctx, cancel := context.WithTimeout(task.ctx, 10*time.Second) // Timeout for task ctx, cancel := context.WithTimeout(task.ctx, wp.timeout) // Timeout for task
defer cancel() defer cancel()
taskSize := int64(utils.SizeOf(task.payload)) taskSize := int64(utils.SizeOf(task.payload))
wp.metrics.TotalMemoryUsed += taskSize wp.metrics.TotalMemoryUsed += taskSize
wp.metrics.TotalTasks++ wp.metrics.TotalTasks++
@@ -194,24 +200,29 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
return fmt.Errorf("max memory load reached, task stored in overflow buffer of size %d", taskSize) return fmt.Errorf("max memory load reached, task stored in overflow buffer of size %d", taskSize)
} }
heap.Push(&wp.taskQueue, task) heap.Push(&wp.taskQueue, task)
wp.taskAvailableCond.L.Lock() wp.Dispatch(wp.taskAvailableCond.Signal)
wp.taskAvailableCond.Signal() wp.taskCompletionNotifier.Add(1) // Increment the counter for task completion
wp.taskAvailableCond.L.Unlock()
return nil return nil
} }
func (wp *Pool) Dispatch(event func()) {
wp.taskAvailableCond.L.Lock()
event()
wp.taskAvailableCond.L.Unlock()
}
func (wp *Pool) Pause() { func (wp *Pool) Pause() {
wp.paused = true wp.paused = true
wp.taskAvailableCond.L.Lock() wp.Dispatch(wp.taskAvailableCond.Broadcast)
wp.taskAvailableCond.Broadcast() // Notify all waiting workers }
wp.taskAvailableCond.L.Unlock()
func (wp *Pool) SetBatchSize(size int) {
wp.batchSize = size
} }
func (wp *Pool) Resume() { func (wp *Pool) Resume() {
wp.paused = false wp.paused = false
wp.taskAvailableCond.L.Lock() wp.Dispatch(wp.taskAvailableCond.Broadcast)
wp.taskAvailableCond.Broadcast() // Notify all waiting workers
wp.taskAvailableCond.L.Unlock()
} }
func (wp *Pool) storeInOverflow(task *QueueTask) { func (wp *Pool) storeInOverflow(task *QueueTask) {
@@ -251,6 +262,12 @@ func (wp *Pool) drainOverflowBuffer() {
func (wp *Pool) Stop() { func (wp *Pool) Stop() {
close(wp.stop) close(wp.stop)
wp.wg.Wait() wp.wg.Wait()
// Wait for all tasks to complete
wp.taskCompletionNotifier.Wait()
if wp.completionCallback != nil {
wp.completionCallback()
}
} }
func (wp *Pool) AdjustWorkerCount(newWorkerCount int) { func (wp *Pool) AdjustWorkerCount(newWorkerCount int) {

View File

@@ -19,6 +19,12 @@ func WithTaskTimeout(t time.Duration) PoolOption {
} }
} }
func WithCompletionCallback(callback func()) PoolOption {
return func(p *Pool) {
p.completionCallback = callback
}
}
func WithMaxMemoryLoad(maxMemoryLoad int64) PoolOption { func WithMaxMemoryLoad(maxMemoryLoad int64) PoolOption {
return func(p *Pool) { return func(p *Pool) {
p.maxMemoryLoad = maxMemoryLoad p.maxMemoryLoad = maxMemoryLoad