feat: [wip] - implement storage

This commit is contained in:
sujit
2024-10-19 10:29:30 +05:45
parent 2fb9dfb803
commit 8fef4e69cc
8 changed files with 176 additions and 109 deletions

View File

@@ -38,6 +38,7 @@ type Broker struct {
queues storage.IMap[string, *Queue]
consumers storage.IMap[string, *consumer]
publishers storage.IMap[string, *publisher]
deadLetter storage.IMap[string, *Queue] // DLQ mapping for each queue
opts *Options
}
@@ -47,6 +48,7 @@ func NewBroker(opts ...Option) *Broker {
queues: memory.New[string, *Queue](),
publishers: memory.New[string, *publisher](),
consumers: memory.New[string, *consumer](),
deadLetter: memory.New[string, *Queue](),
opts: options,
}
}
@@ -422,6 +424,19 @@ func (b *Broker) dispatchWorker(queue *Queue) {
delay = b.backoffRetry(queue, task, delay)
}
}
if task.RetryCount > b.opts.maxRetries {
b.sendToDLQ(queue, task)
}
}
}
func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
id, _ := jsonparser.GetString(task.Message.Payload, "id")
if dlq, ok := b.deadLetter.Get(queue.name); ok {
log.Printf("Sending task %s to dead-letter queue for %s", id, queue.name)
dlq.tasks <- task
} else {
log.Printf("No dead-letter queue for %s, discarding task %s", queue.name, id)
}
}

13
ctx.go
View File

