mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-06 00:16:49 +08:00
init: publisher
This commit is contained in:
@@ -13,14 +13,17 @@ import (
|
||||
|
||||
func main() {
|
||||
dag := NewDAG()
|
||||
|
||||
dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result {
|
||||
log.Printf("Handling task for queue1: %s", string(task.Payload))
|
||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
})
|
||||
|
||||
dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result {
|
||||
log.Printf("Handling task for queue2: %s", string(task.Payload))
|
||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
})
|
||||
|
||||
dag.AddNode("queue3", func(ctx context.Context, task mq.Task) mq.Result {
|
||||
var data map[string]any
|
||||
err := json.Unmarshal(task.Payload, &data)
|
||||
@@ -29,16 +32,21 @@ func main() {
|
||||
}
|
||||
data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
|
||||
bt, _ := json.Marshal(data)
|
||||
log.Printf("Handling task for queue3: %s", string(task.Payload))
|
||||
log.Printf("Handling task for queue3: %s", string(bt))
|
||||
return mq.Result{Payload: bt, MessageID: task.ID}
|
||||
})
|
||||
|
||||
dag.AddNode("queue4", func(ctx context.Context, task mq.Task) mq.Result {
|
||||
log.Printf("Handling task for queue4: %s", string(task.Payload))
|
||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
})
|
||||
dag.AddEdge("queue1", "queue2")
|
||||
dag.AddLoop("queue2", "queue3") // Add loop to handle array
|
||||
|
||||
// Define edges and loops
|
||||
dag.AddEdge("queue1", "queue2")
|
||||
dag.AddLoop("queue2", "queue3") // Loop through queue3 for each item from queue2
|
||||
dag.AddEdge("queue2", "queue4") // After processing queue2 (including loop), continue to queue4
|
||||
|
||||
// Send task payload
|
||||
go func() {
|
||||
time.Sleep(2 * time.Second)
|
||||
finalResult := dag.Send([]byte(`[{"user_id": 1}, {"user_id": 2}]`))
|
||||
@@ -51,14 +59,14 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// DAG struct to handle tasks and loops
|
||||
// DAG struct to handle tasks, edges, and loops
|
||||
type DAG struct {
|
||||
server *mq.Broker
|
||||
nodes map[string]*mq.Consumer
|
||||
edges map[string][]string
|
||||
loopEdges map[string]string // Handles loop edges
|
||||
loopEdges map[string]string // Tracks loop edges
|
||||
taskChMap map[string]chan mq.Result
|
||||
loopTaskMap map[string]*loopTaskContext // Map to handle loop tasks
|
||||
loopTaskMap map[string]*loopTaskContext // Manages loop tasks
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -70,7 +78,7 @@ type loopTaskContext struct {
|
||||
results []json.RawMessage
|
||||
}
|
||||
|
||||
// NewDAG initializes the DAG structure with necessary fields
|
||||
// NewDAG initializes the DAG structure
|
||||
func NewDAG(opts ...mq.Option) *DAG {
|
||||
d := &DAG{
|
||||
nodes: make(map[string]*mq.Consumer),
|
||||
@@ -105,7 +113,7 @@ func (d *DAG) Start(ctx context.Context) error {
|
||||
return d.server.Start(ctx)
|
||||
}
|
||||
|
||||
// PublishTask sends a task to a queue
|
||||
// PublishTask sends a task to a specific queue
|
||||
func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) (*mq.Task, error) {
|
||||
task := mq.Task{
|
||||
Payload: payload,
|
||||
@@ -116,39 +124,48 @@ func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string,
|
||||
return d.server.Publish(ctx, task, queueName)
|
||||
}
|
||||
|
||||
// TaskCallback is called when a task completes and decides the next step
|
||||
// TaskCallback processes the completion of a task and continues the flow
|
||||
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
||||
log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result))
|
||||
|
||||
// Check if the task belongs to a loop
|
||||
d.mu.Lock()
|
||||
loopCtx, isLoopTask := d.loopTaskMap[task.ID]
|
||||
d.mu.Unlock()
|
||||
|
||||
if isLoopTask {
|
||||
loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID}
|
||||
}
|
||||
|
||||
// Handle loopEdges first
|
||||
if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists {
|
||||
var items []json.RawMessage
|
||||
if err := json.Unmarshal(task.Result, &items); err != nil {
|
||||
return err
|
||||
}
|
||||
loopCtx := &loopTaskContext{
|
||||
subResultCh: make(chan mq.Result, len(items)), // A channel to collect sub-task results
|
||||
subResultCh: make(chan mq.Result, len(items)), // Collect sub-task results
|
||||
totalItems: len(items),
|
||||
results: make([]json.RawMessage, 0, len(items)),
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.loopTaskMap[task.ID] = loopCtx
|
||||
d.mu.Unlock()
|
||||
|
||||
// Publish a sub-task for each item
|
||||
for _, item := range items {
|
||||
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for loop completion
|
||||
go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Normal edge processing
|
||||
// Handle normal edges
|
||||
edges, exists := d.edges[task.CurrentQueue]
|
||||
if exists {
|
||||
for _, edge := range edges {
|
||||
@@ -158,10 +175,11 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForLoopCompletion waits until all sub-tasks are processed, aggregates results, and proceeds.
|
||||
// waitForLoopCompletion waits for all sub-tasks in a loop to finish, aggregates the results, and proceeds
|
||||
func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) {
|
||||
// Get the loop context
|
||||
d.mu.Lock()
|
||||
@@ -169,37 +187,51 @@ func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQ
|
||||
d.mu.Unlock()
|
||||
|
||||
for result := range loopCtx.subResultCh {
|
||||
// Collect the result
|
||||
// Collect results
|
||||
loopCtx.results = append(loopCtx.results, result.Payload)
|
||||
loopCtx.completed++
|
||||
// If all sub-tasks are completed, aggregate results and proceed
|
||||
|
||||
// If all sub-tasks are completed, aggregate results and continue
|
||||
if loopCtx.completed == loopCtx.totalItems {
|
||||
close(loopCtx.subResultCh)
|
||||
|
||||
// Aggregate the results
|
||||
// Aggregate results
|
||||
aggregatedResult, err := json.Marshal(loopCtx.results)
|
||||
if err != nil {
|
||||
log.Printf("Error aggregating results: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Continue flow after loop completion
|
||||
d.mu.Lock()
|
||||
delete(d.loopTaskMap, taskID)
|
||||
d.mu.Unlock()
|
||||
|
||||
// Handle the next nodes in the DAG
|
||||
edges, exists := d.edges[currentQueue]
|
||||
if exists {
|
||||
for _, edge := range edges {
|
||||
_, err := d.PublishTask(ctx, aggregatedResult, edge, taskID)
|
||||
if err != nil {
|
||||
log.Printf("Error publishing aggregated result: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no further edges, finalize
|
||||
d.mu.Lock()
|
||||
if resultCh, ok := d.taskChMap[taskID]; ok {
|
||||
result := mq.Result{
|
||||
resultCh <- mq.Result{
|
||||
Command: "complete",
|
||||
Payload: aggregatedResult,
|
||||
Queue: currentQueue,
|
||||
MessageID: taskID,
|
||||
Status: "done",
|
||||
}
|
||||
resultCh <- result
|
||||
delete(d.taskChMap, taskID)
|
||||
}
|
||||
d.mu.Unlock()
|
||||
|
||||
// Remove the loop context
|
||||
d.mu.Lock()
|
||||
delete(d.loopTaskMap, taskID)
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
Reference in New Issue
Block a user