Files
mq/examples/dag.go
2024-09-29 09:26:18 +05:45

236 lines
6.0 KiB
Go

package main
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"
"github.com/oarkflow/mq"
)
func main() {
dag := NewDAG()
dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result {
log.Printf("Handling task for queue1: %s", string(task.Payload))
return mq.Result{Payload: task.Payload, MessageID: task.ID}
})
dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result {
log.Printf("Handling task for queue2: %s", string(task.Payload))
return mq.Result{Payload: task.Payload, MessageID: task.ID}
})
dag.AddNode("queue3", func(ctx context.Context, task mq.Task) mq.Result {
var data map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
}
data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
bt, _ := json.Marshal(data)
log.Printf("Handling task for queue3: %s", string(bt))
return mq.Result{Payload: bt, MessageID: task.ID}
})
dag.AddNode("queue4", func(ctx context.Context, task mq.Task) mq.Result {
var data []map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
}
log.Printf("Handling task for queue4: %s", string(task.Payload))
payload := map[string]any{"storage": data}
bt, _ := json.Marshal(payload)
return mq.Result{Payload: bt, MessageID: task.ID}
})
dag.AddEdge("queue1", "queue2")
dag.AddLoop("queue2", "queue3")
dag.AddEdge("queue2", "queue4")
go func() {
time.Sleep(2 * time.Second)
finalResult := dag.Send([]byte(`[{"user_id": 1}, {"user_id": 2}]`))
log.Printf("Final result received: %s", string(finalResult.Payload))
}()
err := dag.Start(context.TODO())
if err != nil {
panic(err)
}
}
type DAG struct {
server *mq.Broker
nodes map[string]*mq.Consumer
edges map[string][]string
loopEdges map[string]string
taskChMap map[string]chan mq.Result
loopTaskMap map[string]*loopTaskContext
mu sync.Mutex
}
type loopTaskContext struct {
subResultCh chan mq.Result
totalItems int
completed int
results []json.RawMessage
}
func NewDAG(opts ...mq.Option) *DAG {
d := &DAG{
nodes: make(map[string]*mq.Consumer),
edges: make(map[string][]string),
loopEdges: make(map[string]string),
taskChMap: make(map[string]chan mq.Result),
loopTaskMap: make(map[string]*loopTaskContext),
}
opts = append(opts, mq.WithCallback(d.TaskCallback))
d.server = mq.NewBroker(opts...)
return d
}
func (d *DAG) AddNode(name string, handler mq.Handler) {
con := mq.NewConsumer(name)
con.RegisterHandler(name, handler)
d.nodes[name] = con
}
func (d *DAG) AddEdge(fromNode string, toNodes ...string) {
d.edges[fromNode] = toNodes
}
func (d *DAG) AddLoop(fromNode string, toNode string) {
d.loopEdges[fromNode] = toNode
}
func (d *DAG) Start(ctx context.Context) error {
for _, con := range d.nodes {
go con.Consume(ctx)
}
return d.server.Start(ctx)
}
func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) (*mq.Task, error) {
task := mq.Task{
Payload: payload,
}
if len(taskID) > 0 {
task.ID = taskID[0]
}
return d.server.Publish(ctx, task, queueName)
}
func (d *DAG) Send(payload []byte) mq.Result {
resultCh := make(chan mq.Result)
task, err := d.PublishTask(context.TODO(), payload, "queue1")
if err != nil {
panic(err)
}
d.mu.Lock()
d.taskChMap[task.ID] = resultCh
d.mu.Unlock()
finalResult := <-resultCh
return finalResult
}
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result))
d.mu.Lock()
loopCtx, isLoopTask := d.loopTaskMap[task.ID]
d.mu.Unlock()
if isLoopTask {
loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID}
}
if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists {
var items []json.RawMessage
if err := json.Unmarshal(task.Result, &items); err != nil {
return err
}
loopCtx := &loopTaskContext{
subResultCh: make(chan mq.Result, len(items)),
totalItems: len(items),
results: make([]json.RawMessage, 0, len(items)),
}
d.mu.Lock()
d.loopTaskMap[task.ID] = loopCtx
d.mu.Unlock()
for _, item := range items {
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
if err != nil {
return err
}
}
go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue)
return nil
}
edges, exists := d.edges[task.CurrentQueue]
if exists {
for _, edge := range edges {
_, err := d.PublishTask(ctx, task.Result, edge, task.ID)
if err != nil {
return err
}
}
} else {
d.mu.Lock()
if resultCh, ok := d.taskChMap[task.ID]; ok {
resultCh <- mq.Result{
Command: "complete",
Payload: task.Result,
Queue: task.CurrentQueue,
MessageID: task.ID,
Status: "done",
}
delete(d.taskChMap, task.ID)
}
d.mu.Unlock()
}
return nil
}
func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) {
d.mu.Lock()
loopCtx := d.loopTaskMap[taskID]
d.mu.Unlock()
for result := range loopCtx.subResultCh {
loopCtx.results = append(loopCtx.results, result.Payload)
loopCtx.completed++
if loopCtx.completed == loopCtx.totalItems {
close(loopCtx.subResultCh)
aggregatedResult, err := json.Marshal(loopCtx.results)
if err != nil {
log.Printf("Error aggregating results: %v", err)
return
}
d.mu.Lock()
delete(d.loopTaskMap, taskID)
d.mu.Unlock()
edges, exists := d.edges[currentQueue]
if exists {
for _, edge := range edges {
_, err := d.PublishTask(ctx, aggregatedResult, edge, taskID)
if err != nil {
log.Printf("Error publishing aggregated result: %v", err)
return
}
}
} else {
d.mu.Lock()
if resultCh, ok := d.taskChMap[taskID]; ok {
resultCh <- mq.Result{
Command: "complete",
Payload: aggregatedResult,
Queue: currentQueue,
MessageID: taskID,
Status: "done",
}
delete(d.taskChMap, taskID)
}
d.mu.Unlock()
}
}
}
}