mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-31 07:06:21 +08:00
init: publisher
This commit is contained in:
44
broker.go
44
broker.go
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@@ -170,7 +171,7 @@ func (b *Broker) Start(ctx context.Context) error {
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
fmt.Println("Broker server started on", b.opts.brokerAddr)
|
||||
log.Println("Server started on", b.opts.brokerAddr)
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
@@ -181,30 +182,34 @@ func (b *Broker) Start(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) error {
|
||||
queue, err := b.AddMessageToQueue(&message, queueName)
|
||||
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) (*Task, error) {
|
||||
queue, task, err := b.AddMessageToQueue(&message, queueName)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if queue.consumers.Size() == 0 {
|
||||
queue.deferred.Set(NewID(), &message)
|
||||
fmt.Println("task deferred as no consumers are connected", queueName)
|
||||
return nil
|
||||
return task, nil
|
||||
}
|
||||
queue.send(ctx, message)
|
||||
return nil
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (b *Broker) NewQueue(qName string) {
|
||||
if _, ok := b.queues.Get(qName); !ok {
|
||||
b.queues.Set(qName, newQueue(qName))
|
||||
func (b *Broker) NewQueue(qName string) *Queue {
|
||||
q, ok := b.queues.Get(qName)
|
||||
if ok {
|
||||
return q
|
||||
}
|
||||
q = newQueue(qName)
|
||||
b.queues.Set(qName, q)
|
||||
return q
|
||||
}
|
||||
|
||||
func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, error) {
|
||||
func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, *Task, error) {
|
||||
queue, ok := b.queues.Get(queueName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("queue %s not found", queueName)
|
||||
return nil, nil, fmt.Errorf("queue %s not found", queueName)
|
||||
}
|
||||
if message.ID == "" {
|
||||
message.ID = NewID()
|
||||
@@ -214,7 +219,7 @@ func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, err
|
||||
}
|
||||
message.CreatedAt = time.Now()
|
||||
queue.messages.Set(message.ID, message)
|
||||
return queue, nil
|
||||
return queue, message, nil
|
||||
}
|
||||
|
||||
func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) error {
|
||||
@@ -250,9 +255,18 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e
|
||||
|
||||
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
defer func() {
|
||||
cmd := Command{
|
||||
Command: SUBSCRIBE_ACK,
|
||||
Queue: queueName,
|
||||
Error: "",
|
||||
}
|
||||
Write(ctx, conn, cmd)
|
||||
log.Printf("Consumer %s joined server on queue %s", consumerID, queueName)
|
||||
}()
|
||||
q, ok := b.queues.Get(queueName)
|
||||
if !ok {
|
||||
b.NewQueue(queueName)
|
||||
q = b.NewQueue(queueName)
|
||||
}
|
||||
con := &consumer{id: consumerID, conn: conn}
|
||||
b.consumers.Set(consumerID, con)
|
||||
@@ -319,7 +333,7 @@ func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error
|
||||
CreatedAt: time.Now(),
|
||||
CurrentQueue: msg.Queue,
|
||||
}
|
||||
err := b.Publish(ctx, task, msg.Queue)
|
||||
_, err := b.Publish(ctx, task, msg.Queue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -343,7 +357,7 @@ func (b *Broker) request(ctx context.Context, conn net.Conn, msg Command) error
|
||||
CreatedAt: time.Now(),
|
||||
CurrentQueue: msg.Queue,
|
||||
}
|
||||
err := b.Publish(ctx, task, msg.Queue)
|
||||
_, err := b.Publish(ctx, task, msg.Queue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ type CMD int
|
||||
|
||||
const (
|
||||
SUBSCRIBE CMD = iota + 1
|
||||
ACK
|
||||
SUBSCRIBE_ACK
|
||||
PUBLISH
|
||||
REQUEST
|
||||
RESPONSE
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -74,6 +75,9 @@ func (c *Consumer) handleCommandMessage(msg Command) error {
|
||||
switch msg.Command {
|
||||
case STOP:
|
||||
return c.Close()
|
||||
case SUBSCRIBE_ACK:
|
||||
log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown command in consumer %d", msg.Command)
|
||||
}
|
||||
|
||||
238
dag.go
238
dag.go
@@ -1,238 +0,0 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/oarkflow/xsync"
|
||||
)
|
||||
|
||||
const (
|
||||
triggerNodeKey string = "triggerNode"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
Queue() string
|
||||
Consumer() *Consumer
|
||||
Handler() Handler
|
||||
}
|
||||
|
||||
type node struct {
|
||||
queue string
|
||||
consumer *Consumer
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func (n *node) Queue() string {
|
||||
return n.queue
|
||||
}
|
||||
|
||||
func (n *node) Consumer() *Consumer {
|
||||
return n.consumer
|
||||
}
|
||||
|
||||
func (n *node) Handler() Handler {
|
||||
return n.handler
|
||||
}
|
||||
|
||||
type DAG struct {
|
||||
nodes *xsync.MapOf[string, Node]
|
||||
edges [][]string
|
||||
loopEdges [][]string
|
||||
broker *Broker
|
||||
startNode Node
|
||||
conditions map[string]map[string]string
|
||||
syncMode bool
|
||||
}
|
||||
|
||||
func NewDAG(syncMode bool) *DAG {
|
||||
dag := &DAG{
|
||||
nodes: xsync.NewMap[string, Node](),
|
||||
conditions: make(map[string]map[string]string),
|
||||
syncMode: syncMode,
|
||||
}
|
||||
dag.broker = NewBroker(WithCallback(dag.TaskCallback))
|
||||
return dag
|
||||
}
|
||||
|
||||
func (dag *DAG) TaskCallback(ctx context.Context, task *Task) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dag *DAG) AddNode(queue string, handler Handler, firstNode ...bool) {
|
||||
con := NewConsumer("consume-" + queue)
|
||||
con.RegisterHandler(queue, handler)
|
||||
dag.broker.NewQueue(queue)
|
||||
n := &node{
|
||||
queue: queue,
|
||||
consumer: con,
|
||||
handler: handler,
|
||||
}
|
||||
if len(firstNode) > 0 && firstNode[0] {
|
||||
dag.startNode = n
|
||||
}
|
||||
dag.nodes.Set(queue, n)
|
||||
}
|
||||
|
||||
func (dag *DAG) AddEdge(fromNodeID, toNodeID string) error {
|
||||
err := dag.validateNodes(fromNodeID, toNodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dag.edges = append(dag.edges, []string{fromNodeID, toNodeID})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dag *DAG) AddCondition(conditionNodeID string, conditions map[string]string) error {
|
||||
for _, nodeID := range conditions {
|
||||
if err := dag.validateNodes(nodeID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
dag.conditions[conditionNodeID] = conditions
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dag *DAG) AddLoop(fromNodeID, toNodeID string) error {
|
||||
err := dag.validateNodes(fromNodeID, toNodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dag.loopEdges = append(dag.loopEdges, []string{fromNodeID, toNodeID})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dag *DAG) Start(ctx context.Context) error {
|
||||
if dag.syncMode {
|
||||
return nil
|
||||
}
|
||||
return dag.broker.Start(ctx)
|
||||
}
|
||||
|
||||
func (dag *DAG) Prepare(ctx context.Context) error {
|
||||
startNode, err := dag.findInitialNode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if startNode == nil {
|
||||
return fmt.Errorf("no initial node found")
|
||||
}
|
||||
dag.startNode = startNode
|
||||
if dag.syncMode {
|
||||
return nil
|
||||
}
|
||||
dag.nodes.ForEach(func(_ string, node Node) bool {
|
||||
go node.Consumer().Consume(ctx)
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dag *DAG) ProcessTask(ctx context.Context, task Task) Result {
|
||||
return dag.processNode(ctx, &task, dag.startNode.Queue())
|
||||
}
|
||||
|
||||
func (dag *DAG) getConditionalNode(status, currentNode string) string {
|
||||
conditions, ok := dag.conditions[currentNode]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
conditionNodeID, ok := conditions[status]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return conditionNodeID
|
||||
}
|
||||
|
||||
func (dag *DAG) validateNodes(nodeIDs ...string) error {
|
||||
for _, nodeID := range nodeIDs {
|
||||
if _, ok := dag.nodes.Get(nodeID); !ok {
|
||||
return fmt.Errorf("node %s not found", nodeID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dag *DAG) processEdge(ctx context.Context, id string, payload []byte, targets []string) {
|
||||
newTask := &Task{
|
||||
ID: id,
|
||||
Payload: payload,
|
||||
}
|
||||
for _, target := range targets {
|
||||
if target != "" {
|
||||
dag.processNode(ctx, newTask, target)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dag *DAG) calculateForFirstNode() (string, bool) {
|
||||
inDegree := make(map[string]int)
|
||||
for _, n := range dag.nodes.Keys() {
|
||||
inDegree[n] = 0
|
||||
}
|
||||
for _, edge := range dag.edges {
|
||||
inDegree[edge[1]]++
|
||||
}
|
||||
for _, edge := range dag.loopEdges {
|
||||
inDegree[edge[1]]++
|
||||
}
|
||||
for n, count := range inDegree {
|
||||
if count == 0 {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (dag *DAG) findInitialNode() (Node, error) {
|
||||
if dag.startNode != nil {
|
||||
return dag.startNode, nil
|
||||
}
|
||||
var nt Node
|
||||
n, ok := dag.calculateForFirstNode()
|
||||
if !ok {
|
||||
return nil, errors.New("no initial node found")
|
||||
}
|
||||
nt, ok = dag.nodes.Get(n)
|
||||
if !ok {
|
||||
return nil, errors.New("no initial node found")
|
||||
}
|
||||
return nt, nil
|
||||
}
|
||||
|
||||
func (dag *DAG) processNode(ctx context.Context, task *Task, queue string) Result {
|
||||
if !dag.syncMode {
|
||||
if err := dag.broker.Publish(ctx, *task, queue); err != nil {
|
||||
fmt.Println("Failed to publish task:", err)
|
||||
}
|
||||
return Result{}
|
||||
}
|
||||
n, ok := dag.nodes.Get(queue)
|
||||
if task.CurrentQueue == "" {
|
||||
task.CurrentQueue = queue
|
||||
}
|
||||
if !ok {
|
||||
fmt.Println("Node not found:", queue)
|
||||
return Result{Error: fmt.Errorf("node not found %s", queue)}
|
||||
}
|
||||
_, err := dag.broker.AddMessageToQueue(task, queue)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
}
|
||||
result := n.Handler()(ctx, *task)
|
||||
if result.Queue == "" {
|
||||
result.Queue = task.CurrentQueue
|
||||
}
|
||||
if result.MessageID == "" {
|
||||
result.MessageID = task.ID
|
||||
}
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
err = dag.broker.HandleProcessedMessage(ctx, result)
|
||||
if err != nil {
|
||||
return Result{Error: err, Status: result.Status}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -1,252 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DataItem map[string]interface{}
|
||||
|
||||
type NodeInfo struct {
|
||||
Name string
|
||||
Conn net.Conn
|
||||
}
|
||||
|
||||
type Broker struct {
|
||||
nodes map[string]NodeInfo
|
||||
edges map[string]string
|
||||
loops map[string][]string
|
||||
conditions map[string]ConditionConfig
|
||||
results map[string][]DataItem // Track task results by task ID
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type ConditionConfig struct {
|
||||
TrueNode string
|
||||
FalseNode string
|
||||
}
|
||||
|
||||
func NewBroker() *Broker {
|
||||
return &Broker{
|
||||
nodes: make(map[string]NodeInfo),
|
||||
edges: make(map[string]string),
|
||||
loops: make(map[string][]string),
|
||||
conditions: make(map[string]ConditionConfig),
|
||||
results: make(map[string][]DataItem),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) RegisterNode(name string, conn net.Conn) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
fmt.Printf("Registering node: %s\n", name)
|
||||
b.nodes[name] = NodeInfo{Name: name, Conn: conn}
|
||||
}
|
||||
|
||||
func (b *Broker) AddEdge(fromNode string, toNode string) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
fmt.Printf("Adding edge from %s to %s\n", fromNode, toNode)
|
||||
b.edges[fromNode] = toNode
|
||||
}
|
||||
|
||||
func (b *Broker) AddLoop(loopNode string, targetNodes []string) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
fmt.Printf("Adding loop at %s with targets: %v\n", loopNode, targetNodes)
|
||||
b.loops[loopNode] = targetNodes
|
||||
}
|
||||
|
||||
func (b *Broker) AddCondition(condNode string, trueNode string, falseNode string) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
fmt.Printf("Adding condition at %s, True: %s, False: %s\n", condNode, trueNode, falseNode)
|
||||
b.conditions[condNode] = ConditionConfig{
|
||||
TrueNode: trueNode,
|
||||
FalseNode: falseNode,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) SendDataToNode(nodeName string, taskID string, data []DataItem, resultChannel chan []DataItem) {
|
||||
b.mu.Lock()
|
||||
node, exists := b.nodes[nodeName]
|
||||
b.mu.Unlock()
|
||||
if !exists {
|
||||
fmt.Printf("Node %s not found!\n", nodeName)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Sending data to %s for task %s...\n", nodeName, taskID)
|
||||
encoder := json.NewEncoder(node.Conn)
|
||||
err := encoder.Encode(data)
|
||||
if err != nil {
|
||||
fmt.Printf("Error sending data to %s: %v\n", nodeName, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Receive the processed data back from the node asynchronously
|
||||
go func() {
|
||||
decoder := json.NewDecoder(node.Conn)
|
||||
var result []DataItem
|
||||
err = decoder.Decode(&result)
|
||||
if err != nil {
|
||||
fmt.Printf("Error receiving data from %s for task %s: %v\n", nodeName, taskID, err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("Received processed data from %s for task %s\n", nodeName, taskID)
|
||||
|
||||
// Send the result to the result aggregation channel
|
||||
resultChannel <- result
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *Broker) DispatchData(startNode string, data []DataItem, taskID string) []DataItem {
|
||||
finalResult := []DataItem{}
|
||||
currentNode := startNode
|
||||
resultChannel := make(chan []DataItem, len(data)) // Create a channel to handle async results
|
||||
|
||||
for {
|
||||
b.mu.Lock()
|
||||
nextNode, hasEdge := b.edges[currentNode]
|
||||
loopTargets, hasLoop := b.loops[currentNode]
|
||||
conditionConfig, hasCondition := b.conditions[currentNode]
|
||||
b.mu.Unlock()
|
||||
|
||||
// Handle Loops (async dispatch)
|
||||
if hasLoop {
|
||||
var wg sync.WaitGroup
|
||||
fmt.Printf("Dispatching to loop nodes from %s for task %s...\n", currentNode, taskID)
|
||||
for _, targetNode := range loopTargets {
|
||||
wg.Add(1)
|
||||
go func(node string) {
|
||||
defer wg.Done()
|
||||
b.SendDataToNode(node, taskID, data, resultChannel)
|
||||
}(targetNode)
|
||||
}
|
||||
|
||||
// Wait for loop processing to complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultChannel)
|
||||
}()
|
||||
|
||||
// Collect async results
|
||||
for res := range resultChannel {
|
||||
finalResult = append(finalResult, res...)
|
||||
}
|
||||
|
||||
b.AggregateResults(taskID, finalResult)
|
||||
return finalResult // Exit after loop processing
|
||||
}
|
||||
|
||||
// Handle Conditions
|
||||
if hasCondition {
|
||||
for _, item := range data {
|
||||
resultChannel := make(chan []DataItem, 1)
|
||||
go b.SendDataToNode(currentNode, taskID, []DataItem{item}, resultChannel)
|
||||
|
||||
select {
|
||||
case result := <-resultChannel:
|
||||
nextNode = conditionConfig.TrueNode
|
||||
finalResult = append(finalResult, b.DispatchData(nextNode, result, taskID)...)
|
||||
case <-time.After(5 * time.Second): // Timeout if no response
|
||||
fmt.Printf("Condition check timed out at node: %s\n", currentNode)
|
||||
nextNode = conditionConfig.FalseNode
|
||||
}
|
||||
}
|
||||
b.AggregateResults(taskID, finalResult)
|
||||
return finalResult // Exit after condition processing
|
||||
}
|
||||
|
||||
// Handle simple edges (sequential flow)
|
||||
if hasEdge {
|
||||
b.SendDataToNode(currentNode, taskID, data, resultChannel)
|
||||
|
||||
select {
|
||||
case result := <-resultChannel:
|
||||
currentNode = nextNode
|
||||
data = result
|
||||
case <-time.After(5 * time.Second): // Timeout if no response
|
||||
fmt.Printf("Processing timed out at node: %s\n", currentNode)
|
||||
return finalResult
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("No edge found for node: %s, stopping...\n", currentNode)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
b.AggregateResults(taskID, finalResult)
|
||||
return finalResult
|
||||
}
|
||||
|
||||
func (b *Broker) AggregateResults(taskID string, result []DataItem) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.results[taskID] = append(b.results[taskID], result...)
|
||||
fmt.Printf("Aggregated result for task %s: %v\n", taskID, b.results[taskID])
|
||||
}
|
||||
|
||||
func (b *Broker) HandleConnections() {
|
||||
listener, err := net.Listen("tcp", ":8081")
|
||||
if err != nil {
|
||||
fmt.Println("Error setting up TCP server:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
fmt.Println("Broker is listening on port 8081...")
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
fmt.Println("Error accepting connection:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
reader := bufio.NewReader(conn)
|
||||
nodeName, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
fmt.Println("Error reading node name:", err)
|
||||
return
|
||||
}
|
||||
nodeName = strings.TrimSpace(nodeName)
|
||||
|
||||
b.RegisterNode(nodeName, conn)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
broker := NewBroker()
|
||||
|
||||
// Set up the flow
|
||||
broker.AddEdge("Node1", "Node2")
|
||||
broker.AddLoop("Node2", []string{"Node3"})
|
||||
broker.AddCondition("Node3", "Node4", "")
|
||||
|
||||
// Start the broker to listen for node connections
|
||||
go broker.HandleConnections()
|
||||
|
||||
fmt.Println("Press ENTER to start the flow after nodes are connected...")
|
||||
bufio.NewReader(os.Stdin).ReadString('\n')
|
||||
|
||||
// Example Data Items
|
||||
dataItems := []DataItem{
|
||||
{"id": 1, "value": "item1"},
|
||||
{"id": 2, "value": "item2"},
|
||||
{"id": 3, "value": "item3"},
|
||||
}
|
||||
|
||||
taskID := "task-001" // Unique ID to track this task
|
||||
finalResult := broker.DispatchData("Node1", dataItems, taskID)
|
||||
fmt.Println("Final result after processing:", finalResult)
|
||||
}
|
||||
187
examples/dag.go
187
examples/dag.go
@@ -2,90 +2,125 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
func handleNode1(_ context.Context, task mq.Task) mq.Result {
|
||||
result := []map[string]string{
|
||||
{"field": "facility", "item": "item1"},
|
||||
{"field": "facility", "item": "item2"},
|
||||
{"field": "facility", "item": "item3"},
|
||||
}
|
||||
var payload string
|
||||
err := json.Unmarshal(task.Payload, &payload)
|
||||
if err != nil {
|
||||
return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)}
|
||||
}
|
||||
fmt.Printf("Processing task at node1: %s\n", string(task.Payload))
|
||||
bt, _ := json.Marshal(result)
|
||||
return mq.Result{Status: "completed", Payload: bt}
|
||||
}
|
||||
|
||||
func handleNode2(_ context.Context, task mq.Task) mq.Result {
|
||||
var payload map[string]string
|
||||
err := json.Unmarshal(task.Payload, &payload)
|
||||
if err != nil {
|
||||
return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)}
|
||||
}
|
||||
status := "fail"
|
||||
if payload["item"] == "item2" {
|
||||
status = "pass"
|
||||
}
|
||||
fmt.Printf("Processing task at node2: %s %s\n", payload, status)
|
||||
bt, _ := json.Marshal(payload)
|
||||
return mq.Result{Status: status, Payload: bt}
|
||||
}
|
||||
|
||||
func handleNode3(_ context.Context, task mq.Task) mq.Result {
|
||||
var data map[string]any
|
||||
err := json.Unmarshal(task.Payload, &data)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
data["item"] = "Item processed in node3"
|
||||
bt, _ := json.Marshal(data)
|
||||
return mq.Result{Status: "completed", Payload: bt}
|
||||
}
|
||||
|
||||
func handleNode4(_ context.Context, task mq.Task) mq.Result {
|
||||
var data map[string]any
|
||||
err := json.Unmarshal(task.Payload, &data)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
data["item"] = "An Item processed in node4"
|
||||
bt, _ := json.Marshal(data)
|
||||
return mq.Result{Status: "completed", Payload: bt}
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
d := mq.NewDAG(false)
|
||||
dag := NewDAG()
|
||||
dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result {
|
||||
log.Printf("Handling task for queue1: %s", string(task.Payload))
|
||||
return mq.Result{Payload: []byte(`{"task": 123}`), MessageID: task.ID}
|
||||
})
|
||||
dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result {
|
||||
log.Printf("Handling task for queue2: %s", string(task.Payload))
|
||||
return mq.Result{Payload: []byte(`{"task": 456}`), MessageID: task.ID}
|
||||
})
|
||||
dag.AddEdge("queue1", "queue2")
|
||||
|
||||
d.AddNode("node1", handleNode1, true)
|
||||
d.AddNode("node2", handleNode2)
|
||||
d.AddNode("node3", handleNode3)
|
||||
d.AddNode("node4", handleNode4)
|
||||
d.AddCondition("node2", map[string]string{"pass": "node3", "fail": "node4"})
|
||||
err := d.AddLoop("node1", "node2")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = d.Prepare(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Start the DAG and process the task
|
||||
// Start DAG processing
|
||||
go func() {
|
||||
if err := d.Start(ctx); err != nil {
|
||||
fmt.Println("Error starting DAG:", err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
finalResult := dag.Send([]byte(`{"task": 1}`))
|
||||
log.Printf("Final result received: %s", string(finalResult.Payload))
|
||||
}()
|
||||
result := d.ProcessTask(ctx, mq.Task{Payload: []byte(`"Start processing"`)})
|
||||
fmt.Println(string(result.Payload))
|
||||
time.Sleep(50 * time.Second)
|
||||
|
||||
err := dag.Start(context.TODO())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type DAG struct {
|
||||
server *mq.Broker
|
||||
nodes map[string]*mq.Consumer
|
||||
edges map[string][]string
|
||||
taskChMap map[string]chan mq.Result // A map to store result channels for each task
|
||||
mu sync.Mutex // Mutex to protect the taskChMap
|
||||
}
|
||||
|
||||
func NewDAG(opts ...mq.Option) *DAG {
|
||||
d := &DAG{
|
||||
nodes: make(map[string]*mq.Consumer),
|
||||
edges: make(map[string][]string),
|
||||
taskChMap: make(map[string]chan mq.Result),
|
||||
}
|
||||
opts = append(opts, mq.WithCallback(d.TaskCallback))
|
||||
d.server = mq.NewBroker(opts...)
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *DAG) AddNode(name string, handler mq.Handler) {
|
||||
con := mq.NewConsumer(name)
|
||||
con.RegisterHandler(name, handler)
|
||||
d.nodes[name] = con
|
||||
}
|
||||
|
||||
func (d *DAG) AddEdge(fromNode string, toNodes ...string) {
|
||||
d.edges[fromNode] = toNodes
|
||||
}
|
||||
|
||||
func (d *DAG) Start(ctx context.Context) error {
|
||||
for _, con := range d.nodes {
|
||||
go con.Consume(ctx)
|
||||
}
|
||||
return d.server.Start(ctx)
|
||||
}
|
||||
|
||||
func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string) (*mq.Task, error) {
|
||||
task := mq.Task{
|
||||
Payload: payload,
|
||||
}
|
||||
return d.server.Publish(ctx, task, queueName)
|
||||
}
|
||||
|
||||
// TaskCallback is the function triggered after each task completion.
|
||||
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
||||
log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result))
|
||||
edges, exists := d.edges[task.CurrentQueue]
|
||||
if !exists {
|
||||
// Lock and send the result to the specific task channel
|
||||
d.mu.Lock()
|
||||
fmt.Println(d.taskChMap, task.ID)
|
||||
for _, resultCh := range d.taskChMap {
|
||||
result := mq.Result{
|
||||
Command: "complete",
|
||||
Payload: task.Result,
|
||||
Queue: task.CurrentQueue,
|
||||
MessageID: task.ID,
|
||||
Status: "done",
|
||||
}
|
||||
resultCh <- result
|
||||
delete(d.taskChMap, task.ID) // Clean up the channel
|
||||
}
|
||||
d.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward the task to the next node(s)
|
||||
for _, edge := range edges {
|
||||
_, err := d.PublishTask(ctx, task.Result, edge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends the task and waits for the final result.
|
||||
func (d *DAG) Send(payload []byte) mq.Result {
|
||||
resultCh := make(chan mq.Result)
|
||||
task, err := d.PublishTask(context.TODO(), payload, "queue1")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.taskChMap[task.ID] = resultCh
|
||||
d.mu.Unlock()
|
||||
finalResult := <-resultCh
|
||||
return finalResult
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user