diff --git a/consumer.go b/consumer.go index 16c5b4f..7d8cf0f 100644 --- a/consumer.go +++ b/consumer.go @@ -91,8 +91,6 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C } ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) result := c.ProcessTask(ctx, &task) - result.Topic = msg.Queue - result.TaskID = taskID err = c.MessageResponseCallback(ctx, result) if err != nil { log.Printf("Error on message callback: %v", err) @@ -124,7 +122,10 @@ func (c *Consumer) MessageResponseCallback(ctx context.Context, result Result) e // ProcessTask handles a received task message and invokes the appropriate handler. func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result { - return c.handler(ctx, msg) + result := c.handler(ctx, msg) + result.Topic = msg.Topic + result.TaskID = msg.ID + return result } // AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration. diff --git a/pool.go b/pool.go index 1f3906f..5c03dd4 100644 --- a/pool.go +++ b/pool.go @@ -3,7 +3,9 @@ package mq import ( "context" "fmt" + "net" "sync" + "sync/atomic" "github.com/oarkflow/mq/utils" ) @@ -13,40 +15,50 @@ type QueueTask struct { payload *Task } -type Callback func(result Result, err error) +type Callback func(ctx context.Context, result Result) error type Pool struct { totalMemoryUsed int64 completedTasks int errorCount, maxMemoryLoad int64 - totalTasks, numOfWorkers int + totalTasks int + numOfWorkers int32 // Change to int32 for atomic operations taskQueue chan QueueTask wg sync.WaitGroup paused bool stop chan struct{} handler Handler callback Callback + conn net.Conn + workerAdjust chan int // Channel for adjusting workers dynamically } func NewPool( numOfWorkers, taskQueueSize int, maxMemoryLoad int64, - handler Handler, callback Callback) *Pool { - return &Pool{ - numOfWorkers: numOfWorkers, + handler Handler, + callback Callback, conn net.Conn) *Pool { + pool := &Pool{ + numOfWorkers: int32(numOfWorkers), taskQueue: make(chan QueueTask, taskQueueSize), stop: make(chan struct{}), maxMemoryLoad: maxMemoryLoad, handler: handler, callback: callback, + conn: conn, + workerAdjust: make(chan int), } + pool.Start(int(numOfWorkers)) + return pool } -func (wp *Pool) Start() { - for i := 0; i < wp.numOfWorkers; i++ { +func (wp *Pool) Start(numWorkers int) { + for i := 0; i < numWorkers; i++ { wp.wg.Add(1) go wp.worker() } + atomic.StoreInt32(&wp.numOfWorkers, int32(numWorkers)) + go wp.monitorWorkerAdjustments() // Monitor worker changes } func (wp *Pool) worker() { @@ -60,25 +72,59 @@ func (wp *Pool) worker() { taskSize := int64(utils.SizeOf(task.payload)) wp.totalMemoryUsed += taskSize wp.totalTasks++ - if wp.handler != nil { - result := wp.handler(task.ctx, task.payload) - if result.Error != nil { + + result := wp.handler(task.ctx, task.payload) + + if result.Error != nil { + wp.errorCount++ + } else { + wp.completedTasks++ + } + + if wp.callback != nil { + if err := wp.callback(task.ctx, result); err != nil { wp.errorCount++ - } else { - wp.completedTasks++ - } - if wp.callback != nil { - wp.callback(result, result.Error) } } wp.totalMemoryUsed -= taskSize + case <-wp.stop: return } } } +func (wp *Pool) monitorWorkerAdjustments() { + for { + select { + case adjustment := <-wp.workerAdjust: + currentWorkers := atomic.LoadInt32(&wp.numOfWorkers) + newWorkerCount := int(currentWorkers) + adjustment + if newWorkerCount > 0 { + wp.adjustWorkers(newWorkerCount) + } + case <-wp.stop: + return + } + } +} + +func (wp *Pool) adjustWorkers(newWorkerCount int) { + currentWorkers := int(atomic.LoadInt32(&wp.numOfWorkers)) + if newWorkerCount > currentWorkers { + for i := 0; i < newWorkerCount-currentWorkers; i++ { + wp.wg.Add(1) + go wp.worker() + } + } else if newWorkerCount < currentWorkers { + for i := 0; i < currentWorkers-newWorkerCount; i++ { + wp.stop <- struct{}{} + } + } + atomic.StoreInt32(&wp.numOfWorkers, int32(newWorkerCount)) +} + func (wp *Pool) AddTask(ctx context.Context, payload *Task) error { task := QueueTask{ctx: ctx, payload: payload} taskSize := int64(utils.SizeOf(payload)) @@ -107,6 +153,13 @@ func (wp *Pool) Stop() { wp.wg.Wait() } +func (wp *Pool) AdjustWorkerCount(newWorkerCount int) { + adjustment := newWorkerCount - int(atomic.LoadInt32(&wp.numOfWorkers)) + if adjustment != 0 { + wp.workerAdjust <- adjustment + } +} + func (wp *Pool) PrintMetrics() { fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes\n", wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed)