mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-04 23:52:48 +08:00
feat: [wip] - implement storage
This commit is contained in:
15
broker.go
15
broker.go
@@ -38,6 +38,7 @@ type Broker struct {
|
||||
queues storage.IMap[string, *Queue]
|
||||
consumers storage.IMap[string, *consumer]
|
||||
publishers storage.IMap[string, *publisher]
|
||||
deadLetter storage.IMap[string, *Queue] // DLQ mapping for each queue
|
||||
opts *Options
|
||||
}
|
||||
|
||||
@@ -47,6 +48,7 @@ func NewBroker(opts ...Option) *Broker {
|
||||
queues: memory.New[string, *Queue](),
|
||||
publishers: memory.New[string, *publisher](),
|
||||
consumers: memory.New[string, *consumer](),
|
||||
deadLetter: memory.New[string, *Queue](),
|
||||
opts: options,
|
||||
}
|
||||
}
|
||||
@@ -422,6 +424,19 @@ func (b *Broker) dispatchWorker(queue *Queue) {
|
||||
delay = b.backoffRetry(queue, task, delay)
|
||||
}
|
||||
}
|
||||
if task.RetryCount > b.opts.maxRetries {
|
||||
b.sendToDLQ(queue, task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
|
||||
id, _ := jsonparser.GetString(task.Message.Payload, "id")
|
||||
if dlq, ok := b.deadLetter.Get(queue.name); ok {
|
||||
log.Printf("Sending task %s to dead-letter queue for %s", id, queue.name)
|
||||
dlq.tasks <- task
|
||||
} else {
|
||||
log.Printf("No dead-letter queue for %s, discarding task %s", queue.name, id)
|
||||
}
|
||||
}
|
||||
|
||||
|
13
ctx.go
13
ctx.go
@@ -4,11 +4,9 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xid"
|
||||
|
||||
@@ -17,17 +15,6 @@ import (
|
||||
"github.com/oarkflow/mq/storage/memory"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ProcessedAt time.Time `json:"processed_at"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
Error error `json:"error"`
|
||||
ID string `json:"id"`
|
||||
Topic string `json:"topic"`
|
||||
Status string `json:"status"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
type Handler func(context.Context, *Task) Result
|
||||
|
||||
func IsClosed(conn net.Conn) bool {
|
||||
|
31
dag/dag.go
31
dag/dag.go
@@ -2,7 +2,9 @@ package dag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -180,6 +182,35 @@ func (tm *DAG) GetStartNode() string {
|
||||
return tm.startNode
|
||||
}
|
||||
|
||||
func (tm *DAG) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var payload []byte
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
var err error
|
||||
payload, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "Empty request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
|
||||
rs := tm.Process(ctx, payload)
|
||||
if rs.Error != nil {
|
||||
http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(rs)
|
||||
}
|
||||
|
||||
func (tm *DAG) Start(ctx context.Context, addr string) error {
|
||||
if !tm.server.SyncMode() {
|
||||
go func() {
|
||||
|
@@ -4,10 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
"github.com/oarkflow/mq/services"
|
||||
|
||||
@@ -16,8 +14,8 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
sync()
|
||||
async()
|
||||
Sync()
|
||||
aSync()
|
||||
}
|
||||
|
||||
func setup(f *dag.DAG) {
|
||||
@@ -46,7 +44,7 @@ func sendData(f *dag.DAG) {
|
||||
fmt.Println(string(result.Payload))
|
||||
}
|
||||
|
||||
func sync() {
|
||||
func Sync() {
|
||||
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse))
|
||||
setup(f)
|
||||
fmt.Println(f.ExportDOT())
|
||||
@@ -54,46 +52,10 @@ func sync() {
|
||||
fmt.Println(f.SaveSVG("dag.svg"))
|
||||
}
|
||||
|
||||
func async() {
|
||||
func aSync() {
|
||||
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithNotifyResponse(tasks.NotifyResponse))
|
||||
setup(f)
|
||||
|
||||
requestHandler := func(requestType string) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var payload []byte
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
var err error
|
||||
payload, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "Empty request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
if requestType == "request" {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
|
||||
}
|
||||
// ctx = context.WithValue(ctx, "initial_node", "E")
|
||||
rs := f.Process(ctx, payload)
|
||||
if rs.Error != nil {
|
||||
http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(rs)
|
||||
}
|
||||
}
|
||||
|
||||
http.HandleFunc("POST /publish", requestHandler("publish"))
|
||||
http.HandleFunc("POST /request", requestHandler("request"))
|
||||
http.HandleFunc("POST /request", f.ServeHTTP)
|
||||
http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
id := request.PathValue("id")
|
||||
if id != "" {
|
||||
|
@@ -11,12 +11,17 @@ import (
|
||||
func main() {
|
||||
pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback, mq.NewMemoryTaskStorage(10*time.Minute))
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1)
|
||||
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5)
|
||||
pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10)
|
||||
for i := 0; i < 100; i++ {
|
||||
if i%10 == 0 {
|
||||
pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task: I'm high"}, 10)
|
||||
} else if i%15 == 0 {
|
||||
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Super High Priority Task: {}"}, 15)
|
||||
} else {
|
||||
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
time.Sleep(15 * time.Second)
|
||||
pool.PrintMetrics()
|
||||
pool.Stop()
|
||||
}
|
||||
|
130
pool.go
130
pool.go
@@ -29,19 +29,15 @@ type Pool struct {
|
||||
numOfWorkers int32
|
||||
paused bool
|
||||
scheduler *Scheduler
|
||||
totalScheduledTasks int
|
||||
overflowBufferLock sync.RWMutex
|
||||
overflowBuffer []*QueueTask
|
||||
}
|
||||
|
||||
func NewPool(
|
||||
numOfWorkers, taskQueueSize int,
|
||||
maxMemoryLoad int64,
|
||||
handler Handler,
|
||||
callback Callback,
|
||||
storage TaskStorage) *Pool {
|
||||
func NewPool(numOfWorkers, taskQueueSize int, maxMemoryLoad int64, handler Handler, callback Callback, storage TaskStorage) *Pool {
|
||||
pool := &Pool{
|
||||
taskQueue: make(PriorityQueue, 0, taskQueueSize),
|
||||
stop: make(chan struct{}),
|
||||
taskNotify: make(chan struct{}, 1),
|
||||
taskNotify: make(chan struct{}, numOfWorkers), // Buffer for workers
|
||||
maxMemoryLoad: maxMemoryLoad,
|
||||
handler: handler,
|
||||
callback: callback,
|
||||
@@ -70,6 +66,7 @@ func (wp *Pool) Start(numWorkers int) {
|
||||
}
|
||||
atomic.StoreInt32(&wp.numOfWorkers, int32(numWorkers))
|
||||
go wp.monitorWorkerAdjustments()
|
||||
go wp.startOverflowDrainer()
|
||||
}
|
||||
|
||||
func (wp *Pool) worker() {
|
||||
@@ -77,46 +74,53 @@ func (wp *Pool) worker() {
|
||||
for {
|
||||
select {
|
||||
case <-wp.taskNotify:
|
||||
wp.taskQueueLock.Lock()
|
||||
var task *QueueTask
|
||||
if len(wp.taskQueue) > 0 && !wp.paused {
|
||||
task = heap.Pop(&wp.taskQueue).(*QueueTask)
|
||||
}
|
||||
wp.taskQueueLock.Unlock()
|
||||
if task == nil && !wp.paused {
|
||||
var err error
|
||||
task, err = wp.taskStorage.FetchNextTask()
|
||||
if err != nil {
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
if task != nil {
|
||||
taskSize := int64(utils.SizeOf(task.payload))
|
||||
wp.totalMemoryUsed += taskSize
|
||||
wp.totalTasks++
|
||||
result := wp.handler(task.ctx, task.payload)
|
||||
if result.Error != nil {
|
||||
wp.errorCount++
|
||||
} else {
|
||||
wp.completedTasks++
|
||||
}
|
||||
if wp.callback != nil {
|
||||
if err := wp.callback(task.ctx, result); err != nil {
|
||||
wp.errorCount++
|
||||
}
|
||||
}
|
||||
if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil {
|
||||
|
||||
}
|
||||
wp.totalMemoryUsed -= taskSize
|
||||
}
|
||||
wp.processNextTask()
|
||||
case <-wp.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *Pool) processNextTask() {
|
||||
wp.taskQueueLock.Lock()
|
||||
var task *QueueTask
|
||||
if len(wp.taskQueue) > 0 && !wp.paused {
|
||||
task = heap.Pop(&wp.taskQueue).(*QueueTask)
|
||||
}
|
||||
wp.taskQueueLock.Unlock()
|
||||
if task == nil && !wp.paused {
|
||||
var err error
|
||||
task, err = wp.taskStorage.FetchNextTask()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if task != nil {
|
||||
wp.handleTask(task)
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *Pool) handleTask(task *QueueTask) {
|
||||
taskSize := int64(utils.SizeOf(task.payload))
|
||||
wp.totalMemoryUsed += taskSize
|
||||
wp.totalTasks++
|
||||
result := wp.handler(task.ctx, task.payload)
|
||||
if result.Error != nil {
|
||||
wp.errorCount++
|
||||
} else {
|
||||
wp.completedTasks++
|
||||
}
|
||||
if wp.callback != nil {
|
||||
if err := wp.callback(task.ctx, result); err != nil {
|
||||
wp.errorCount++
|
||||
}
|
||||
}
|
||||
if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil {
|
||||
// Handle deletion error
|
||||
}
|
||||
wp.totalMemoryUsed -= taskSize
|
||||
}
|
||||
|
||||
func (wp *Pool) monitorWorkerAdjustments() {
|
||||
for {
|
||||
select {
|
||||
@@ -162,9 +166,12 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
|
||||
return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize)
|
||||
}
|
||||
heap.Push(&wp.taskQueue, task)
|
||||
|
||||
// Non-blocking task notification
|
||||
select {
|
||||
case wp.taskNotify <- struct{}{}:
|
||||
default:
|
||||
wp.storeInOverflow(task)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -177,6 +184,45 @@ func (wp *Pool) Resume() {
|
||||
wp.paused = false
|
||||
}
|
||||
|
||||
// Overflow Handling
|
||||
func (wp *Pool) storeInOverflow(task *QueueTask) {
|
||||
wp.overflowBufferLock.Lock()
|
||||
wp.overflowBuffer = append(wp.overflowBuffer, task)
|
||||
wp.overflowBufferLock.Unlock()
|
||||
}
|
||||
|
||||
// Drains tasks from the overflow buffer when taskNotify is not full
|
||||
func (wp *Pool) startOverflowDrainer() {
|
||||
for {
|
||||
wp.drainOverflowBuffer()
|
||||
select {
|
||||
case <-wp.stop:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *Pool) drainOverflowBuffer() {
|
||||
wp.overflowBufferLock.Lock()
|
||||
defer wp.overflowBufferLock.Unlock()
|
||||
|
||||
for len(wp.overflowBuffer) > 0 {
|
||||
select {
|
||||
case wp.taskNotify <- struct{}{}:
|
||||
// Move the first task from the overflow buffer to the queue
|
||||
wp.taskQueueLock.Lock()
|
||||
heap.Push(&wp.taskQueue, wp.overflowBuffer[0])
|
||||
wp.overflowBuffer = wp.overflowBuffer[1:]
|
||||
wp.taskQueueLock.Unlock()
|
||||
default:
|
||||
// Stop if taskNotify is full
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *Pool) Stop() {
|
||||
close(wp.stop)
|
||||
wp.wg.Wait()
|
||||
|
22
queue.go
22
queue.go
@@ -21,14 +21,24 @@ func newQueue(name string, queueSize int) *Queue {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) NewQueue(qName string) *Queue {
|
||||
q, ok := b.queues.Get(qName)
|
||||
if ok {
|
||||
return q
|
||||
func (b *Broker) NewQueue(name string) *Queue {
|
||||
q := &Queue{
|
||||
name: name,
|
||||
tasks: make(chan *QueuedTask, b.opts.queueSize),
|
||||
consumers: memory.New[string, *consumer](),
|
||||
}
|
||||
q = newQueue(qName, b.opts.queueSize)
|
||||
b.queues.Set(qName, q)
|
||||
b.queues.Set(name, q)
|
||||
|
||||
// Create DLQ for the queue
|
||||
dlq := &Queue{
|
||||
name: name + "_dlq",
|
||||
tasks: make(chan *QueuedTask, b.opts.queueSize),
|
||||
consumers: memory.New[string, *consumer](),
|
||||
}
|
||||
b.deadLetter.Set(name, dlq)
|
||||
|
||||
go b.dispatchWorker(q)
|
||||
go b.dispatchWorker(dlq)
|
||||
return q
|
||||
}
|
||||
|
||||
|
11
task.go
11
task.go
@@ -5,6 +5,17 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ProcessedAt time.Time `json:"processed_at"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
Error error `json:"error"`
|
||||
ID string `json:"id"`
|
||||
Topic string `json:"topic"`
|
||||
Status string `json:"status"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
func NewTask(id string, payload json.RawMessage, nodeKey string) *Task {
|
||||
if id == "" {
|
||||
id = NewID()
|
||||
|
Reference in New Issue
Block a user