diff --git a/consumer.go b/consumer.go index efc0cc7..6cc3266 100644 --- a/consumer.go +++ b/consumer.go @@ -9,13 +9,15 @@ import ( "net/http" "strings" "time" - + "github.com/oarkflow/json" - + "github.com/oarkflow/json/jsonparser" - + "github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/storage" + "github.com/oarkflow/mq/storage/memory" "github.com/oarkflow/mq/utils" ) @@ -38,6 +40,7 @@ type Consumer struct { opts *Options id string queue string + pIDs storage.IMap[string, bool] } func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer { @@ -47,6 +50,7 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons opts: options, queue: queue, handler: handler, + pIDs: memory.New[string, bool](), } } @@ -154,11 +158,33 @@ func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn log.Printf("Error unmarshalling message: %v", err) return } - ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) - if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil { - c.sendDenyMessage(ctx, task.ID, msg.Queue, err) + + // Check if the task has already been processed + if _, exists := c.pIDs.Get(task.ID); exists { + log.Printf("Task %s already processed, skipping...", task.ID) return } + + ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) + retryCount := 0 + for { + err := c.pool.EnqueueTask(ctx, &task, 1) + if err == nil { + // Mark the task as processed + c.pIDs.Set(task.ID, true) + break + } + + if retryCount >= c.opts.maxRetries { + c.sendDenyMessage(ctx, task.ID, msg.Queue, err) + return + } + + retryCount++ + backoffDuration := utils.CalculateJitter(c.opts.initialDelay*(1< received from %s on %s for Task %s", pub.id, msg.Queue, taskID) - + ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) if err := b.send(ctx, conn, ack); err != nil { log.Printf("Error sending PUBLISH_ACK: %v\n", err) @@ -591,9 +599,14 @@ func (b *Broker) receive(ctx context.Context, c net.Conn) (*codec.Message, error } func (b *Broker) broadcastToConsumers(msg *codec.Message) { - if queue, ok := b.queues.Get(msg.Queue); ok { - task := &QueuedTask{Message: msg, RetryCount: 0} - queue.tasks <- task + if tenantQueues, ok := b.queues.Get(msg.Queue); ok { + tenantQueues.ForEach(func(_, queueName string) bool { + if queue, ok := tenantQueues.Get(queueName); ok { + task := &QueuedTask{Message: msg, RetryCount: 0} + queue.tasks <- task + } + return true + }) } } @@ -755,6 +768,7 @@ func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) { for !success && task.RetryCount <= b.opts.maxRetries { if b.dispatchTaskToConsumer(ctx, queue, task) { success = true + b.acknowledgeTask(ctx, task.Message.Queue, queue.name) } else { task.RetryCount++ delay = b.backoffRetry(queue, task, delay) @@ -779,6 +793,14 @@ func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) { func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task *QueuedTask) bool { var consumerFound bool var err error + + // Deduplication: Check if the task has already been processed + taskID, _ := jsonparser.GetString(task.Message.Payload, "id") + if _, exists := b.pIDs.Get(taskID); exists { + log.Printf("Task %s already processed, skipping...", taskID) + return true + } + queue.consumers.ForEach(func(_ string, con *consumer) bool { if con.state != consts.ConsumerStateActive { err = fmt.Errorf("consumer %s is not active", con.id) @@ -786,10 +808,13 @@ func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task } if err := b.send(ctx, con.conn, task.Message); err == nil { consumerFound = true + // Mark the task as processed + b.pIDs.Set(taskID, true) return false } return true }) + if err != nil { log.Println(err.Error()) return false @@ -800,7 +825,7 @@ func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task result := Result{ Status: "NO_CONSUMER", Topic: queue.name, - TaskID: "", + TaskID: taskID, Ctx: ctx, } _ = b.opts.notifyResponse(ctx, result) @@ -843,8 +868,32 @@ func (b *Broker) NewQueue(name string) *Queue { tasks: make(chan *QueuedTask, b.opts.queueSize), consumers: memory.New[string, *consumer](), } - b.queues.Set(name, q) - + b.queues.Set(name, memory.New[string, *Queue]()) + b.queues.Get(name).Set(name, q) + + // Create DLQ for the queue + dlq := &Queue{ + name: name + "_dlq", + tasks: make(chan *QueuedTask, b.opts.queueSize), + consumers: memory.New[string, *consumer](), + } + b.deadLetter.Set(name, dlq) + ctx := context.Background() + go b.dispatchWorker(ctx, q) + go b.dispatchWorker(ctx, dlq) + return q +} + +// Ensure message ordering in task queues +func (b *Broker) NewQueueWithOrdering(name string) *Queue { + q := &Queue{ + name: name, + tasks: make(chan *QueuedTask, b.opts.queueSize), + consumers: memory.New[string, *consumer](), + } + b.queues.Set(name, memory.New[string, *Queue]()) + b.queues.Get(name).Set(name, q) + // Create DLQ for the queue dlq := &Queue{ name: name + "_dlq", @@ -885,3 +934,134 @@ func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) { } } } + +// Add explicit acknowledgment for successful task processing +func (b *Broker) acknowledgeTask(ctx context.Context, taskID string, queueName string) { + log.Printf("Acknowledging task %s on queue %s", taskID, queueName) + if b.opts.notifyResponse != nil { + result := Result{ + Status: "ACKNOWLEDGED", + Topic: queueName, + TaskID: taskID, + Ctx: ctx, + } + _ = b.opts.notifyResponse(ctx, result) + } +} + +// Add authentication and authorization for publishers and consumers +func (b *Broker) Authenticate(ctx context.Context, credentials map[string]string) error { + username, userExists := credentials["username"] + password, passExists := credentials["password"] + if !userExists || !passExists { + return fmt.Errorf("missing credentials") + } + // Example: Hardcoded credentials for simplicity + if username != "admin" || password != "password" { + return fmt.Errorf("invalid credentials") + } + return nil +} + +func (b *Broker) Authorize(ctx context.Context, role string, action string) error { + // Example: Simple role-based authorization + if role == "publisher" && action == "publish" { + return nil + } + if role == "consumer" && action == "consume" { + return nil + } + return fmt.Errorf("unauthorized action") +} + +// Add support for multi-tenancy +func (b *Broker) AddTenant(tenantID string) error { + if _, exists := b.queues.Get(tenantID); exists { + return fmt.Errorf("tenant %s already exists", tenantID) + } + b.queues.Set(tenantID, memory.New[string, *Queue]()) + return nil +} + +func (b *Broker) RemoveTenant(tenantID string) error { + if _, exists := b.queues.Get(tenantID); !exists { + return fmt.Errorf("tenant %s does not exist", tenantID) + } + b.queues.Del(tenantID) + return nil +} + +// Ensure tenant-specific queues and operations +func (b *Broker) NewQueueForTenant(tenantID, queueName string) (*Queue, error) { + tenantQueues, ok := b.queues.Get(tenantID) + if !ok { + return nil, fmt.Errorf("tenant %s does not exist", tenantID) + } + if _, exists := tenantQueues.Get(queueName); exists { + return nil, fmt.Errorf("queue %s already exists for tenant %s", queueName, tenantID) + } + q := &Queue{ + name: queueName, + tasks: make(chan *QueuedTask, b.opts.queueSize), + consumers: memory.New[string, *consumer](), + } + tenantQueues.Set(queueName, q) + + // Create tenant-specific DLQ + dlq := &Queue{ + name: queueName + "_dlq", + tasks: make(chan *QueuedTask, b.opts.queueSize), + consumers: memory.New[string, *consumer](), + } + tenantQueues.Set(queueName+"_dlq", dlq) + ctx := context.Background() + go b.dispatchWorker(ctx, q) + go b.dispatchWorker(ctx, dlq) + return q, nil +} + +func (b *Broker) PublishForTenant(ctx context.Context, tenantID string, task *Task, queueName string) error { + tenantQueues, ok := b.queues.Get(tenantID) + if !ok { + return fmt.Errorf("tenant %s does not exist", tenantID) + } + queue, ok := tenantQueues.Get(queueName) + if !ok { + return fmt.Errorf("queue %s does not exist for tenant %s", queueName, tenantID) + } + taskID := task.ID + if taskID == "" { + taskID = NewID() + task.ID = taskID + } + queuedTask := &QueuedTask{Message: codec.NewMessage(consts.PUBLISH, task.Payload, queueName, nil), RetryCount: 0} + queue.tasks <- queuedTask + return nil +} + +func (b *Broker) SubscribeForTenant(ctx context.Context, tenantID, queueName string, conn net.Conn) error { + tenantQueues, ok := b.queues.Get(tenantID) + if !ok { + return fmt.Errorf("tenant %s does not exist", tenantID) + } + queue, ok := tenantQueues.Get(queueName) + if !ok { + return fmt.Errorf("queue %s does not exist for tenant %s", queueName, tenantID) + } + consumerID := b.AddConsumer(ctx, queueName, conn) + queue.consumers.Set(consumerID, &consumer{id: consumerID, conn: conn}) + return nil +} + +func (b *Broker) ListQueuesForTenant(tenantID string) ([]string, error) { + tenantQueues, ok := b.queues.Get(tenantID) + if !ok { + return nil, fmt.Errorf("tenant %s does not exist", tenantID) + } + var queueNames []string + tenantQueues.ForEach(func(queueName string, _ *Queue) bool { + queueNames = append(queueNames, queueName) + return true + }) + return queueNames, nil +} diff --git a/task.go b/task.go index e01e36e..d5e4490 100644 --- a/task.go +++ b/task.go @@ -2,6 +2,7 @@ package mq import ( "context" + "fmt" "time" "github.com/oarkflow/json" @@ -93,3 +94,20 @@ func WithDedupKey(key string) TaskOption { t.DedupKey = key } } + +// Add advanced dead-letter queue management +func (b *Broker) ReprocessDLQ(queueName string) error { + dlqName := queueName + "_dlq" + dlq, ok := b.deadLetter.Get(dlqName) + if !ok { + return fmt.Errorf("dead-letter queue %s does not exist", dlqName) + } + for { + select { + case task := <-dlq.tasks: + b.NewQueue(queueName).tasks <- task + default: + return nil + } + } +}