diff --git a/consumer.go b/consumer.go index 563f723..fab3a37 100644 --- a/consumer.go +++ b/consumer.go @@ -5,10 +5,11 @@ import ( "encoding/json" "errors" "fmt" - "math/rand" "net" "sync" "time" + + "github.com/oarkflow/mq/utils" ) type Consumer struct { @@ -109,38 +110,26 @@ func (c *Consumer) readMessage(ctx context.Context, message []byte) error { return nil } -const ( - maxRetries = 5 - initialDelay = 2 * time.Second - maxBackoff = 30 * time.Second // Upper limit for backoff delay - jitterPercent = 0.5 // 50% jitter -) - func (c *Consumer) AttemptConnect() error { var conn net.Conn var err error - delay := initialDelay - for i := 0; i < maxRetries; i++ { + delay := c.opts.initialDelay + for i := 0; i < c.opts.maxRetries; i++ { conn, err = net.Dial("tcp", c.opts.brokerAddr) if err == nil { c.conn = conn return nil } - sleepDuration := calculateJitter(delay) - fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, maxRetries, err, sleepDuration) + sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent) + fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration) time.Sleep(sleepDuration) delay *= 2 - if delay > maxBackoff { - delay = maxBackoff + if delay > c.opts.maxBackoff { + delay = c.opts.maxBackoff } } - return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, maxRetries, err) -} - -func calculateJitter(baseDelay time.Duration) time.Duration { - jitter := time.Duration(rand.Float64()*jitterPercent*float64(baseDelay)) - time.Duration(jitterPercent*float64(baseDelay)/2) - return baseDelay + jitter + return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err) } func (c *Consumer) readConn(ctx context.Context, conn net.Conn, message []byte) error { diff --git a/utils/retry.go b/utils/retry.go new file mode 100644 index 0000000..d3bdb0a --- /dev/null +++ b/utils/retry.go @@ -0,0 +1,11 @@ +package utils + +import ( + "math/rand" + "time" +) + +func CalculateJitter(baseDelay time.Duration, percent float64) time.Duration { + jitter := time.Duration(rand.Float64()*percent*float64(baseDelay)) - time.Duration(percent*float64(baseDelay)/2) + return baseDelay + jitter +}