diff --git a/broker.go b/broker.go index 97470b8..891eb44 100644 --- a/broker.go +++ b/broker.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net" "time" @@ -170,7 +171,7 @@ func (b *Broker) Start(ctx context.Context) error { defer func() { _ = listener.Close() }() - fmt.Println("Broker server started on", b.opts.brokerAddr) + log.Println("Server started on", b.opts.brokerAddr) for { conn, err := listener.Accept() if err != nil { @@ -181,30 +182,34 @@ func (b *Broker) Start(ctx context.Context) error { } } -func (b *Broker) Publish(ctx context.Context, message Task, queueName string) error { - queue, err := b.AddMessageToQueue(&message, queueName) +func (b *Broker) Publish(ctx context.Context, message Task, queueName string) (*Task, error) { + queue, task, err := b.AddMessageToQueue(&message, queueName) if err != nil { - return err + return nil, err } if queue.consumers.Size() == 0 { queue.deferred.Set(NewID(), &message) fmt.Println("task deferred as no consumers are connected", queueName) - return nil + return task, nil } queue.send(ctx, message) - return nil + return task, nil } -func (b *Broker) NewQueue(qName string) { - if _, ok := b.queues.Get(qName); !ok { - b.queues.Set(qName, newQueue(qName)) +func (b *Broker) NewQueue(qName string) *Queue { + q, ok := b.queues.Get(qName) + if ok { + return q } + q = newQueue(qName) + b.queues.Set(qName, q) + return q } -func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, error) { +func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, *Task, error) { queue, ok := b.queues.Get(queueName) if !ok { - return nil, fmt.Errorf("queue %s not found", queueName) + return nil, nil, fmt.Errorf("queue %s not found", queueName) } if message.ID == "" { message.ID = NewID() @@ -214,7 +219,7 @@ func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, err } message.CreatedAt = time.Now() queue.messages.Set(message.ID, message) - return queue, nil + return queue, message, nil } func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) error { @@ -250,9 +255,18 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string { consumerID, ok := GetConsumerID(ctx) + defer func() { + cmd := Command{ + Command: SUBSCRIBE_ACK, + Queue: queueName, + Error: "", + } + Write(ctx, conn, cmd) + log.Printf("Consumer %s joined server on queue %s", consumerID, queueName) + }() q, ok := b.queues.Get(queueName) if !ok { - b.NewQueue(queueName) + q = b.NewQueue(queueName) } con := &consumer{id: consumerID, conn: conn} b.consumers.Set(consumerID, con) @@ -319,7 +333,7 @@ func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error CreatedAt: time.Now(), CurrentQueue: msg.Queue, } - err := b.Publish(ctx, task, msg.Queue) + _, err := b.Publish(ctx, task, msg.Queue) if err != nil { return err } @@ -343,7 +357,7 @@ func (b *Broker) request(ctx context.Context, conn net.Conn, msg Command) error CreatedAt: time.Now(), CurrentQueue: msg.Queue, } - err := b.Publish(ctx, task, msg.Queue) + _, err := b.Publish(ctx, task, msg.Queue) if err != nil { return err } diff --git a/constants.go b/constants.go index 8414a3d..9e58b4c 100644 --- a/constants.go +++ b/constants.go @@ -4,7 +4,7 @@ type CMD int const ( SUBSCRIBE CMD = iota + 1 - ACK + SUBSCRIBE_ACK PUBLISH REQUEST RESPONSE diff --git a/consumer.go b/consumer.go index fab3a37..4048502 100644 --- a/consumer.go +++ b/consumer.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net" "sync" "time" @@ -74,6 +75,9 @@ func (c *Consumer) handleCommandMessage(msg Command) error { switch msg.Command { case STOP: return c.Close() + case SUBSCRIBE_ACK: + log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue) + return nil default: return fmt.Errorf("unknown command in consumer %d", msg.Command) } diff --git a/dag.go b/dag.go deleted file mode 100644 index ef8d446..0000000 --- a/dag.go +++ /dev/null @@ -1,238 +0,0 @@ -package mq - -import ( - "context" - "errors" - "fmt" - - "github.com/oarkflow/xsync" -) - -const ( - triggerNodeKey string = "triggerNode" -) - -type Node interface { - Queue() string - Consumer() *Consumer - Handler() Handler -} - -type node struct { - queue string - consumer *Consumer - handler Handler -} - -func (n *node) Queue() string { - return n.queue -} - -func (n *node) Consumer() *Consumer { - return n.consumer -} - -func (n *node) Handler() Handler { - return n.handler -} - -type DAG struct { - nodes *xsync.MapOf[string, Node] - edges [][]string - loopEdges [][]string - broker *Broker - startNode Node - conditions map[string]map[string]string - syncMode bool -} - -func NewDAG(syncMode bool) *DAG { - dag := &DAG{ - nodes: xsync.NewMap[string, Node](), - conditions: make(map[string]map[string]string), - syncMode: syncMode, - } - dag.broker = NewBroker(WithCallback(dag.TaskCallback)) - return dag -} - -func (dag *DAG) TaskCallback(ctx context.Context, task *Task) error { - return nil -} - -func (dag *DAG) AddNode(queue string, handler Handler, firstNode ...bool) { - con := NewConsumer("consume-" + queue) - con.RegisterHandler(queue, handler) - dag.broker.NewQueue(queue) - n := &node{ - queue: queue, - consumer: con, - handler: handler, - } - if len(firstNode) > 0 && firstNode[0] { - dag.startNode = n - } - dag.nodes.Set(queue, n) -} - -func (dag *DAG) AddEdge(fromNodeID, toNodeID string) error { - err := dag.validateNodes(fromNodeID, toNodeID) - if err != nil { - return err - } - dag.edges = append(dag.edges, []string{fromNodeID, toNodeID}) - return nil -} - -func (dag *DAG) AddCondition(conditionNodeID string, conditions map[string]string) error { - for _, nodeID := range conditions { - if err := dag.validateNodes(nodeID); err != nil { - return err - } - } - dag.conditions[conditionNodeID] = conditions - return nil -} - -func (dag *DAG) AddLoop(fromNodeID, toNodeID string) error { - err := dag.validateNodes(fromNodeID, toNodeID) - if err != nil { - return err - } - dag.loopEdges = append(dag.loopEdges, []string{fromNodeID, toNodeID}) - return nil -} - -func (dag *DAG) Start(ctx context.Context) error { - if dag.syncMode { - return nil - } - return dag.broker.Start(ctx) -} - -func (dag *DAG) Prepare(ctx context.Context) error { - startNode, err := dag.findInitialNode() - if err != nil { - return err - } - if startNode == nil { - return fmt.Errorf("no initial node found") - } - dag.startNode = startNode - if dag.syncMode { - return nil - } - dag.nodes.ForEach(func(_ string, node Node) bool { - go node.Consumer().Consume(ctx) - return true - }) - return nil -} - -func (dag *DAG) ProcessTask(ctx context.Context, task Task) Result { - return dag.processNode(ctx, &task, dag.startNode.Queue()) -} - -func (dag *DAG) getConditionalNode(status, currentNode string) string { - conditions, ok := dag.conditions[currentNode] - if !ok { - return "" - } - conditionNodeID, ok := conditions[status] - if !ok { - return "" - } - return conditionNodeID -} - -func (dag *DAG) validateNodes(nodeIDs ...string) error { - for _, nodeID := range nodeIDs { - if _, ok := dag.nodes.Get(nodeID); !ok { - return fmt.Errorf("node %s not found", nodeID) - } - } - return nil -} - -func (dag *DAG) processEdge(ctx context.Context, id string, payload []byte, targets []string) { - newTask := &Task{ - ID: id, - Payload: payload, - } - for _, target := range targets { - if target != "" { - dag.processNode(ctx, newTask, target) - } - } -} - -func (dag *DAG) calculateForFirstNode() (string, bool) { - inDegree := make(map[string]int) - for _, n := range dag.nodes.Keys() { - inDegree[n] = 0 - } - for _, edge := range dag.edges { - inDegree[edge[1]]++ - } - for _, edge := range dag.loopEdges { - inDegree[edge[1]]++ - } - for n, count := range inDegree { - if count == 0 { - return n, true - } - } - return "", false -} - -func (dag *DAG) findInitialNode() (Node, error) { - if dag.startNode != nil { - return dag.startNode, nil - } - var nt Node - n, ok := dag.calculateForFirstNode() - if !ok { - return nil, errors.New("no initial node found") - } - nt, ok = dag.nodes.Get(n) - if !ok { - return nil, errors.New("no initial node found") - } - return nt, nil -} - -func (dag *DAG) processNode(ctx context.Context, task *Task, queue string) Result { - if !dag.syncMode { - if err := dag.broker.Publish(ctx, *task, queue); err != nil { - fmt.Println("Failed to publish task:", err) - } - return Result{} - } - n, ok := dag.nodes.Get(queue) - if task.CurrentQueue == "" { - task.CurrentQueue = queue - } - if !ok { - fmt.Println("Node not found:", queue) - return Result{Error: fmt.Errorf("node not found %s", queue)} - } - _, err := dag.broker.AddMessageToQueue(task, queue) - if err != nil { - return Result{Error: err} - } - result := n.Handler()(ctx, *task) - if result.Queue == "" { - result.Queue = task.CurrentQueue - } - if result.MessageID == "" { - result.MessageID = task.ID - } - if result.Error != nil { - return result - } - err = dag.broker.HandleProcessedMessage(ctx, result) - if err != nil { - return Result{Error: err, Status: result.Status} - } - return result -} diff --git a/examples/broker.go b/examples/broker.go deleted file mode 100644 index ab89a69..0000000 --- a/examples/broker.go +++ /dev/null @@ -1,252 +0,0 @@ -package main - -import ( - "bufio" - "encoding/json" - "fmt" - "net" - "os" - "strings" - "sync" - "time" -) - -type DataItem map[string]interface{} - -type NodeInfo struct { - Name string - Conn net.Conn -} - -type Broker struct { - nodes map[string]NodeInfo - edges map[string]string - loops map[string][]string - conditions map[string]ConditionConfig - results map[string][]DataItem // Track task results by task ID - mu sync.Mutex -} - -type ConditionConfig struct { - TrueNode string - FalseNode string -} - -func NewBroker() *Broker { - return &Broker{ - nodes: make(map[string]NodeInfo), - edges: make(map[string]string), - loops: make(map[string][]string), - conditions: make(map[string]ConditionConfig), - results: make(map[string][]DataItem), - } -} - -func (b *Broker) RegisterNode(name string, conn net.Conn) { - b.mu.Lock() - defer b.mu.Unlock() - fmt.Printf("Registering node: %s\n", name) - b.nodes[name] = NodeInfo{Name: name, Conn: conn} -} - -func (b *Broker) AddEdge(fromNode string, toNode string) { - b.mu.Lock() - defer b.mu.Unlock() - fmt.Printf("Adding edge from %s to %s\n", fromNode, toNode) - b.edges[fromNode] = toNode -} - -func (b *Broker) AddLoop(loopNode string, targetNodes []string) { - b.mu.Lock() - defer b.mu.Unlock() - fmt.Printf("Adding loop at %s with targets: %v\n", loopNode, targetNodes) - b.loops[loopNode] = targetNodes -} - -func (b *Broker) AddCondition(condNode string, trueNode string, falseNode string) { - b.mu.Lock() - defer b.mu.Unlock() - fmt.Printf("Adding condition at %s, True: %s, False: %s\n", condNode, trueNode, falseNode) - b.conditions[condNode] = ConditionConfig{ - TrueNode: trueNode, - FalseNode: falseNode, - } -} - -func (b *Broker) SendDataToNode(nodeName string, taskID string, data []DataItem, resultChannel chan []DataItem) { - b.mu.Lock() - node, exists := b.nodes[nodeName] - b.mu.Unlock() - if !exists { - fmt.Printf("Node %s not found!\n", nodeName) - return - } - - fmt.Printf("Sending data to %s for task %s...\n", nodeName, taskID) - encoder := json.NewEncoder(node.Conn) - err := encoder.Encode(data) - if err != nil { - fmt.Printf("Error sending data to %s: %v\n", nodeName, err) - return - } - - // Receive the processed data back from the node asynchronously - go func() { - decoder := json.NewDecoder(node.Conn) - var result []DataItem - err = decoder.Decode(&result) - if err != nil { - fmt.Printf("Error receiving data from %s for task %s: %v\n", nodeName, taskID, err) - return - } - fmt.Printf("Received processed data from %s for task %s\n", nodeName, taskID) - - // Send the result to the result aggregation channel - resultChannel <- result - }() -} - -func (b *Broker) DispatchData(startNode string, data []DataItem, taskID string) []DataItem { - finalResult := []DataItem{} - currentNode := startNode - resultChannel := make(chan []DataItem, len(data)) // Create a channel to handle async results - - for { - b.mu.Lock() - nextNode, hasEdge := b.edges[currentNode] - loopTargets, hasLoop := b.loops[currentNode] - conditionConfig, hasCondition := b.conditions[currentNode] - b.mu.Unlock() - - // Handle Loops (async dispatch) - if hasLoop { - var wg sync.WaitGroup - fmt.Printf("Dispatching to loop nodes from %s for task %s...\n", currentNode, taskID) - for _, targetNode := range loopTargets { - wg.Add(1) - go func(node string) { - defer wg.Done() - b.SendDataToNode(node, taskID, data, resultChannel) - }(targetNode) - } - - // Wait for loop processing to complete - go func() { - wg.Wait() - close(resultChannel) - }() - - // Collect async results - for res := range resultChannel { - finalResult = append(finalResult, res...) - } - - b.AggregateResults(taskID, finalResult) - return finalResult // Exit after loop processing - } - - // Handle Conditions - if hasCondition { - for _, item := range data { - resultChannel := make(chan []DataItem, 1) - go b.SendDataToNode(currentNode, taskID, []DataItem{item}, resultChannel) - - select { - case result := <-resultChannel: - nextNode = conditionConfig.TrueNode - finalResult = append(finalResult, b.DispatchData(nextNode, result, taskID)...) - case <-time.After(5 * time.Second): // Timeout if no response - fmt.Printf("Condition check timed out at node: %s\n", currentNode) - nextNode = conditionConfig.FalseNode - } - } - b.AggregateResults(taskID, finalResult) - return finalResult // Exit after condition processing - } - - // Handle simple edges (sequential flow) - if hasEdge { - b.SendDataToNode(currentNode, taskID, data, resultChannel) - - select { - case result := <-resultChannel: - currentNode = nextNode - data = result - case <-time.After(5 * time.Second): // Timeout if no response - fmt.Printf("Processing timed out at node: %s\n", currentNode) - return finalResult - } - } else { - fmt.Printf("No edge found for node: %s, stopping...\n", currentNode) - break - } - } - - b.AggregateResults(taskID, finalResult) - return finalResult -} - -func (b *Broker) AggregateResults(taskID string, result []DataItem) { - b.mu.Lock() - defer b.mu.Unlock() - b.results[taskID] = append(b.results[taskID], result...) - fmt.Printf("Aggregated result for task %s: %v\n", taskID, b.results[taskID]) -} - -func (b *Broker) HandleConnections() { - listener, err := net.Listen("tcp", ":8081") - if err != nil { - fmt.Println("Error setting up TCP server:", err) - os.Exit(1) - } - defer listener.Close() - - fmt.Println("Broker is listening on port 8081...") - - for { - conn, err := listener.Accept() - if err != nil { - fmt.Println("Error accepting connection:", err) - continue - } - - go func(conn net.Conn) { - defer conn.Close() - reader := bufio.NewReader(conn) - nodeName, err := reader.ReadString('\n') - if err != nil { - fmt.Println("Error reading node name:", err) - return - } - nodeName = strings.TrimSpace(nodeName) - - b.RegisterNode(nodeName, conn) - }(conn) - } -} - -func main() { - broker := NewBroker() - - // Set up the flow - broker.AddEdge("Node1", "Node2") - broker.AddLoop("Node2", []string{"Node3"}) - broker.AddCondition("Node3", "Node4", "") - - // Start the broker to listen for node connections - go broker.HandleConnections() - - fmt.Println("Press ENTER to start the flow after nodes are connected...") - bufio.NewReader(os.Stdin).ReadString('\n') - - // Example Data Items - dataItems := []DataItem{ - {"id": 1, "value": "item1"}, - {"id": 2, "value": "item2"}, - {"id": 3, "value": "item3"}, - } - - taskID := "task-001" // Unique ID to track this task - finalResult := broker.DispatchData("Node1", dataItems, taskID) - fmt.Println("Final result after processing:", finalResult) -} diff --git a/examples/dag.go b/examples/dag.go index e2444d9..8235f53 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -2,90 +2,125 @@ package main import ( "context" - "encoding/json" "fmt" + "log" + "sync" "time" - + "github.com/oarkflow/mq" ) -func handleNode1(_ context.Context, task mq.Task) mq.Result { - result := []map[string]string{ - {"field": "facility", "item": "item1"}, - {"field": "facility", "item": "item2"}, - {"field": "facility", "item": "item3"}, - } - var payload string - err := json.Unmarshal(task.Payload, &payload) - if err != nil { - return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)} - } - fmt.Printf("Processing task at node1: %s\n", string(task.Payload)) - bt, _ := json.Marshal(result) - return mq.Result{Status: "completed", Payload: bt} -} - -func handleNode2(_ context.Context, task mq.Task) mq.Result { - var payload map[string]string - err := json.Unmarshal(task.Payload, &payload) - if err != nil { - return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)} - } - status := "fail" - if payload["item"] == "item2" { - status = "pass" - } - fmt.Printf("Processing task at node2: %s %s\n", payload, status) - bt, _ := json.Marshal(payload) - return mq.Result{Status: status, Payload: bt} -} - -func handleNode3(_ context.Context, task mq.Task) mq.Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return mq.Result{Error: err} - } - data["item"] = "Item processed in node3" - bt, _ := json.Marshal(data) - return mq.Result{Status: "completed", Payload: bt} -} - -func handleNode4(_ context.Context, task mq.Task) mq.Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return mq.Result{Error: err} - } - data["item"] = "An Item processed in node4" - bt, _ := json.Marshal(data) - return mq.Result{Status: "completed", Payload: bt} -} - func main() { - ctx := context.Background() - d := mq.NewDAG(false) - - d.AddNode("node1", handleNode1, true) - d.AddNode("node2", handleNode2) - d.AddNode("node3", handleNode3) - d.AddNode("node4", handleNode4) - d.AddCondition("node2", map[string]string{"pass": "node3", "fail": "node4"}) - err := d.AddLoop("node1", "node2") - if err != nil { - panic(err) - } - err = d.Prepare(ctx) - if err != nil { - panic(err) - } - // Start the DAG and process the task + dag := NewDAG() + dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result { + log.Printf("Handling task for queue1: %s", string(task.Payload)) + return mq.Result{Payload: []byte(`{"task": 123}`), MessageID: task.ID} + }) + dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result { + log.Printf("Handling task for queue2: %s", string(task.Payload)) + return mq.Result{Payload: []byte(`{"task": 456}`), MessageID: task.ID} + }) + dag.AddEdge("queue1", "queue2") + + // Start DAG processing go func() { - if err := d.Start(ctx); err != nil { - fmt.Println("Error starting DAG:", err) - } + time.Sleep(2 * time.Second) + finalResult := dag.Send([]byte(`{"task": 1}`)) + log.Printf("Final result received: %s", string(finalResult.Payload)) }() - result := d.ProcessTask(ctx, mq.Task{Payload: []byte(`"Start processing"`)}) - fmt.Println(string(result.Payload)) - time.Sleep(50 * time.Second) + + err := dag.Start(context.TODO()) + if err != nil { + panic(err) + } +} + +type DAG struct { + server *mq.Broker + nodes map[string]*mq.Consumer + edges map[string][]string + taskChMap map[string]chan mq.Result // A map to store result channels for each task + mu sync.Mutex // Mutex to protect the taskChMap +} + +func NewDAG(opts ...mq.Option) *DAG { + d := &DAG{ + nodes: make(map[string]*mq.Consumer), + edges: make(map[string][]string), + taskChMap: make(map[string]chan mq.Result), + } + opts = append(opts, mq.WithCallback(d.TaskCallback)) + d.server = mq.NewBroker(opts...) + return d +} + +func (d *DAG) AddNode(name string, handler mq.Handler) { + con := mq.NewConsumer(name) + con.RegisterHandler(name, handler) + d.nodes[name] = con +} + +func (d *DAG) AddEdge(fromNode string, toNodes ...string) { + d.edges[fromNode] = toNodes +} + +func (d *DAG) Start(ctx context.Context) error { + for _, con := range d.nodes { + go con.Consume(ctx) + } + return d.server.Start(ctx) +} + +func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string) (*mq.Task, error) { + task := mq.Task{ + Payload: payload, + } + return d.server.Publish(ctx, task, queueName) +} + +// TaskCallback is the function triggered after each task completion. +func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error { + log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result)) + edges, exists := d.edges[task.CurrentQueue] + if !exists { + // Lock and send the result to the specific task channel + d.mu.Lock() + fmt.Println(d.taskChMap, task.ID) + for _, resultCh := range d.taskChMap { + result := mq.Result{ + Command: "complete", + Payload: task.Result, + Queue: task.CurrentQueue, + MessageID: task.ID, + Status: "done", + } + resultCh <- result + delete(d.taskChMap, task.ID) // Clean up the channel + } + d.mu.Unlock() + return nil + } + + // Forward the task to the next node(s) + for _, edge := range edges { + _, err := d.PublishTask(ctx, task.Result, edge) + if err != nil { + return err + } + } + return nil +} + +// Send sends the task and waits for the final result. +func (d *DAG) Send(payload []byte) mq.Result { + resultCh := make(chan mq.Result) + task, err := d.PublishTask(context.TODO(), payload, "queue1") + if err != nil { + panic(err) + } + d.mu.Lock() + d.taskChMap[task.ID] = resultCh + d.mu.Unlock() + finalResult := <-resultCh + return finalResult }