diff --git a/broker.go b/broker.go index 946ba61..027b22d 100644 --- a/broker.go +++ b/broker.go @@ -38,6 +38,7 @@ type Broker struct { queues storage.IMap[string, *Queue] consumers storage.IMap[string, *consumer] publishers storage.IMap[string, *publisher] + deadLetter storage.IMap[string, *Queue] // DLQ mapping for each queue opts *Options } @@ -47,6 +48,7 @@ func NewBroker(opts ...Option) *Broker { queues: memory.New[string, *Queue](), publishers: memory.New[string, *publisher](), consumers: memory.New[string, *consumer](), + deadLetter: memory.New[string, *Queue](), opts: options, } } @@ -422,6 +424,19 @@ func (b *Broker) dispatchWorker(queue *Queue) { delay = b.backoffRetry(queue, task, delay) } } + if task.RetryCount > b.opts.maxRetries { + b.sendToDLQ(queue, task) + } + } +} + +func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) { + id, _ := jsonparser.GetString(task.Message.Payload, "id") + if dlq, ok := b.deadLetter.Get(queue.name); ok { + log.Printf("Sending task %s to dead-letter queue for %s", id, queue.name) + dlq.tasks <- task + } else { + log.Printf("No dead-letter queue for %s, discarding task %s", queue.name, id) } } diff --git a/ctx.go b/ctx.go index 2d16ed7..60c995e 100644 --- a/ctx.go +++ b/ctx.go @@ -4,11 +4,9 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" "net" "os" - "time" "github.com/oarkflow/xid" @@ -17,17 +15,6 @@ import ( "github.com/oarkflow/mq/storage/memory" ) -type Task struct { - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - Expiry time.Time `json:"expiry"` - Error error `json:"error"` - ID string `json:"id"` - Topic string `json:"topic"` - Status string `json:"status"` - Payload json.RawMessage `json:"payload"` -} - type Handler func(context.Context, *Task) Result func IsClosed(conn net.Conn) bool { diff --git a/dag/dag.go b/dag/dag.go index c1285ab..cac9a67 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -2,7 +2,9 @@ package dag import ( "context" + "encoding/json" "fmt" + "io" "log" "net/http" "sync" @@ -180,6 +182,35 @@ func (tm *DAG) GetStartNode() string { return tm.startNode } +func (tm *DAG) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) + return + } + var payload []byte + if r.Body != nil { + defer r.Body.Close() + var err error + payload, err = io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + } else { + http.Error(w, "Empty request body", http.StatusBadRequest) + return + } + ctx := r.Context() + ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) + rs := tm.Process(ctx, payload) + if rs.Error != nil { + http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(rs) +} + func (tm *DAG) Start(ctx context.Context, addr string) error { if !tm.server.SyncMode() { go func() { diff --git a/examples/dag.go b/examples/dag.go index d7371f7..10bcc80 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -4,10 +4,8 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" - "github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/services" @@ -16,8 +14,8 @@ import ( ) func main() { - sync() - async() + Sync() + aSync() } func setup(f *dag.DAG) { @@ -46,7 +44,7 @@ func sendData(f *dag.DAG) { fmt.Println(string(result.Payload)) } -func sync() { +func Sync() { f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse)) setup(f) fmt.Println(f.ExportDOT()) @@ -54,46 +52,10 @@ func sync() { fmt.Println(f.SaveSVG("dag.svg")) } -func async() { +func aSync() { f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithNotifyResponse(tasks.NotifyResponse)) setup(f) - - requestHandler := func(requestType string) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var payload []byte - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err = io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return - } - ctx := r.Context() - if requestType == "request" { - ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) - } - // ctx = context.WithValue(ctx, "initial_node", "E") - rs := f.Process(ctx, payload) - if rs.Error != nil { - http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(rs) - } - } - - http.HandleFunc("POST /publish", requestHandler("publish")) - http.HandleFunc("POST /request", requestHandler("request")) + http.HandleFunc("POST /request", f.ServeHTTP) http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { id := request.PathValue("id") if id != "" { diff --git a/examples/priority.go b/examples/priority.go index c97f862..7d21558 100644 --- a/examples/priority.go +++ b/examples/priority.go @@ -11,12 +11,17 @@ import ( func main() { pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback, mq.NewMemoryTaskStorage(10*time.Minute)) - time.Sleep(time.Millisecond) - pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1) - pool.EnqueueTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5) - pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10) + for i := 0; i < 100; i++ { + if i%10 == 0 { + pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task: I'm high"}, 10) + } else if i%15 == 0 { + pool.EnqueueTask(context.Background(), &mq.Task{ID: "Super High Priority Task: {}"}, 15) + } else { + pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1) + } + } - time.Sleep(5 * time.Second) + time.Sleep(15 * time.Second) pool.PrintMetrics() pool.Stop() } diff --git a/pool.go b/pool.go index f0e6c55..465c542 100644 --- a/pool.go +++ b/pool.go @@ -29,19 +29,15 @@ type Pool struct { numOfWorkers int32 paused bool scheduler *Scheduler - totalScheduledTasks int + overflowBufferLock sync.RWMutex + overflowBuffer []*QueueTask } -func NewPool( - numOfWorkers, taskQueueSize int, - maxMemoryLoad int64, - handler Handler, - callback Callback, - storage TaskStorage) *Pool { +func NewPool(numOfWorkers, taskQueueSize int, maxMemoryLoad int64, handler Handler, callback Callback, storage TaskStorage) *Pool { pool := &Pool{ taskQueue: make(PriorityQueue, 0, taskQueueSize), stop: make(chan struct{}), - taskNotify: make(chan struct{}, 1), + taskNotify: make(chan struct{}, numOfWorkers), // Buffer for workers maxMemoryLoad: maxMemoryLoad, handler: handler, callback: callback, @@ -70,6 +66,7 @@ func (wp *Pool) Start(numWorkers int) { } atomic.StoreInt32(&wp.numOfWorkers, int32(numWorkers)) go wp.monitorWorkerAdjustments() + go wp.startOverflowDrainer() } func (wp *Pool) worker() { @@ -77,46 +74,53 @@ func (wp *Pool) worker() { for { select { case <-wp.taskNotify: - wp.taskQueueLock.Lock() - var task *QueueTask - if len(wp.taskQueue) > 0 && !wp.paused { - task = heap.Pop(&wp.taskQueue).(*QueueTask) - } - wp.taskQueueLock.Unlock() - if task == nil && !wp.paused { - var err error - task, err = wp.taskStorage.FetchNextTask() - if err != nil { - - continue - } - } - if task != nil { - taskSize := int64(utils.SizeOf(task.payload)) - wp.totalMemoryUsed += taskSize - wp.totalTasks++ - result := wp.handler(task.ctx, task.payload) - if result.Error != nil { - wp.errorCount++ - } else { - wp.completedTasks++ - } - if wp.callback != nil { - if err := wp.callback(task.ctx, result); err != nil { - wp.errorCount++ - } - } - if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil { - - } - wp.totalMemoryUsed -= taskSize - } + wp.processNextTask() case <-wp.stop: return } } } +func (wp *Pool) processNextTask() { + wp.taskQueueLock.Lock() + var task *QueueTask + if len(wp.taskQueue) > 0 && !wp.paused { + task = heap.Pop(&wp.taskQueue).(*QueueTask) + } + wp.taskQueueLock.Unlock() + if task == nil && !wp.paused { + var err error + task, err = wp.taskStorage.FetchNextTask() + if err != nil { + return + } + } + if task != nil { + wp.handleTask(task) + } +} + +func (wp *Pool) handleTask(task *QueueTask) { + taskSize := int64(utils.SizeOf(task.payload)) + wp.totalMemoryUsed += taskSize + wp.totalTasks++ + result := wp.handler(task.ctx, task.payload) + if result.Error != nil { + wp.errorCount++ + } else { + wp.completedTasks++ + } + if wp.callback != nil { + if err := wp.callback(task.ctx, result); err != nil { + wp.errorCount++ + } + } + if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil { + // Handle deletion error + } + wp.totalMemoryUsed -= taskSize +} + func (wp *Pool) monitorWorkerAdjustments() { for { select { @@ -162,9 +166,12 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize) } heap.Push(&wp.taskQueue, task) + + // Non-blocking task notification select { case wp.taskNotify <- struct{}{}: default: + wp.storeInOverflow(task) } return nil } @@ -177,6 +184,45 @@ func (wp *Pool) Resume() { wp.paused = false } +// Overflow Handling +func (wp *Pool) storeInOverflow(task *QueueTask) { + wp.overflowBufferLock.Lock() + wp.overflowBuffer = append(wp.overflowBuffer, task) + wp.overflowBufferLock.Unlock() +} + +// Drains tasks from the overflow buffer when taskNotify is not full +func (wp *Pool) startOverflowDrainer() { + for { + wp.drainOverflowBuffer() + select { + case <-wp.stop: + return + default: + continue + } + } +} + +func (wp *Pool) drainOverflowBuffer() { + wp.overflowBufferLock.Lock() + defer wp.overflowBufferLock.Unlock() + + for len(wp.overflowBuffer) > 0 { + select { + case wp.taskNotify <- struct{}{}: + // Move the first task from the overflow buffer to the queue + wp.taskQueueLock.Lock() + heap.Push(&wp.taskQueue, wp.overflowBuffer[0]) + wp.overflowBuffer = wp.overflowBuffer[1:] + wp.taskQueueLock.Unlock() + default: + // Stop if taskNotify is full + return + } + } +} + func (wp *Pool) Stop() { close(wp.stop) wp.wg.Wait() diff --git a/queue.go b/queue.go index 8b4428c..5f7395f 100644 --- a/queue.go +++ b/queue.go @@ -21,14 +21,24 @@ func newQueue(name string, queueSize int) *Queue { } } -func (b *Broker) NewQueue(qName string) *Queue { - q, ok := b.queues.Get(qName) - if ok { - return q +func (b *Broker) NewQueue(name string) *Queue { + q := &Queue{ + name: name, + tasks: make(chan *QueuedTask, b.opts.queueSize), + consumers: memory.New[string, *consumer](), } - q = newQueue(qName, b.opts.queueSize) - b.queues.Set(qName, q) + b.queues.Set(name, q) + + // Create DLQ for the queue + dlq := &Queue{ + name: name + "_dlq", + tasks: make(chan *QueuedTask, b.opts.queueSize), + consumers: memory.New[string, *consumer](), + } + b.deadLetter.Set(name, dlq) + go b.dispatchWorker(q) + go b.dispatchWorker(dlq) return q } diff --git a/task.go b/task.go index 15785ca..4930d0a 100644 --- a/task.go +++ b/task.go @@ -5,6 +5,17 @@ import ( "time" ) +type Task struct { + CreatedAt time.Time `json:"created_at"` + ProcessedAt time.Time `json:"processed_at"` + Expiry time.Time `json:"expiry"` + Error error `json:"error"` + ID string `json:"id"` + Topic string `json:"topic"` + Status string `json:"status"` + Payload json.RawMessage `json:"payload"` +} + func NewTask(id string, payload json.RawMessage, nodeKey string) *Task { if id == "" { id = NewID()