init: publisher

This commit is contained in:
sujit
2024-09-29 08:39:33 +05:45
parent cf9b97d95b
commit f0635931a2

View File

@@ -13,17 +13,14 @@ import (
func main() { func main() {
dag := NewDAG() dag := NewDAG()
dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result { dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result {
log.Printf("Handling task for queue1: %s", string(task.Payload)) log.Printf("Handling task for queue1: %s", string(task.Payload))
return mq.Result{Payload: task.Payload, MessageID: task.ID} return mq.Result{Payload: task.Payload, MessageID: task.ID}
}) })
dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result { dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result {
log.Printf("Handling task for queue2: %s", string(task.Payload)) log.Printf("Handling task for queue2: %s", string(task.Payload))
return mq.Result{Payload: task.Payload, MessageID: task.ID} return mq.Result{Payload: task.Payload, MessageID: task.ID}
}) })
dag.AddNode("queue3", func(ctx context.Context, task mq.Task) mq.Result { dag.AddNode("queue3", func(ctx context.Context, task mq.Task) mq.Result {
var data map[string]any var data map[string]any
err := json.Unmarshal(task.Payload, &data) err := json.Unmarshal(task.Payload, &data)
@@ -35,18 +32,22 @@ func main() {
log.Printf("Handling task for queue3: %s", string(bt)) log.Printf("Handling task for queue3: %s", string(bt))
return mq.Result{Payload: bt, MessageID: task.ID} return mq.Result{Payload: bt, MessageID: task.ID}
}) })
dag.AddNode("queue4", func(ctx context.Context, task mq.Task) mq.Result { dag.AddNode("queue4", func(ctx context.Context, task mq.Task) mq.Result {
var data []map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
}
log.Printf("Handling task for queue4: %s", string(task.Payload)) log.Printf("Handling task for queue4: %s", string(task.Payload))
return mq.Result{Payload: task.Payload, MessageID: task.ID} payload := map[string]any{"storage": data}
bt, _ := json.Marshal(payload)
return mq.Result{Payload: bt, MessageID: task.ID}
}) })
// Define edges and loops
dag.AddEdge("queue1", "queue2") dag.AddEdge("queue1", "queue2")
dag.AddLoop("queue2", "queue3") // Loop through queue3 for each item from queue2 dag.AddLoop("queue2", "queue3")
dag.AddEdge("queue2", "queue4") // After processing queue2 (including loop), continue to queue4 dag.AddEdge("queue2", "queue4")
// Send task payload
go func() { go func() {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
finalResult := dag.Send([]byte(`[{"user_id": 1}, {"user_id": 2}]`)) finalResult := dag.Send([]byte(`[{"user_id": 1}, {"user_id": 2}]`))
@@ -59,18 +60,16 @@ func main() {
} }
} }
// DAG struct to handle tasks, edges, and loops
type DAG struct { type DAG struct {
server *mq.Broker server *mq.Broker
nodes map[string]*mq.Consumer nodes map[string]*mq.Consumer
edges map[string][]string edges map[string][]string
loopEdges map[string]string // Tracks loop edges loopEdges map[string]string
taskChMap map[string]chan mq.Result taskChMap map[string]chan mq.Result
loopTaskMap map[string]*loopTaskContext // Manages loop tasks loopTaskMap map[string]*loopTaskContext
mu sync.Mutex mu sync.Mutex
} }
// Structure to store the loop task context
type loopTaskContext struct { type loopTaskContext struct {
subResultCh chan mq.Result subResultCh chan mq.Result
totalItems int totalItems int
@@ -78,7 +77,6 @@ type loopTaskContext struct {
results []json.RawMessage results []json.RawMessage
} }
// NewDAG initializes the DAG structure
func NewDAG(opts ...mq.Option) *DAG { func NewDAG(opts ...mq.Option) *DAG {
d := &DAG{ d := &DAG{
nodes: make(map[string]*mq.Consumer), nodes: make(map[string]*mq.Consumer),
@@ -113,7 +111,6 @@ func (d *DAG) Start(ctx context.Context) error {
return d.server.Start(ctx) return d.server.Start(ctx)
} }
// PublishTask sends a task to a specific queue
func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) (*mq.Task, error) { func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) (*mq.Task, error) {
task := mq.Task{ task := mq.Task{
Payload: payload, Payload: payload,
@@ -124,48 +121,37 @@ func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string,
return d.server.Publish(ctx, task, queueName) return d.server.Publish(ctx, task, queueName)
} }
// TaskCallback processes the completion of a task and continues the flow
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error { func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result)) log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result))
// Check if the task belongs to a loop
d.mu.Lock() d.mu.Lock()
loopCtx, isLoopTask := d.loopTaskMap[task.ID] loopCtx, isLoopTask := d.loopTaskMap[task.ID]
d.mu.Unlock() d.mu.Unlock()
if isLoopTask { if isLoopTask {
loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID} loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID}
} }
// Handle loopEdges first
if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists { if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists {
var items []json.RawMessage var items []json.RawMessage
if err := json.Unmarshal(task.Result, &items); err != nil { if err := json.Unmarshal(task.Result, &items); err != nil {
return err return err
} }
loopCtx := &loopTaskContext{ loopCtx := &loopTaskContext{
subResultCh: make(chan mq.Result, len(items)), // Collect sub-task results subResultCh: make(chan mq.Result, len(items)),
totalItems: len(items), totalItems: len(items),
results: make([]json.RawMessage, 0, len(items)), results: make([]json.RawMessage, 0, len(items)),
} }
d.mu.Lock() d.mu.Lock()
d.loopTaskMap[task.ID] = loopCtx d.loopTaskMap[task.ID] = loopCtx
d.mu.Unlock() d.mu.Unlock()
// Publish a sub-task for each item
for _, item := range items { for _, item := range items {
_, err := d.PublishTask(ctx, item, loopNode, task.ID) _, err := d.PublishTask(ctx, item, loopNode, task.ID)
if err != nil { if err != nil {
return err return err
} }
} }
// Wait for loop completion
go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue) go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue)
return nil return nil
} }
// Handle normal edges
edges, exists := d.edges[task.CurrentQueue] edges, exists := d.edges[task.CurrentQueue]
if exists { if exists {
for _, edge := range edges { for _, edge := range edges {
@@ -174,40 +160,40 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
return err return err
} }
} }
} else {
d.mu.Lock()
if resultCh, ok := d.taskChMap[task.ID]; ok {
resultCh <- mq.Result{
Command: "complete",
Payload: task.Result,
Queue: task.CurrentQueue,
MessageID: task.ID,
Status: "done",
}
delete(d.taskChMap, task.ID)
}
d.mu.Unlock()
} }
return nil return nil
} }
// waitForLoopCompletion waits for all sub-tasks in a loop to finish, aggregates the results, and proceeds
func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) { func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) {
// Get the loop context
d.mu.Lock() d.mu.Lock()
loopCtx := d.loopTaskMap[taskID] loopCtx := d.loopTaskMap[taskID]
d.mu.Unlock() d.mu.Unlock()
for result := range loopCtx.subResultCh { for result := range loopCtx.subResultCh {
// Collect results
loopCtx.results = append(loopCtx.results, result.Payload) loopCtx.results = append(loopCtx.results, result.Payload)
loopCtx.completed++ loopCtx.completed++
// If all sub-tasks are completed, aggregate results and continue
if loopCtx.completed == loopCtx.totalItems { if loopCtx.completed == loopCtx.totalItems {
close(loopCtx.subResultCh) close(loopCtx.subResultCh)
// Aggregate results
aggregatedResult, err := json.Marshal(loopCtx.results) aggregatedResult, err := json.Marshal(loopCtx.results)
if err != nil { if err != nil {
log.Printf("Error aggregating results: %v", err) log.Printf("Error aggregating results: %v", err)
return return
} }
// Continue flow after loop completion
d.mu.Lock() d.mu.Lock()
delete(d.loopTaskMap, taskID) delete(d.loopTaskMap, taskID)
d.mu.Unlock() d.mu.Unlock()
// Handle the next nodes in the DAG
edges, exists := d.edges[currentQueue] edges, exists := d.edges[currentQueue]
if exists { if exists {
for _, edge := range edges { for _, edge := range edges {
@@ -218,7 +204,6 @@ func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQ
} }
} }
} else { } else {
// If no further edges, finalize
d.mu.Lock() d.mu.Lock()
if resultCh, ok := d.taskChMap[taskID]; ok { if resultCh, ok := d.taskChMap[taskID]; ok {
resultCh <- mq.Result{ resultCh <- mq.Result{
@@ -232,13 +217,10 @@ func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQ
} }
d.mu.Unlock() d.mu.Unlock()
} }
break
} }
} }
} }
// Send sends the initial task and waits for the final result
func (d *DAG) Send(payload []byte) mq.Result { func (d *DAG) Send(payload []byte) mq.Result {
resultCh := make(chan mq.Result) resultCh := make(chan mq.Result)
task, err := d.PublishTask(context.TODO(), payload, "queue1") task, err := d.PublishTask(context.TODO(), payload, "queue1")