This commit is contained in:
sujit
2025-09-24 18:09:55 +05:45
parent 43f315e07c
commit 428ec634f7
3 changed files with 73 additions and 16 deletions

74
mq.go
View File

@@ -3,6 +3,7 @@ package mq
import (
"context"
"fmt"
"io"
"log"
"net"
"strings"
@@ -341,6 +342,7 @@ func HeadersWithConsumerIDAndQueue(ctx context.Context, id, queue string) map[st
type QueuedTask struct {
Message *codec.Message
Task *Task
RetryCount int
}
@@ -482,6 +484,8 @@ type Broker struct {
metricsServer *MetricsServer
authenticatedConns storage.IMap[string, bool] // authenticated connections
taskHeaders storage.IMap[string, map[string]string] // task headers by task ID
pendingTasks map[string]map[string]*Task // consumerID -> taskID -> task
mu sync.RWMutex // for pendingTasks
isShutdown int32
shutdown chan struct{}
wg sync.WaitGroup
@@ -493,12 +497,13 @@ func NewBroker(opts ...Option) *Broker {
broker := &Broker{
// Core broker functionality
queues: memory.New[string, *Queue](),
publishers: memory.New[string, *publisher](),
consumers: memory.New[string, *consumer](),
deadLetter: memory.New[string, *Queue](),
pIDs: memory.New[string, bool](),
opts: options,
queues: memory.New[string, *Queue](),
publishers: memory.New[string, *publisher](),
consumers: memory.New[string, *consumer](),
deadLetter: memory.New[string, *Queue](),
pIDs: memory.New[string, bool](),
pendingTasks: make(map[string]map[string]*Task),
opts: options,
// Enhanced production features
connectionPool: NewConnectionPool(1000), // max 1000 connections
@@ -611,10 +616,42 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
if conn != nil {
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
if b.isConnectionClosed(err) {
log.Printf("Connection closed for consumer at %s", conn.RemoteAddr())
// Find and remove the consumer
b.consumers.ForEach(func(id string, con *consumer) bool {
if con.conn == conn {
b.RemoveConsumer(id)
b.mu.Lock()
if tasks, ok := b.pendingTasks[id]; ok {
for _, task := range tasks {
// Put back to queue
if q, ok := b.queues.Get(task.Topic); ok {
select {
case q.tasks <- &QueuedTask{Task: task, RetryCount: task.Retries}:
log.Printf("Requeued task %s for consumer %s", task.ID, id)
default:
log.Printf("Failed to requeue task %s, queue full", task.ID)
}
}
}
delete(b.pendingTasks, id)
}
b.mu.Unlock()
return false
}
return true
})
} else {
log.Printf("Error reading from connection: %v", err)
}
}
}
func (b *Broker) isConnectionClosed(err error) bool {
return err == io.EOF || strings.Contains(err.Error(), "connection closed") || strings.Contains(err.Error(), "connection reset")
}
func (b *Broker) isAuthenticated(connID string) bool {
if b.securityManager == nil {
return true // no security, allow all
@@ -722,6 +759,11 @@ func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
b.mu.Lock()
if tasks, ok := b.pendingTasks[consumerID]; ok {
delete(tasks, taskID)
}
b.mu.Unlock()
}
func (b *Broker) MessageDeny(ctx context.Context, msg *codec.Message) {
@@ -947,7 +989,9 @@ func (b *Broker) receive(ctx context.Context, c net.Conn) (*codec.Message, error
func (b *Broker) broadcastToConsumers(msg *codec.Message) {
if queue, ok := b.queues.Get(msg.Queue); ok {
task := &QueuedTask{Message: msg, RetryCount: 0}
var t Task
json.Unmarshal(msg.Payload, &t)
task := &QueuedTask{Message: msg, Task: &t, RetryCount: 0}
queue.tasks <- task
}
}
@@ -1166,7 +1210,7 @@ func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) {
}
func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
id, _ := jsonparser.GetString(task.Message.Payload, "id")
id := task.Task.ID
if dlq, ok := b.deadLetter.Get(queue.name); ok {
log.Printf("Sending task %s to dead-letter queue for %s", id, queue.name)
dlq.tasks <- task
@@ -1180,7 +1224,7 @@ func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task
var err error
// Deduplication: Check if the task has already been processed
taskID, _ := jsonparser.GetString(task.Message.Payload, "id")
taskID := task.Task.ID
if _, exists := b.pIDs.Get(taskID); exists {
log.Printf("Task %s already processed, skipping...", taskID)
return true
@@ -1193,6 +1237,12 @@ func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task
}
// Send message asynchronously to avoid blocking
b.mu.Lock()
if b.pendingTasks[con.id] == nil {
b.pendingTasks[con.id] = make(map[string]*Task)
}
b.pendingTasks[con.id][taskID] = task.Task
b.mu.Unlock()
go func(consumer *consumer, message *codec.Message) {
sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -1229,10 +1279,10 @@ func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task
return consumerFound
}
// Modified backoffRetry: Removed reinsertion of the task into queue.tasks.
// Modified backoffRetry: Re-insert the task into queue.tasks after backoff.
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Task.Topic)
// Perform backoff sleep in a goroutine to avoid blocking
go func() {