diff --git a/.gitignore b/.gitignore index fb61f0f..8d0c62e 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ go.work .qodo .history *.log +snapshots diff --git a/consumer.go b/consumer.go index 4111c75..d2e0a1e 100644 --- a/consumer.go +++ b/consumer.go @@ -86,9 +86,7 @@ func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) } func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) { - c.connMutex.RLock() - defer c.connMutex.RUnlock() - + // Check shutdown before attempting read if atomic.LoadInt32(&c.isShutdown) == 1 { return nil, fmt.Errorf("consumer is shutdown") } @@ -97,6 +95,7 @@ func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, return nil, fmt.Errorf("connection is nil") } + // Don't hold lock during blocking read - this allows Close() to proceed return codec.ReadMessage(ctx, conn) } @@ -197,17 +196,35 @@ func (c *Consumer) Auth(ctx context.Context, username, password string) error { } func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error { - fmt.Println("Consumer closed") + // Only log if not already shutdown (prevents spam during graceful shutdown) + if atomic.LoadInt32(&c.isShutdown) == 0 { + c.logger.Debug("Connection closed", + logger.Field{Key: "consumer_id", Value: c.id}) + } return nil } func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) { + // Don't trigger reconnection if shutting down + if atomic.LoadInt32(&c.isShutdown) == 1 { + return + } + if c.isConnectionClosed(err) { - log.Printf("Connection to broker closed for consumer %s at %s", c.id, conn.RemoteAddr()) + c.logger.Warn("Connection to broker closed, triggering reconnection", + logger.Field{Key: "consumer_id", Value: c.id}, + logger.Field{Key: "broker_addr", Value: conn.RemoteAddr().String()}) + // Trigger reconnection if possible - c.reconnectCh <- struct{}{} + select { + case c.reconnectCh <- struct{}{}: + default: + // Channel full, reconnection already pending + } } else { - log.Printf("Error reading from connection: %v", err) + c.logger.Error("Error reading from connection", + logger.Field{Key: "consumer_id", Value: c.id}, + logger.Field{Key: "error", Value: err.Error()}) } } @@ -645,18 +662,43 @@ func (c *Consumer) Consume(ctx context.Context) error { logger.Field{Key: "consumer_id", Value: c.id}, logger.Field{Key: "queue", Value: c.queue}) - // Main processing loop with enhanced error handling + // Main processing loop - will exit when context is cancelled or shutdown signal received + c.messageProcessingLoop(ctx) + + c.logger.Info("Consumer stopped", + logger.Field{Key: "consumer_id", Value: c.id}) + + return nil +} // messageProcessingLoop handles the actual message processing +func (c *Consumer) messageProcessingLoop(ctx context.Context) { + // Create a ticker for periodic shutdown checks + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for { + // Single point of shutdown check select { case <-ctx.Done(): - c.logger.Info("Context cancelled, stopping consumer", + c.logger.Info("Context cancelled", logger.Field{Key: "consumer_id", Value: c.id}) - return c.Close() + // Close connection immediately to unblock any reads + c.connMutex.Lock() + if c.conn != nil { + c.conn.Close() + } + c.connMutex.Unlock() + return case <-c.shutdown: c.logger.Info("Shutdown signal received", logger.Field{Key: "consumer_id", Value: c.id}) - return nil + // Close connection immediately to unblock any reads + c.connMutex.Lock() + if c.conn != nil { + c.conn.Close() + } + c.connMutex.Unlock() + return case <-c.reconnectCh: c.logger.Info("Reconnection triggered", @@ -665,34 +707,60 @@ func (c *Consumer) Consume(ctx context.Context) error { c.logger.Error("Reconnection failed, will retry based on backoff policy", logger.Field{Key: "consumer_id", Value: c.id}, logger.Field{Key: "error", Value: err.Error()}) - // The handleReconnection method now implements its own backoff, - // so we don't need to do anything here except continue the loop } + case <-ticker.C: + // Periodic check - do nothing, just allows checking other select cases + default: + // Check shutdown flag before processing + if atomic.LoadInt32(&c.isShutdown) == 1 { + return + } + // Apply rate limiting if configured if c.opts.ConsumerRateLimiter != nil { c.opts.ConsumerRateLimiter.Wait() } - // Process messages with timeout + // Process messages with timeout (non-blocking with quick return on shutdown) if err := c.processWithTimeout(ctx); err != nil { + // Check if shutdown was initiated if atomic.LoadInt32(&c.isShutdown) == 1 { - return nil + return } - c.logger.Error("Error processing message", - logger.Field{Key: "consumer_id", Value: c.id}, - logger.Field{Key: "error", Value: err.Error()}) + // Check if context was cancelled (graceful shutdown) + if err == context.Canceled || err == context.DeadlineExceeded { + c.logger.Info("Context cancelled during message processing", + logger.Field{Key: "consumer_id", Value: c.id}) + return + } - // Trigger reconnection for connection errors - if isConnectionError(err) { + // Handle connection closed errors + if isConnectionError(err) || strings.Contains(err.Error(), "shutdown") { + c.logger.Debug("Connection error detected", + logger.Field{Key: "consumer_id", Value: c.id}, + logger.Field{Key: "error", Value: err.Error()}) + + // If we're shutting down, don't try to reconnect + if atomic.LoadInt32(&c.isShutdown) == 1 { + return + } + + // Trigger reconnection select { case c.reconnectCh <- struct{}{}: default: } + continue } + // Log other errors but continue processing + c.logger.Error("Error processing message", + logger.Field{Key: "consumer_id", Value: c.id}, + logger.Field{Key: "error", Value: err.Error()}) + // Brief pause before retrying time.Sleep(100 * time.Millisecond) } @@ -700,10 +768,13 @@ func (c *Consumer) Consume(ctx context.Context) error { } } -// processWithTimeout processes messages WITHOUT I/O timeouts for persistent broker connections +// processWithTimeout processes messages with context awareness func (c *Consumer) processWithTimeout(ctx context.Context) error { - // Consumer should wait indefinitely for messages from broker - NO I/O timeout - // Only individual task processing should have timeouts, not the consumer connection + // Check shutdown first + if atomic.LoadInt32(&c.isShutdown) == 1 { + return fmt.Errorf("consumer is shutdown") + } + c.connMutex.RLock() conn := c.conn c.connMutex.RUnlock() @@ -712,11 +783,14 @@ func (c *Consumer) processWithTimeout(ctx context.Context) error { return fmt.Errorf("no connection available") } - // CRITICAL: Never set any connection timeouts for broker-consumer communication - // The consumer must maintain a persistent connection to the broker indefinitely - // Read message without ANY timeout - consumer should be long-running background service + // Just read the message - connection will be closed on shutdown err := c.readMessage(ctx, conn) + // If shutdown happened during read, return immediately + if atomic.LoadInt32(&c.isShutdown) == 1 { + return fmt.Errorf("consumer is shutdown") + } + // If message was processed successfully, reset reconnection attempts if err == nil { if atomic.LoadInt32(&c.reconnectAttempts) > 0 { @@ -730,6 +804,11 @@ func (c *Consumer) processWithTimeout(ctx context.Context) error { } func (c *Consumer) handleReconnection(ctx context.Context) error { + // Check shutdown immediately + if atomic.LoadInt32(&c.isShutdown) == 1 { + return fmt.Errorf("consumer is shutdown") + } + c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() @@ -766,6 +845,11 @@ func (c *Consumer) handleReconnection(ctx context.Context) error { } } + // Check shutdown again after potential wait + if atomic.LoadInt32(&c.isShutdown) == 1 { + return fmt.Errorf("consumer is shutdown") + } + c.lastReconnectAttempt = time.Now() // If we've exceeded reasonable attempts, implement longer backoff @@ -785,6 +869,11 @@ func (c *Consumer) handleReconnection(ctx context.Context) error { } } + // Final shutdown check before reconnecting + if atomic.LoadInt32(&c.isShutdown) == 1 { + return fmt.Errorf("consumer is shutdown") + } + // Mark as disconnected atomic.StoreInt32(&c.isConnected, 0) diff --git a/examples/consumer_example/main.go b/examples/consumer_example/main.go index 2fbef22..e973b7e 100644 --- a/examples/consumer_example/main.go +++ b/examples/consumer_example/main.go @@ -7,6 +7,7 @@ import ( "os" "os/signal" "strings" + "sync" "syscall" "time" @@ -67,18 +68,25 @@ func main() { fmt.Println("\nāœ… Consumers created") - // Start periodic statistics reporting for first consumer - go reportStatistics(consumers[0]) - // Start consuming messages fmt.Println("\nšŸ”„ Starting message consumption...") ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Start periodic statistics reporting for first consumer with context + statsCtx, statsCancel := context.WithCancel(context.Background()) + defer statsCancel() + go reportStatistics(statsCtx, consumers[0]) + + // Wait group to track all consumers + var wg sync.WaitGroup + // Run all consumers in background for _, consumer := range consumers { c := consumer // capture for goroutine + wg.Add(1) go func() { + defer wg.Done() if err := c.Consume(ctx); err != nil { log.Printf("āŒ Consumer error: %v", err) } @@ -96,13 +104,14 @@ func main() { fmt.Println("\n\nšŸ›‘ Shutdown signal received...") - // Cancel context to stop consumption - cancel() + // Stop statistics reporting first + fmt.Println(" 1. Stopping statistics reporting...") + statsCancel() + // Give statistics goroutine time to finish its current print cycle + time.Sleep(100 * time.Millisecond) + fmt.Println(" āœ… Statistics reporting stopped") - // Give a moment for context cancellation to propagate - time.Sleep(500 * time.Millisecond) - - fmt.Println(" 1. Closing consumers (this will stop worker pools)...") + fmt.Println(" 2. Closing consumers (this will stop worker pools)...") for i, consumer := range consumers { if err := consumer.Close(); err != nil { fmt.Printf("āŒ Consumer %d close error: %v\n", i, err) @@ -110,6 +119,26 @@ func main() { } fmt.Println(" āœ… All consumers closed") + // Cancel context to stop consumption + fmt.Println(" 3. Cancelling context to stop message processing...") + cancel() + + // Wait for all Consume() goroutines to finish + fmt.Println(" 4. Waiting for all consumers to finish...") + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + // Wait with timeout + select { + case <-done: + fmt.Println(" āœ… All consumers finished") + case <-time.After(5 * time.Second): + fmt.Println(" āš ļø Timeout waiting for consumers to finish") + } + fmt.Println("\nāœ… Graceful shutdown complete") fmt.Println("šŸ‘‹ Consumer stopped") } @@ -289,33 +318,39 @@ func isRetryableError(err error) bool { } // reportStatistics periodically reports consumer statistics -func reportStatistics(consumer *mq.Consumer) { +func reportStatistics(ctx context.Context, consumer *mq.Consumer) { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() - for range ticker.C { - metrics := consumer.Metrics() + for { + select { + case <-ctx.Done(): + // Context cancelled, stop reporting + return + case <-ticker.C: + metrics := consumer.Metrics() - fmt.Println("\nšŸ“Š Consumer Statistics:") - fmt.Println(" " + strings.Repeat("-", 50)) - fmt.Printf(" Consumer ID: %s\n", consumer.GetKey()) - fmt.Printf(" Total Tasks: %d\n", metrics.TotalTasks) - fmt.Printf(" Completed Tasks: %d\n", metrics.CompletedTasks) - fmt.Printf(" Failed Tasks: %d\n", metrics.ErrorCount) - fmt.Printf(" Scheduled Tasks: %d\n", metrics.TotalScheduled) - fmt.Printf(" Memory Used: %d bytes\n", metrics.TotalMemoryUsed) + fmt.Println("\nšŸ“Š Consumer Statistics:") + fmt.Println(" " + strings.Repeat("-", 50)) + fmt.Printf(" Consumer ID: %s\n", consumer.GetKey()) + fmt.Printf(" Total Tasks: %d\n", metrics.TotalTasks) + fmt.Printf(" Completed Tasks: %d\n", metrics.CompletedTasks) + fmt.Printf(" Failed Tasks: %d\n", metrics.ErrorCount) + fmt.Printf(" Scheduled Tasks: %d\n", metrics.TotalScheduled) + fmt.Printf(" Memory Used: %d bytes\n", metrics.TotalMemoryUsed) - if metrics.TotalTasks > 0 { - successRate := float64(metrics.CompletedTasks) / float64(metrics.TotalTasks) * 100 - fmt.Printf(" Success Rate: %.1f%%\n", successRate) + if metrics.TotalTasks > 0 { + successRate := float64(metrics.CompletedTasks) / float64(metrics.TotalTasks) * 100 + fmt.Printf(" Success Rate: %.1f%%\n", successRate) + } + + if metrics.TotalTasks > 0 && metrics.ExecutionTime > 0 { + avgTime := time.Duration(metrics.ExecutionTime/metrics.TotalTasks) * time.Millisecond + fmt.Printf(" Avg Processing Time: %v\n", avgTime) + } + + fmt.Println(" " + strings.Repeat("-", 50)) } - - if metrics.TotalTasks > 0 && metrics.ExecutionTime > 0 { - avgTime := time.Duration(metrics.ExecutionTime/metrics.TotalTasks) * time.Millisecond - fmt.Printf(" Avg Processing Time: %v\n", avgTime) - } - - fmt.Println(" " + strings.Repeat("-", 50)) } } diff --git a/mq.go b/mq.go index bf7e491..e65f115 100644 --- a/mq.go +++ b/mq.go @@ -1269,6 +1269,11 @@ func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) { delay := b.opts.initialDelay for task := range queue.tasks { + // Check if broker is shutting down + if atomic.LoadInt32(&b.isShutdown) == 1 { + return + } + // Handle each task in a separate goroutine to avoid blocking the dispatch loop go func(t *QueuedTask) { if b.opts.BrokerRateLimiter != nil { @@ -1279,6 +1284,11 @@ func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) { currentDelay := delay for !success && t.RetryCount <= b.opts.maxRetries { + // Check if broker is shutting down + if atomic.LoadInt32(&b.isShutdown) == 1 { + return + } + if b.dispatchTaskToConsumer(ctx, queue, t) { success = true b.acknowledgeTask(ctx, t.Message.Queue, queue.name) @@ -1387,12 +1397,99 @@ func (b *Broker) URL() string { } func (b *Broker) Close() error { - if b != nil && b.listener != nil { - log.Printf("Broker is closing...") - // Stop deferred task processor - b.StopDeferredTaskProcessor() - return b.listener.Close() + // Check if already shutdown + if !atomic.CompareAndSwapInt32(&b.isShutdown, 0, 1) { + log.Printf("Broker already closed") + return nil // Already shutdown } + + log.Printf("Broker is closing...") + + // Stop deferred task processor first + b.StopDeferredTaskProcessor() + + // Stop health checker + if b.healthChecker != nil { + b.healthChecker.Stop() + } + + // Stop admin server + if b.adminServer != nil { + b.adminServer.Stop() + } + + // Stop metrics server + if b.metricsServer != nil { + b.metricsServer.Stop() + } + + // Signal shutdown to main Start loop and background routines + select { + case <-b.shutdown: + // Already closed + default: + close(b.shutdown) + } + + // Close listener to stop accepting new connections + if b.listener != nil { + b.listener.Close() + } + + // Close all consumer connections + b.consumers.ForEach(func(_ string, con *consumer) bool { + if con.conn != nil { + con.conn.Close() + } + return true + }) + + // Close all publisher connections + b.publishers.ForEach(func(_ string, pub *publisher) bool { + if pub.conn != nil { + pub.conn.Close() + } + return true + }) + + // Wait for background goroutines to finish (with timeout) + done := make(chan struct{}) + go func() { + b.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Printf("All background routines finished") + case <-time.After(5 * time.Second): + log.Printf("Timeout waiting for background routines, forcing shutdown") + } + + // Close all queue task channels to stop dispatch workers + // Do this AFTER waiting for goroutines to prevent panic on closed channel + b.queues.ForEach(func(name string, queue *Queue) bool { + select { + case <-queue.tasks: + // Already closed + default: + close(queue.tasks) + } + return true + }) + + // Close all DLQ task channels + b.deadLetter.ForEach(func(name string, dlq *Queue) bool { + select { + case <-dlq.tasks: + // Already closed + default: + close(dlq.tasks) + } + return true + }) + + log.Printf("Broker shutdown complete") return nil } diff --git a/pool.go b/pool.go index 87ab0b4..8073246 100644 --- a/pool.go +++ b/pool.go @@ -714,7 +714,8 @@ func (wp *Pool) handleTask(task *QueueTask) { // Cleanup task from storage if wp.taskStorage != nil { if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil { - wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Failed to delete task from storage: %v", err) + // Task might already be deleted (duplicate processing) - this is expected with at-least-once delivery + wp.logger.Debug().Str("taskID", task.payload.ID).Msgf("Task already removed from storage: %v", err) } }