feat: [wip] - implement storage

This commit is contained in:
sujit
2024-10-19 10:29:30 +05:45
parent 2fb9dfb803
commit 8fef4e69cc
8 changed files with 176 additions and 109 deletions

View File

@@ -38,6 +38,7 @@ type Broker struct {
queues storage.IMap[string, *Queue] queues storage.IMap[string, *Queue]
consumers storage.IMap[string, *consumer] consumers storage.IMap[string, *consumer]
publishers storage.IMap[string, *publisher] publishers storage.IMap[string, *publisher]
deadLetter storage.IMap[string, *Queue] // DLQ mapping for each queue
opts *Options opts *Options
} }
@@ -47,6 +48,7 @@ func NewBroker(opts ...Option) *Broker {
queues: memory.New[string, *Queue](), queues: memory.New[string, *Queue](),
publishers: memory.New[string, *publisher](), publishers: memory.New[string, *publisher](),
consumers: memory.New[string, *consumer](), consumers: memory.New[string, *consumer](),
deadLetter: memory.New[string, *Queue](),
opts: options, opts: options,
} }
} }
@@ -422,6 +424,19 @@ func (b *Broker) dispatchWorker(queue *Queue) {
delay = b.backoffRetry(queue, task, delay) delay = b.backoffRetry(queue, task, delay)
} }
} }
if task.RetryCount > b.opts.maxRetries {
b.sendToDLQ(queue, task)
}
}
}
func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
id, _ := jsonparser.GetString(task.Message.Payload, "id")
if dlq, ok := b.deadLetter.Get(queue.name); ok {
log.Printf("Sending task %s to dead-letter queue for %s", id, queue.name)
dlq.tasks <- task
} else {
log.Printf("No dead-letter queue for %s, discarding task %s", queue.name, id)
} }
} }

13
ctx.go
View File

@@ -4,11 +4,9 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json"
"fmt" "fmt"
"net" "net"
"os" "os"
"time"
"github.com/oarkflow/xid" "github.com/oarkflow/xid"
@@ -17,17 +15,6 @@ import (
"github.com/oarkflow/mq/storage/memory" "github.com/oarkflow/mq/storage/memory"
) )
type Task struct {
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
Expiry time.Time `json:"expiry"`
Error error `json:"error"`
ID string `json:"id"`
Topic string `json:"topic"`
Status string `json:"status"`
Payload json.RawMessage `json:"payload"`
}
type Handler func(context.Context, *Task) Result type Handler func(context.Context, *Task) Result
func IsClosed(conn net.Conn) bool { func IsClosed(conn net.Conn) bool {

View File

@@ -2,7 +2,9 @@ package dag
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io"
"log" "log"
"net/http" "net/http"
"sync" "sync"
@@ -180,6 +182,35 @@ func (tm *DAG) GetStartNode() string {
return tm.startNode return tm.startNode
} }
func (tm *DAG) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var payload []byte
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err = io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
ctx := r.Context()
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
rs := tm.Process(ctx, payload)
if rs.Error != nil {
http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rs)
}
func (tm *DAG) Start(ctx context.Context, addr string) error { func (tm *DAG) Start(ctx context.Context, addr string) error {
if !tm.server.SyncMode() { if !tm.server.SyncMode() {
go func() { go func() {

View File

@@ -4,10 +4,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/examples/tasks"
"github.com/oarkflow/mq/services" "github.com/oarkflow/mq/services"
@@ -16,8 +14,8 @@ import (
) )
func main() { func main() {
sync() Sync()
async() aSync()
} }
func setup(f *dag.DAG) { func setup(f *dag.DAG) {
@@ -46,7 +44,7 @@ func sendData(f *dag.DAG) {
fmt.Println(string(result.Payload)) fmt.Println(string(result.Payload))
} }
func sync() { func Sync() {
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse)) f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse))
setup(f) setup(f)
fmt.Println(f.ExportDOT()) fmt.Println(f.ExportDOT())
@@ -54,46 +52,10 @@ func sync() {
fmt.Println(f.SaveSVG("dag.svg")) fmt.Println(f.SaveSVG("dag.svg"))
} }
func async() { func aSync() {
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithNotifyResponse(tasks.NotifyResponse)) f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithNotifyResponse(tasks.NotifyResponse))
setup(f) setup(f)
http.HandleFunc("POST /request", f.ServeHTTP)
requestHandler := func(requestType string) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var payload []byte
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err = io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
ctx := r.Context()
if requestType == "request" {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
// ctx = context.WithValue(ctx, "initial_node", "E")
rs := f.Process(ctx, payload)
if rs.Error != nil {
http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rs)
}
}
http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request"))
http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id") id := request.PathValue("id")
if id != "" { if id != "" {

View File

@@ -11,12 +11,17 @@ import (
func main() { func main() {
pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback, mq.NewMemoryTaskStorage(10*time.Minute)) pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback, mq.NewMemoryTaskStorage(10*time.Minute))
time.Sleep(time.Millisecond) for i := 0; i < 100; i++ {
if i%10 == 0 {
pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task: I'm high"}, 10)
} else if i%15 == 0 {
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Super High Priority Task: {}"}, 15)
} else {
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1) pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1)
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5) }
pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10) }
time.Sleep(5 * time.Second) time.Sleep(15 * time.Second)
pool.PrintMetrics() pool.PrintMetrics()
pool.Stop() pool.Stop()
} }