@@ -4,11 +4,9 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"net"
"os"
"time"
"github.com/oarkflow/xid"
@@ -17,17 +15,6 @@ import (
"github.com/oarkflow/mq/storage/memory"
)
type Task struct {
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
Expiry time.Time `json:"expiry"`
Error error `json:"error"`
ID string `json:"id"`
Topic string `json:"topic"`
Status string `json:"status"`
Payload json.RawMessage `json:"payload"`
}
type Handler func(context.Context, *Task) Result
func IsClosed(conn net.Conn) bool {

View File

@@ -2,7 +2,9 @@ package dag
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"sync"
@@ -180,6 +182,35 @@ func (tm *DAG) GetStartNode() string {
return tm.startNode
}
func (tm *DAG) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var payload []byte
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err = io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
ctx := r.Context()
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
rs := tm.Process(ctx, payload)
if rs.Error != nil {
http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rs)
}
func (tm *DAG) Start(ctx context.Context, addr string) error {
if !tm.server.SyncMode() {
go func() {

View File

@@ -4,10 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/examples/tasks"
"github.com/oarkflow/mq/services"
@@ -16,8 +14,8 @@ import (
)
func main() {
sync()
async()
Sync()
aSync()
}
func setup(f *dag.DAG) {
@@ -46,7 +44,7 @@ func sendData(f *dag.DAG) {
fmt.Println(string(result.Payload))
}
func sync() {
func Sync() {
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse))
setup(f)
fmt.Println(f.ExportDOT())
@@ -54,46 +52,10 @@ func sync() {
fmt.Println(f.SaveSVG("dag.svg"))
}
func async() {
func aSync() {
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithNotifyResponse(tasks.NotifyResponse))
setup(f)
requestHandler := func(requestType string) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var payload []byte
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err = io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
ctx := r.Context()
if requestType == "request" {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
// ctx = context.WithValue(ctx, "initial_node", "E")
rs := f.Process(ctx, payload)
if rs.Error != nil {
http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rs)
}
}
http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request"))
http.HandleFunc("POST /request", f.ServeHTTP)
http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {

View File

@@ -11,12 +11,17 @@ import (
func main() {
pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback, mq.NewMemoryTaskStorage(10*time.Minute))
time.Sleep(time.Millisecond)
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1)
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5)
pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10)
for i := 0; i < 100; i++ {
if i%10 == 0 {
pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task: I'm high"}, 10)
} else if i%15 == 0 {
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Super High Priority Task: {}"}, 15)
} else {
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1)
}
}
time.Sleep(5 * time.Second)
time.Sleep(15 * time.Second)
pool.PrintMetrics()
pool.Stop()
}

130
pool.go
View File

@@ -29,19 +29,15 @@ type Pool struct {
numOfWorkers int32
paused bool
scheduler *Scheduler
totalScheduledTasks int
overflowBufferLock sync.RWMutex
overflowBuffer []*QueueTask
}
func NewPool(
numOfWorkers, taskQueueSize int,
maxMemoryLoad int64,
handler Handler,
callback Callback,
storage TaskStorage) *Pool {
func NewPool(numOfWorkers, taskQueueSize int, maxMemoryLoad int64, handler Handler, callback Callback, storage TaskStorage) *Pool {
pool := &Pool{
taskQueue: make(PriorityQueue, 0, taskQueueSize),
stop: make(chan struct{}),
taskNotify: make(chan struct{}, 1),
taskNotify: make(chan struct{}, numOfWorkers), // Buffer for workers
maxMemoryLoad: maxMemoryLoad,
handler: handler,
callback: callback,
@@ -70,6 +66,7 @@ func (wp *Pool) Start(numWorkers int) {
}
atomic.StoreInt32(&wp.numOfWorkers, int32(numWorkers))
go wp.monitorWorkerAdjustments()
go wp.startOverflowDrainer()
}
func (wp *Pool) worker() {
@@ -77,46 +74,53 @@ func (wp *Pool) worker() {
for {
select {
case <-wp.taskNotify:
wp.taskQueueLock.Lock()
var task *QueueTask
if len(wp.taskQueue) > 0 && !wp.paused {
task = heap.Pop(&wp.taskQueue).(*QueueTask)
}
wp.taskQueueLock.Unlock()
if task == nil && !wp.paused {
var err error
task, err = wp.taskStorage.FetchNextTask()
if err != nil {
continue
}
}
if task != nil {
taskSize := int64(utils.SizeOf(task.payload))
wp.totalMemoryUsed += taskSize
wp.totalTasks++
result := wp.handler(task.ctx, task.payload)
if result.Error != nil {
wp.errorCount++
} else {
wp.completedTasks++
}
if wp.callback != nil {
if err := wp.callback(task.ctx, result); err != nil {
wp.errorCount++
}
}
if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil {
}
wp.totalMemoryUsed -= taskSize
}
wp.processNextTask()
case <-wp.stop:
return
}
}
}
func (wp *Pool) processNextTask() {
wp.taskQueueLock.Lock()
var task *QueueTask
if len(wp.taskQueue) > 0 && !wp.paused {
task = heap.Pop(&wp.taskQueue).(*QueueTask)
}
wp.taskQueueLock.Unlock()
if task == nil && !wp.paused {
var err error
task, err = wp.taskStorage.FetchNextTask()
if err != nil {
return
}
}
if task != nil {
wp.handleTask(task)
}
}
func (wp *Pool) handleTask(task *QueueTask) {
taskSize := int64(utils.SizeOf(task.payload))
wp.totalMemoryUsed += taskSize
wp.totalTasks++
result := wp.handler(task.ctx, task.payload)
if result.Error != nil {
wp.errorCount++
} else {
wp.completedTasks++
}
if wp.callback != nil {
if err := wp.callback(task.ctx, result); err != nil {
wp.errorCount++
}
}
if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil {
// Handle deletion error
}
wp.totalMemoryUsed -= taskSize
}
func (wp *Pool) monitorWorkerAdjustments() {
for {
select {
@@ -162,9 +166,12 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize)
}
heap.Push(&wp.taskQueue, task)
// Non-blocking task notification
select {
case wp.taskNotify <- struct{}{}:
default:
wp.storeInOverflow(task)
}
return nil
}
@@ -177,6 +184,45 @@ func (wp *Pool) Resume() {
wp.paused = false
}
// Overflow Handling
func (wp *Pool) storeInOverflow(task *QueueTask) {
wp.overflowBufferLock.Lock()
wp.overflowBuffer = append(wp.overflowBuffer, task)
wp.overflowBufferLock.Unlock()
}
// Drains tasks from the overflow buffer when taskNotify is not full
func (wp *Pool) startOverflowDrainer() {
for {
wp.drainOverflowBuffer()
select {
case <-wp.stop:
return
default:
continue
}
}
}
func (wp *Pool) drainOverflowBuffer() {
wp.overflowBufferLock.Lock()
defer wp.overflowBufferLock.Unlock()
for len(wp.overflowBuffer) > 0 {
select {
case wp.taskNotify <- struct{}{}:
// Move the first task from the overflow buffer to the queue
wp.taskQueueLock.Lock()
heap.Push(&wp.taskQueue, wp.overflowBuffer[0])
wp.overflowBuffer = wp.overflowBuffer[1:]
wp.taskQueueLock.Unlock()
default:
// Stop if taskNotify is full
return
}
}
}
func (wp *Pool) Stop() {
close(wp.stop)
wp.wg.Wait()

View File

@@ -21,14 +21,24 @@ func newQueue(name string, queueSize int) *Queue {
}
}
func (b *Broker) NewQueue(qName string) *Queue {
q, ok := b.queues.Get(qName)
if ok {
return q
func (b *Broker) NewQueue(name string) *Queue {
q := &Queue{
name: name,
tasks: make(chan *QueuedTask, b.opts.queueSize),
consumers: memory.New[string, *consumer](),
}
q = newQueue(qName, b.opts.queueSize)
b.queues.Set(qName, q)
b.queues.Set(name, q)
// Create DLQ for the queue
dlq := &Queue{
name: name + "_dlq",
tasks: make(chan *QueuedTask, b.opts.queueSize),
consumers: memory.New[string, *consumer](),
}
b.deadLetter.Set(name, dlq)
go b.dispatchWorker(q)
go b.dispatchWorker(dlq)
return q
}

11
task.go
View File

@@ -5,6 +5,17 @@ import (
"time"
)
type Task struct {
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
Expiry time.Time `json:"expiry"`
Error error `json:"error"`
ID string `json:"id"`
Topic string `json:"topic"`
Status string `json:"status"`
Payload json.RawMessage `json:"payload"`
}
func NewTask(id string, payload json.RawMessage, nodeKey string) *Task {
if id == "" {
id = NewID()