Files
mq/consumer.go
2025-09-24 18:09:55 +05:45

1043 lines
30 KiB
Go

package mq
import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/oarkflow/json"
"github.com/oarkflow/json/jsonparser"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/logger"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"github.com/oarkflow/mq/utils"
)
type Processor interface {
ProcessTask(ctx context.Context, msg *Task) Result
Consume(ctx context.Context) error
Pause(ctx context.Context) error
Resume(ctx context.Context) error
Stop(ctx context.Context) error
Close() error
GetKey() string
SetKey(key string)
GetType() string
}
type Consumer struct {
conn net.Conn
handler Handler
pool *Pool
opts *Options
id string
queue string
pIDs storage.IMap[string, bool]
connMutex sync.RWMutex
isConnected int32 // atomic flag
isShutdown int32 // atomic flag
shutdown chan struct{}
reconnectCh chan struct{}
healthTicker *time.Ticker
logger logger.Logger
reconnectAttempts int32 // track consecutive reconnection attempts
lastReconnectAttempt time.Time // track when last reconnect was attempted
reconnectMutex sync.Mutex // protect reconnection attempt tracking
}
func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer {
options := SetupOptions(opts...)
return &Consumer{
id: id,
opts: options,
queue: queue,
handler: handler,
pIDs: memory.New[string, bool](),
shutdown: make(chan struct{}),
reconnectCh: make(chan struct{}, 1),
logger: options.Logger(),
}
}
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
c.connMutex.RLock()
defer c.connMutex.RUnlock()
if atomic.LoadInt32(&c.isShutdown) == 1 {
return fmt.Errorf("consumer is shutdown")
}
if conn == nil {
return fmt.Errorf("connection is nil")
}
return codec.SendMessage(ctx, conn, msg)
}
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
c.connMutex.RLock()
defer c.connMutex.RUnlock()
if atomic.LoadInt32(&c.isShutdown) == 1 {
return nil, fmt.Errorf("consumer is shutdown")
}
if conn == nil {
return nil, fmt.Errorf("connection is nil")
}
return codec.ReadMessage(ctx, conn)
}
func (c *Consumer) Close() error {
// Signal shutdown
if !atomic.CompareAndSwapInt32(&c.isShutdown, 0, 1) {
return nil // Already shutdown
}
close(c.shutdown)
// Stop health checker
if c.healthTicker != nil {
c.healthTicker.Stop()
}
// Stop pool gracefully
if c.pool != nil {
c.pool.Stop()
}
// Close connection
c.connMutex.Lock()
if c.conn != nil {
err := c.conn.Close()
c.conn = nil
atomic.StoreInt32(&c.isConnected, 0)
c.connMutex.Unlock()
c.logger.Info("Connection closed for consumer", logger.Field{Key: "consumer_id", Value: c.id})
return err
}
c.connMutex.Unlock()
c.logger.Info("Consumer closed successfully", logger.Field{Key: "consumer_id", Value: c.id})
return nil
}
func (c *Consumer) GetKey() string {
return c.id
}
func (c *Consumer) GetType() string {
return "consumer"
}
func (c *Consumer) SetKey(key string) {
c.id = key
}
func (c *Consumer) Metrics() Metrics {
return c.pool.Metrics()
}
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg, err := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
if err != nil {
return fmt.Errorf("error creating subscribe message: %v", err)
}
if err := c.send(ctx, c.conn, msg); err != nil {
return fmt.Errorf("error while trying to subscribe: %v", err)
}
return c.waitForAck(ctx, c.conn)
}
// Auth authenticates the consumer with the broker
func (c *Consumer) Auth(ctx context.Context, username, password string) error {
authPayload := map[string]string{
"username": username,
"password": password,
}
payload, err := json.Marshal(authPayload)
if err != nil {
return err
}
headers := HeadersWithConsumerID(ctx, c.id)
msg, err := codec.NewMessage(consts.AUTH, payload, "", headers)
if err != nil {
return fmt.Errorf("error creating auth message: %v", err)
}
if err := c.send(ctx, c.conn, msg); err != nil {
return fmt.Errorf("error sending auth: %v", err)
}
// Wait for AUTH_ACK
resp, err := c.receive(ctx, c.conn)
if err != nil {
return fmt.Errorf("error receiving auth response: %v", err)
}
if resp.Command != consts.AUTH_ACK {
return fmt.Errorf("authentication failed: %s", string(resp.Payload))
}
return nil
}
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
fmt.Println("Consumer closed")
return nil
}
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
if c.isConnectionClosed(err) {
log.Printf("Connection to broker closed for consumer %s at %s", c.id, conn.RemoteAddr())
// Trigger reconnection if possible
c.reconnectCh <- struct{}{}
} else {
log.Printf("Error reading from connection: %v", err)
}
}
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) error {
switch msg.Command {
case consts.PUBLISH:
// Handle message consumption asynchronously to prevent blocking
go c.ConsumeMessage(ctx, msg, conn)
return nil
case consts.CONSUMER_PAUSE:
err := c.Pause(ctx)
if err != nil {
log.Printf("Unable to pause consumer: %v", err)
}
return err
case consts.CONSUMER_RESUME:
err := c.Resume(ctx)
if err != nil {
log.Printf("Unable to resume consumer: %v", err)
}
return err
case consts.CONSUMER_STOP:
err := c.Stop(ctx)
if err != nil {
log.Printf("Unable to stop consumer: %v", err)
}
return err
case consts.CONSUMER_UPDATE:
err := c.Update(ctx, msg.Payload)
if err != nil {
log.Printf("Unable to update consumer: %v", err)
}
return err
default:
log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
}
return nil
}
func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) {
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply, err := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err != nil {
c.logger.Error("Failed to create MESSAGE_ACK",
logger.Field{Key: "queue", Value: msg.Queue},
logger.Field{Key: "task_id", Value: taskID},
logger.Field{Key: "error", Value: err.Error()})
return
}
// Send with timeout to avoid blocking
sendCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
if err := c.send(sendCtx, conn, reply); err != nil {
c.logger.Error("Failed to send MESSAGE_ACK",
logger.Field{Key: "queue", Value: msg.Queue},
logger.Field{Key: "task_id", Value: taskID},
logger.Field{Key: "error", Value: err.Error()})
}
}
func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
// Send acknowledgment asynchronously
go c.sendMessageAck(ctx, msg, conn)
if msg.Payload == nil {
log.Printf("Received empty message payload")
return
}
var task Task
err := json.Unmarshal(msg.Payload, &task)
if err != nil {
log.Printf("Error unmarshalling message: %v", err)
return
}
// Check if the task has already been processed
if _, exists := c.pIDs.Get(task.ID); exists {
log.Printf("Task %s already processed, skipping...", task.ID)
return
}
// Process the task asynchronously to avoid blocking the main consumer loop
go c.processTaskAsync(ctx, &task, msg.Queue)
}
func (c *Consumer) processTaskAsync(ctx context.Context, task *Task, queue string) {
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: queue})
// Try to enqueue the task with timeout
enqueueDone := make(chan error, 1)
go func() {
err := c.pool.EnqueueTask(ctx, task, 1)
enqueueDone <- err
}()
// Wait for enqueue with timeout
select {
case err := <-enqueueDone:
if err == nil {
// Mark the task as processed
c.pIDs.Set(task.ID, true)
return
}
// Handle enqueue error with retry logic
c.retryTaskEnqueue(ctx, task, queue, err)
case <-time.After(30 * time.Second): // Enqueue timeout
c.logger.Error("Task enqueue timeout",
logger.Field{Key: "task_id", Value: task.ID},
logger.Field{Key: "queue", Value: queue})
c.sendDenyMessage(ctx, task.ID, queue, fmt.Errorf("enqueue timeout"))
}
}
func (c *Consumer) retryTaskEnqueue(ctx context.Context, task *Task, queue string, initialErr error) {
retryCount := 0
for retryCount < c.opts.maxRetries {
retryCount++
// Calculate backoff duration
backoffDuration := utils.CalculateJitter(
c.opts.initialDelay*time.Duration(1<<retryCount),
c.opts.jitterPercent,
)
c.logger.Warn("Retrying task enqueue",
logger.Field{Key: "task_id", Value: task.ID},
logger.Field{Key: "attempt", Value: fmt.Sprintf("%d/%d", retryCount, c.opts.maxRetries)},
logger.Field{Key: "backoff", Value: backoffDuration.String()},
logger.Field{Key: "error", Value: initialErr.Error()})
// Sleep in goroutine to avoid blocking
time.Sleep(backoffDuration)
// Try enqueue again
if err := c.pool.EnqueueTask(ctx, task, 1); err == nil {
c.pIDs.Set(task.ID, true)
c.logger.Info("Task enqueue successful after retry",
logger.Field{Key: "task_id", Value: task.ID},
logger.Field{Key: "attempts", Value: retryCount})
return
}
}
// All retries failed
c.logger.Error("Task enqueue failed after all retries",
logger.Field{Key: "task_id", Value: task.ID},
logger.Field{Key: "max_retries", Value: c.opts.maxRetries})
c.sendDenyMessage(ctx, task.ID, queue, fmt.Errorf("enqueue failed after %d retries", c.opts.maxRetries))
}
func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result {
defer RecoverPanic(RecoverTitle)
queue, _ := GetQueue(ctx)
if msg.Topic == "" && queue != "" {
msg.Topic = queue
}
// Apply timeout to individual task processing (not consumer connection)
timeout := c.opts.ConsumerTimeout()
if timeout > 0 {
// Use configurable timeout for task processing
taskCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Run task with timeout in a goroutine
resultCh := make(chan Result, 1)
go func() {
result := c.handler(taskCtx, msg)
result.Topic = msg.Topic
result.TaskID = msg.ID
resultCh <- result
}()
// Wait for result or timeout
select {
case result := <-resultCh:
return result
case <-taskCtx.Done():
// Task processing timeout
return Result{
Error: fmt.Errorf("task processing timeout after %v", timeout),
Topic: msg.Topic,
TaskID: msg.ID,
Status: "FAILED",
Ctx: ctx,
}
}
}
// No timeout - for page nodes that need unlimited time for user input
result := c.handler(ctx, msg)
result.Topic = msg.Topic
result.TaskID = msg.ID
return result
}
func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
if result.Status == "PENDING" && c.opts.respondPendingResult {
return nil
}
// Send response asynchronously to avoid blocking task processing
go func() {
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, result.Topic)
if result.Status == "" {
if result.Error != nil {
result.Status = "FAILED"
} else {
result.Status = "SUCCESS"
}
}
bt, _ := json.Marshal(result)
reply, err := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
if err != nil {
c.logger.Error("Failed to create MESSAGE_RESPONSE",
logger.Field{Key: "topic", Value: result.Topic},
logger.Field{Key: "task_id", Value: result.TaskID},
logger.Field{Key: "error", Value: err.Error()})
return
}
sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := c.send(sendCtx, c.conn, reply); err != nil {
c.logger.Error("Failed to send MESSAGE_RESPONSE",
logger.Field{Key: "topic", Value: result.Topic},
logger.Field{Key: "task_id", Value: result.TaskID},
logger.Field{Key: "error", Value: err.Error()})
}
}()
return nil
}
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
// Send deny message asynchronously to avoid blocking
go func() {
headers := HeadersWithConsumerID(ctx, c.id)
reply, err := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
if err != nil {
c.logger.Error("Failed to create MESSAGE_DENY",
logger.Field{Key: "queue", Value: queue},
logger.Field{Key: "task_id", Value: taskID},
logger.Field{Key: "error", Value: err.Error()})
return
}
sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if sendErr := c.send(sendCtx, c.conn, reply); sendErr != nil {
c.logger.Error("Failed to send MESSAGE_DENY",
logger.Field{Key: "queue", Value: queue},
logger.Field{Key: "task_id", Value: taskID},
logger.Field{Key: "original_error", Value: err.Error()},
logger.Field{Key: "send_error", Value: sendErr.Error()})
}
}()
}
// isHealthy checks if the connection is still healthy WITHOUT setting deadlines
func (c *Consumer) isHealthy() bool {
c.connMutex.RLock()
defer c.connMutex.RUnlock()
if c.conn == nil || atomic.LoadInt32(&c.isConnected) == 0 {
return false
}
// CRITICAL: DO NOT set any deadlines on broker-consumer connections
// These are persistent connections that must remain open indefinitely
// Instead, use a simple non-blocking connection state check
// Check if connection is still valid by checking the connection state
// without setting any timeouts or deadlines
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
// Check TCP connection state without timeouts
// This is a lightweight check that doesn't interfere with persistent connection
if tcpConn == nil {
return false
}
// Connection exists and is of correct type - assume healthy
// The actual health will be determined when we try to read/write
return true
}
// For non-TCP connections, assume healthy if connection exists
return c.conn != nil
}
// startHealthChecker starts periodic health checks
func (c *Consumer) startHealthChecker() {
c.healthTicker = time.NewTicker(30 * time.Second)
go func() {
defer c.healthTicker.Stop()
for {
select {
case <-c.healthTicker.C:
if !c.isHealthy() {
c.logger.Warn("Connection health check failed, triggering reconnection",
logger.Field{Key: "consumer_id", Value: c.id})
select {
case c.reconnectCh <- struct{}{}:
default:
// Channel is full, reconnection already pending
}
}
case <-c.shutdown:
return
}
}
}()
}
func (c *Consumer) attemptConnect() error {
if atomic.LoadInt32(&c.isShutdown) == 1 {
return fmt.Errorf("consumer is shutdown")
}
var err error
delay := c.opts.initialDelay
for i := 0; i < c.opts.maxRetries; i++ {
if atomic.LoadInt32(&c.isShutdown) == 1 {
return fmt.Errorf("consumer is shutdown")
}
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
if err == nil {
c.connMutex.Lock()
c.conn = conn
atomic.StoreInt32(&c.isConnected, 1)
c.connMutex.Unlock()
c.logger.Info("Successfully connected to broker",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "broker_addr", Value: c.opts.brokerAddr})
return nil
}
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
c.logger.Warn("Failed to connect to broker, retrying",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "broker_addr", Value: c.opts.brokerAddr},
logger.Field{Key: "attempt", Value: fmt.Sprintf("%d/%d", i+1, c.opts.maxRetries)},
logger.Field{Key: "error", Value: err.Error()},
logger.Field{Key: "retry_in", Value: sleepDuration.String()})
time.Sleep(sleepDuration)
delay *= 2
if delay > c.opts.maxBackoff {
delay = c.opts.maxBackoff
}
}
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
}
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(ctx, conn)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
return c.OnMessage(ctx, msg, conn)
}
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
err1 := c.OnClose(ctx, conn)
if err1 != nil {
return err1
}
return err
}
c.OnError(ctx, conn, err)
return err
}
func (c *Consumer) Consume(ctx context.Context) error {
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Initial connection
if err := c.attemptConnect(); err != nil {
return fmt.Errorf("initial connection failed: %w", err)
}
// Authenticate if security is enabled
if c.opts.enableSecurity {
if c.opts.username == "" || c.opts.password == "" {
return fmt.Errorf("username and password required for authentication")
}
if err := c.Auth(ctx, c.opts.username, c.opts.password); err != nil {
return fmt.Errorf("authentication failed: %w", err)
}
}
// Initialize pool
c.pool = NewPool(
c.opts.numOfWorkers,
WithTaskQueueSize(c.opts.queueSize),
WithMaxMemoryLoad(c.opts.maxMemoryLoad),
WithHandler(c.ProcessTask),
WithPoolCallback(c.OnResponse),
WithTaskStorage(c.opts.storage),
)
// Subscribe to queue
if err := c.subscribe(ctx, c.queue); err != nil {
return fmt.Errorf("failed to subscribe to queue %s: %w", c.queue, err)
}
// Start worker pool
c.pool.Start(c.opts.numOfWorkers)
// Start health checker
c.startHealthChecker()
// Start HTTP API if enabled
if c.opts.enableHTTPApi {
go func() {
if _, err := c.StartHTTPAPI(); err != nil {
c.logger.Error("Failed to start HTTP API",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "error", Value: err.Error()})
}
}()
}
c.logger.Info("Consumer started successfully",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "queue", Value: c.queue})
// Main processing loop with enhanced error handling
for {
select {
case <-ctx.Done():
c.logger.Info("Context cancelled, stopping consumer",
logger.Field{Key: "consumer_id", Value: c.id})
return c.Close()
case <-c.shutdown:
c.logger.Info("Shutdown signal received",
logger.Field{Key: "consumer_id", Value: c.id})
return nil
case <-c.reconnectCh:
c.logger.Info("Reconnection triggered",
logger.Field{Key: "consumer_id", Value: c.id})
if err := c.handleReconnection(ctx); err != nil {
c.logger.Error("Reconnection failed, will retry based on backoff policy",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "error", Value: err.Error()})
// The handleReconnection method now implements its own backoff,
// so we don't need to do anything here except continue the loop
}
default:
// Apply rate limiting if configured
if c.opts.ConsumerRateLimiter != nil {
c.opts.ConsumerRateLimiter.Wait()
}
// Process messages with timeout
if err := c.processWithTimeout(ctx); err != nil {
if atomic.LoadInt32(&c.isShutdown) == 1 {
return nil
}
c.logger.Error("Error processing message",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "error", Value: err.Error()})
// Trigger reconnection for connection errors
if isConnectionError(err) {
select {
case c.reconnectCh <- struct{}{}:
default:
}
}
// Brief pause before retrying
time.Sleep(100 * time.Millisecond)
}
}
}
}
// processWithTimeout processes messages WITHOUT I/O timeouts for persistent broker connections
func (c *Consumer) processWithTimeout(ctx context.Context) error {
// Consumer should wait indefinitely for messages from broker - NO I/O timeout
// Only individual task processing should have timeouts, not the consumer connection
c.connMutex.RLock()
conn := c.conn
c.connMutex.RUnlock()
if conn == nil {
return fmt.Errorf("no connection available")
}
// CRITICAL: Never set any connection timeouts for broker-consumer communication
// The consumer must maintain a persistent connection to the broker indefinitely
// Read message without ANY timeout - consumer should be long-running background service
err := c.readMessage(ctx, conn)
// If message was processed successfully, reset reconnection attempts
if err == nil {
if atomic.LoadInt32(&c.reconnectAttempts) > 0 {
atomic.StoreInt32(&c.reconnectAttempts, 0)
c.logger.Debug("Reset reconnection attempts after successful message processing",
logger.Field{Key: "consumer_id", Value: c.id})
}
}
return err
}
func (c *Consumer) handleReconnection(ctx context.Context) error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
// Increment reconnection attempts
attempts := atomic.AddInt32(&c.reconnectAttempts, 1)
// Calculate backoff based on consecutive attempts
backoffDelay := utils.CalculateJitter(
c.opts.initialDelay*time.Duration(1<<minInt(int(attempts-1), 6)), // Cap exponential growth at 2^6
c.opts.jitterPercent,
)
// Cap maximum backoff
if backoffDelay > c.opts.maxBackoff {
backoffDelay = c.opts.maxBackoff
}
// If we've been reconnecting too frequently, implement circuit breaker logic
timeSinceLastAttempt := time.Since(c.lastReconnectAttempt)
if attempts > 1 && timeSinceLastAttempt < backoffDelay {
remainingWait := backoffDelay - timeSinceLastAttempt
c.logger.Warn("Throttling reconnection attempt",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "consecutive_attempts", Value: int(attempts)},
logger.Field{Key: "wait_duration", Value: remainingWait.String()})
// Wait with context cancellation support
select {
case <-time.After(remainingWait):
case <-ctx.Done():
return ctx.Err()
case <-c.shutdown:
return fmt.Errorf("consumer is shutdown")
}
}
c.lastReconnectAttempt = time.Now()
// If we've exceeded reasonable attempts, implement longer backoff
if attempts > int32(c.opts.maxRetries*2) {
longBackoff := 5 * time.Minute // Long circuit breaker period
c.logger.Warn("Too many consecutive reconnection attempts, entering long backoff",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "consecutive_attempts", Value: int(attempts)},
logger.Field{Key: "backoff_duration", Value: longBackoff.String()})
select {
case <-time.After(longBackoff):
case <-ctx.Done():
return ctx.Err()
case <-c.shutdown:
return fmt.Errorf("consumer is shutdown")
}
}
// Mark as disconnected
atomic.StoreInt32(&c.isConnected, 0)
// Close existing connection
c.connMutex.Lock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.connMutex.Unlock()
c.logger.Info("Attempting reconnection",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "attempt", Value: int(attempts)},
logger.Field{Key: "backoff_delay", Value: backoffDelay.String()})
// Attempt reconnection
if err := c.attemptConnect(); err != nil {
c.logger.Error("Reconnection failed",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "attempt", Value: int(attempts)},
logger.Field{Key: "error", Value: err.Error()})
return fmt.Errorf("failed to reconnect (attempt %d): %w", attempts, err)
}
// Reconnection successful, try to resubscribe
if err := c.subscribe(ctx, c.queue); err != nil {
c.logger.Error("Failed to resubscribe after reconnection",
logger.Field{Key: "consumer_id", Value: c.id},
logger.Field{Key: "error", Value: err.Error()})
return fmt.Errorf("failed to resubscribe after reconnection: %w", err)
}
// Reset reconnection attempts on successful reconnection
atomic.StoreInt32(&c.reconnectAttempts, 0)
c.logger.Info("Successfully reconnected and resubscribed",
logger.Field{Key: "consumer_id", Value: c.id})
return nil
}
// Helper function to get minimum of two integers
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func (c *Consumer) isConnectionClosed(err error) bool {
return err == io.EOF || strings.Contains(err.Error(), "connection closed") || strings.Contains(err.Error(), "connection reset")
}
func isConnectionError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "closed network") ||
strings.Contains(errStr, "broken pipe")
}
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(ctx, conn)
if err != nil {
return err
}
if msg.Command == consts.SUBSCRIBE_ACK {
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
return nil
}
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
}
func (c *Consumer) Pause(ctx context.Context) error {
return c.operate(ctx, consts.CONSUMER_PAUSED, c.pool.Pause)
}
func (c *Consumer) Update(ctx context.Context, payload []byte) error {
var newConfig DynamicConfig
if err := json.Unmarshal(payload, &newConfig); err != nil {
log.Printf("Invalid payload for CONSUMER_UPDATE: %v", err)
return err
}
if c.pool != nil {
if err := c.pool.UpdateConfig(&newConfig); err != nil {
log.Printf("Failed to update pool config: %v", err)
return err
}
}
return c.sendOpsMessage(ctx, consts.CONSUMER_UPDATED)
}
func (c *Consumer) Resume(ctx context.Context) error {
return c.operate(ctx, consts.CONSUMER_RESUMED, c.pool.Resume)
}
func (c *Consumer) Stop(ctx context.Context) error {
return c.operate(ctx, consts.CONSUMER_STOPPED, c.pool.Stop)
}
func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation func()) error {
poolOperation()
if err := c.sendOpsMessage(ctx, cmd); err != nil {
return err
}
return nil
}
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg, err := codec.NewMessage(cmd, nil, c.queue, headers)
if err != nil {
return fmt.Errorf("error creating operation message: %v", err)
}
return c.send(ctx, c.conn, msg)
}
func (c *Consumer) Conn() net.Conn {
return c.conn
}
// StartHTTPAPI starts an HTTP server on a random available port and registers API endpoints.
// It returns the port number the server is listening on.
func (c *Consumer) StartHTTPAPI() (int, error) {
// Listen on a random port.
ln, err := net.Listen("tcp", ":0")
if err != nil {
return 0, fmt.Errorf("failed to start listener: %w", err)
}
port := ln.Addr().(*net.TCPAddr).Port
// Create a new HTTP mux and register endpoints.
mux := http.NewServeMux()
mux.HandleFunc("/stats", c.handleStats)
mux.HandleFunc("/update", c.handleUpdate)
mux.HandleFunc("/pause", c.handlePause)
mux.HandleFunc("/resume", c.handleResume)
mux.HandleFunc("/stop", c.handleStop)
// Start the server in a new goroutine.
go func() {
// Log errors if the HTTP server stops.
if err := http.Serve(ln, mux); err != nil {
log.Printf("HTTP server error on port %d: %v", port, err)
}
}()
log.Printf("HTTP API for consumer %s started on port %d", c.id, port)
return port, nil
}
// handleStats responds with JSON containing consumer and pool metrics.
func (c *Consumer) handleStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
// Gather consumer and pool stats using formatted metrics.
stats := map[string]any{
"consumer_id": c.id,
"queue": c.queue,
"pool_metrics": c.pool.FormattedMetrics(),
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(stats); err != nil {
http.Error(w, fmt.Sprintf("failed to encode stats: %v", err), http.StatusInternalServerError)
}
}
// handleUpdate accepts a POST request with a JSON payload to update the consumer's pool configuration.
// It reuses the consumer's Update method which updates the pool configuration.
func (c *Consumer) handleUpdate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
// Read the request body.
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusBadRequest)
return
}
defer r.Body.Close()
// Call the Update method on the consumer (which in turn updates the pool configuration).
if err := c.Update(r.Context(), body); err != nil {
http.Error(w, fmt.Sprintf("failed to update configuration: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
resp := map[string]string{"status": "configuration updated"}
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
// handlePause pauses the consumer's pool.
func (c *Consumer) handlePause(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
if err := c.Pause(r.Context()); err != nil {
http.Error(w, fmt.Sprintf("failed to pause consumer: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
resp := map[string]string{"status": "consumer paused"}
json.NewEncoder(w).Encode(resp)
}
// handleResume resumes the consumer's pool.
func (c *Consumer) handleResume(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
if err := c.Resume(r.Context()); err != nil {
http.Error(w, fmt.Sprintf("failed to resume consumer: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
resp := map[string]string{"status": "consumer resumed"}
json.NewEncoder(w).Encode(resp)
}
// handleStop stops the consumer's pool.
func (c *Consumer) handleStop(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
// Stop the consumer.
if err := c.Stop(r.Context()); err != nil {
http.Error(w, fmt.Sprintf("failed to stop consumer: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
resp := map[string]string{"status": "consumer stopped"}
json.NewEncoder(w).Encode(resp)
}