78
pool.go
View File

@@ -29,19 +29,15 @@ type Pool struct {
numOfWorkers int32 numOfWorkers int32
paused bool paused bool
scheduler *Scheduler scheduler *Scheduler
totalScheduledTasks int overflowBufferLock sync.RWMutex
overflowBuffer []*QueueTask
} }
func NewPool( func NewPool(numOfWorkers, taskQueueSize int, maxMemoryLoad int64, handler Handler, callback Callback, storage TaskStorage) *Pool {
numOfWorkers, taskQueueSize int,
maxMemoryLoad int64,
handler Handler,
callback Callback,
storage TaskStorage) *Pool {
pool := &Pool{ pool := &Pool{
taskQueue: make(PriorityQueue, 0, taskQueueSize), taskQueue: make(PriorityQueue, 0, taskQueueSize),
stop: make(chan struct{}), stop: make(chan struct{}),
taskNotify: make(chan struct{}, 1), taskNotify: make(chan struct{}, numOfWorkers), // Buffer for workers
maxMemoryLoad: maxMemoryLoad, maxMemoryLoad: maxMemoryLoad,
handler: handler, handler: handler,
callback: callback, callback: callback,
@@ -70,6 +66,7 @@ func (wp *Pool) Start(numWorkers int) {
} }
atomic.StoreInt32(&wp.numOfWorkers, int32(numWorkers)) atomic.StoreInt32(&wp.numOfWorkers, int32(numWorkers))
go wp.monitorWorkerAdjustments() go wp.monitorWorkerAdjustments()
go wp.startOverflowDrainer()
} }
func (wp *Pool) worker() { func (wp *Pool) worker() {
@@ -77,6 +74,14 @@ func (wp *Pool) worker() {
for { for {
select { select {
case <-wp.taskNotify: case <-wp.taskNotify:
wp.processNextTask()
case <-wp.stop:
return
}
}
}
func (wp *Pool) processNextTask() {
wp.taskQueueLock.Lock() wp.taskQueueLock.Lock()
var task *QueueTask var task *QueueTask
if len(wp.taskQueue) > 0 && !wp.paused { if len(wp.taskQueue) > 0 && !wp.paused {
@@ -87,11 +92,15 @@ func (wp *Pool) worker() {
var err error var err error
task, err = wp.taskStorage.FetchNextTask() task, err = wp.taskStorage.FetchNextTask()
if err != nil { if err != nil {
return
continue
} }
} }
if task != nil { if task != nil {
wp.handleTask(task)
}
}
func (wp *Pool) handleTask(task *QueueTask) {
taskSize := int64(utils.SizeOf(task.payload)) taskSize := int64(utils.SizeOf(task.payload))
wp.totalMemoryUsed += taskSize wp.totalMemoryUsed += taskSize
wp.totalTasks++ wp.totalTasks++
@@ -107,14 +116,9 @@ func (wp *Pool) worker() {
} }
} }
if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil { if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil {
// Handle deletion error
} }
wp.totalMemoryUsed -= taskSize wp.totalMemoryUsed -= taskSize
}
case <-wp.stop:
return
}
}
} }
func (wp *Pool) monitorWorkerAdjustments() { func (wp *Pool) monitorWorkerAdjustments() {
@@ -162,9 +166,12 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize) return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize)
} }
heap.Push(&wp.taskQueue, task) heap.Push(&wp.taskQueue, task)
// Non-blocking task notification
select { select {
case wp.taskNotify <- struct{}{}: case wp.taskNotify <- struct{}{}:
default: default:
wp.storeInOverflow(task)
} }
return nil return nil
} }
@@ -177,6 +184,45 @@ func (wp *Pool) Resume() {
wp.paused = false wp.paused = false
} }
// Overflow Handling
func (wp *Pool) storeInOverflow(task *QueueTask) {
wp.overflowBufferLock.Lock()
wp.overflowBuffer = append(wp.overflowBuffer, task)
wp.overflowBufferLock.Unlock()
}
// Drains tasks from the overflow buffer when taskNotify is not full
func (wp *Pool) startOverflowDrainer() {
for {
wp.drainOverflowBuffer()
select {
case <-wp.stop:
return
default:
continue
}
}
}
func (wp *Pool) drainOverflowBuffer() {
wp.overflowBufferLock.Lock()
defer wp.overflowBufferLock.Unlock()
for len(wp.overflowBuffer) > 0 {
select {
case wp.taskNotify <- struct{}{}:
// Move the first task from the overflow buffer to the queue
wp.taskQueueLock.Lock()
heap.Push(&wp.taskQueue, wp.overflowBuffer[0])
wp.overflowBuffer = wp.overflowBuffer[1:]
wp.taskQueueLock.Unlock()
default:
// Stop if taskNotify is full
return
}
}
}
func (wp *Pool) Stop() { func (wp *Pool) Stop() {
close(wp.stop) close(wp.stop)
wp.wg.Wait() wp.wg.Wait()

View File

@@ -21,14 +21,24 @@ func newQueue(name string, queueSize int) *Queue {
} }
} }
func (b *Broker) NewQueue(qName string) *Queue { func (b *Broker) NewQueue(name string) *Queue {
q, ok := b.queues.Get(qName) q := &Queue{
if ok { name: name,
return q tasks: make(chan *QueuedTask, b.opts.queueSize),
consumers: memory.New[string, *consumer](),
} }
q = newQueue(qName, b.opts.queueSize) b.queues.Set(name, q)
b.queues.Set(qName, q)
// Create DLQ for the queue
dlq := &Queue{
name: name + "_dlq",
tasks: make(chan *QueuedTask, b.opts.queueSize),
consumers: memory.New[string, *consumer](),
}
b.deadLetter.Set(name, dlq)
go b.dispatchWorker(q) go b.dispatchWorker(q)
go b.dispatchWorker(dlq)
return q return q
} }

11
task.go
View File

@@ -5,6 +5,17 @@ import (
"time" "time"
) )
type Task struct {
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
Expiry time.Time `json:"expiry"`
Error error `json:"error"`
ID string `json:"id"`
Topic string `json:"topic"`
Status string `json:"status"`
Payload json.RawMessage `json:"payload"`
}
func NewTask(id string, payload json.RawMessage, nodeKey string) *Task { func NewTask(id string, payload json.RawMessage, nodeKey string) *Task {
if id == "" { if id == "" {
id = NewID() id = NewID()