init: publisher

This commit is contained in:
sujit
2024-09-29 00:50:59 +05:45
parent c6c83c8a3f
commit 1368b9a9e7
6 changed files with 147 additions and 584 deletions

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net"
"time"
@@ -170,7 +171,7 @@ func (b *Broker) Start(ctx context.Context) error {
defer func() {
_ = listener.Close()
}()
fmt.Println("Broker server started on", b.opts.brokerAddr)
log.Println("Server started on", b.opts.brokerAddr)
for {
conn, err := listener.Accept()
if err != nil {
@@ -181,30 +182,34 @@ func (b *Broker) Start(ctx context.Context) error {
}
}
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) error {
queue, err := b.AddMessageToQueue(&message, queueName)
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) (*Task, error) {
queue, task, err := b.AddMessageToQueue(&message, queueName)
if err != nil {
return err
return nil, err
}
if queue.consumers.Size() == 0 {
queue.deferred.Set(NewID(), &message)
fmt.Println("task deferred as no consumers are connected", queueName)
return nil
return task, nil
}
queue.send(ctx, message)
return nil
return task, nil
}
func (b *Broker) NewQueue(qName string) {
if _, ok := b.queues.Get(qName); !ok {
b.queues.Set(qName, newQueue(qName))
func (b *Broker) NewQueue(qName string) *Queue {
q, ok := b.queues.Get(qName)
if ok {
return q
}
q = newQueue(qName)
b.queues.Set(qName, q)
return q
}
func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, error) {
func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, *Task, error) {
queue, ok := b.queues.Get(queueName)
if !ok {
return nil, fmt.Errorf("queue %s not found", queueName)
return nil, nil, fmt.Errorf("queue %s not found", queueName)
}
if message.ID == "" {
message.ID = NewID()
@@ -214,7 +219,7 @@ func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, err
}
message.CreatedAt = time.Now()
queue.messages.Set(message.ID, message)
return queue, nil
return queue, message, nil
}
func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) error {
@@ -250,9 +255,18 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
consumerID, ok := GetConsumerID(ctx)
defer func() {
cmd := Command{
Command: SUBSCRIBE_ACK,
Queue: queueName,
Error: "",
}
Write(ctx, conn, cmd)
log.Printf("Consumer %s joined server on queue %s", consumerID, queueName)
}()
q, ok := b.queues.Get(queueName)
if !ok {
b.NewQueue(queueName)
q = b.NewQueue(queueName)
}
con := &consumer{id: consumerID, conn: conn}
b.consumers.Set(consumerID, con)
@@ -319,7 +333,7 @@ func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error
CreatedAt: time.Now(),
CurrentQueue: msg.Queue,
}
err := b.Publish(ctx, task, msg.Queue)
_, err := b.Publish(ctx, task, msg.Queue)
if err != nil {
return err
}
@@ -343,7 +357,7 @@ func (b *Broker) request(ctx context.Context, conn net.Conn, msg Command) error
CreatedAt: time.Now(),
CurrentQueue: msg.Queue,
}
err := b.Publish(ctx, task, msg.Queue)
_, err := b.Publish(ctx, task, msg.Queue)
if err != nil {
return err
}

View File

@@ -4,7 +4,7 @@ type CMD int
const (
SUBSCRIBE CMD = iota + 1
ACK
SUBSCRIBE_ACK
PUBLISH
REQUEST
RESPONSE

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net"
"sync"
"time"
@@ -74,6 +75,9 @@ func (c *Consumer) handleCommandMessage(msg Command) error {
switch msg.Command {
case STOP:
return c.Close()
case SUBSCRIBE_ACK:
log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue)
return nil
default:
return fmt.Errorf("unknown command in consumer %d", msg.Command)
}

238
dag.go
View File

@@ -1,238 +0,0 @@
package mq
import (
"context"
"errors"
"fmt"
"github.com/oarkflow/xsync"
)
const (
triggerNodeKey string = "triggerNode"
)
type Node interface {
Queue() string
Consumer() *Consumer
Handler() Handler
}
type node struct {
queue string
consumer *Consumer
handler Handler
}
func (n *node) Queue() string {
return n.queue
}
func (n *node) Consumer() *Consumer {
return n.consumer
}
func (n *node) Handler() Handler {
return n.handler
}
type DAG struct {
nodes *xsync.MapOf[string, Node]
edges [][]string
loopEdges [][]string
broker *Broker
startNode Node
conditions map[string]map[string]string
syncMode bool
}
func NewDAG(syncMode bool) *DAG {
dag := &DAG{
nodes: xsync.NewMap[string, Node](),
conditions: make(map[string]map[string]string),
syncMode: syncMode,
}
dag.broker = NewBroker(WithCallback(dag.TaskCallback))
return dag
}
func (dag *DAG) TaskCallback(ctx context.Context, task *Task) error {
return nil
}
func (dag *DAG) AddNode(queue string, handler Handler, firstNode ...bool) {
con := NewConsumer("consume-" + queue)
con.RegisterHandler(queue, handler)
dag.broker.NewQueue(queue)
n := &node{
queue: queue,
consumer: con,
handler: handler,
}
if len(firstNode) > 0 && firstNode[0] {
dag.startNode = n
}
dag.nodes.Set(queue, n)
}
func (dag *DAG) AddEdge(fromNodeID, toNodeID string) error {
err := dag.validateNodes(fromNodeID, toNodeID)
if err != nil {
return err
}
dag.edges = append(dag.edges, []string{fromNodeID, toNodeID})
return nil
}
func (dag *DAG) AddCondition(conditionNodeID string, conditions map[string]string) error {
for _, nodeID := range conditions {
if err := dag.validateNodes(nodeID); err != nil {
return err
}
}
dag.conditions[conditionNodeID] = conditions
return nil
}
func (dag *DAG) AddLoop(fromNodeID, toNodeID string) error {
err := dag.validateNodes(fromNodeID, toNodeID)
if err != nil {
return err
}
dag.loopEdges = append(dag.loopEdges, []string{fromNodeID, toNodeID})
return nil
}
func (dag *DAG) Start(ctx context.Context) error {
if dag.syncMode {
return nil
}
return dag.broker.Start(ctx)
}
func (dag *DAG) Prepare(ctx context.Context) error {
startNode, err := dag.findInitialNode()
if err != nil {
return err
}
if startNode == nil {
return fmt.Errorf("no initial node found")
}
dag.startNode = startNode
if dag.syncMode {
return nil
}
dag.nodes.ForEach(func(_ string, node Node) bool {
go node.Consumer().Consume(ctx)
return true
})
return nil
}
func (dag *DAG) ProcessTask(ctx context.Context, task Task) Result {
return dag.processNode(ctx, &task, dag.startNode.Queue())
}
func (dag *DAG) getConditionalNode(status, currentNode string) string {
conditions, ok := dag.conditions[currentNode]
if !ok {
return ""
}
conditionNodeID, ok := conditions[status]
if !ok {
return ""
}
return conditionNodeID
}
func (dag *DAG) validateNodes(nodeIDs ...string) error {
for _, nodeID := range nodeIDs {
if _, ok := dag.nodes.Get(nodeID); !ok {
return fmt.Errorf("node %s not found", nodeID)
}
}
return nil
}
func (dag *DAG) processEdge(ctx context.Context, id string, payload []byte, targets []string) {
newTask := &Task{
ID: id,
Payload: payload,
}
for _, target := range targets {
if target != "" {
dag.processNode(ctx, newTask, target)
}
}
}
func (dag *DAG) calculateForFirstNode() (string, bool) {
inDegree := make(map[string]int)
for _, n := range dag.nodes.Keys() {
inDegree[n] = 0
}
for _, edge := range dag.edges {
inDegree[edge[1]]++
}
for _, edge := range dag.loopEdges {
inDegree[edge[1]]++
}
for n, count := range inDegree {
if count == 0 {
return n, true
}
}
return "", false
}
func (dag *DAG) findInitialNode() (Node, error) {
if dag.startNode != nil {
return dag.startNode, nil
}
var nt Node
n, ok := dag.calculateForFirstNode()
if !ok {
return nil, errors.New("no initial node found")
}
nt, ok = dag.nodes.Get(n)
if !ok {
return nil, errors.New("no initial node found")
}
return nt, nil
}
func (dag *DAG) processNode(ctx context.Context, task *Task, queue string) Result {
if !dag.syncMode {
if err := dag.broker.Publish(ctx, *task, queue); err != nil {
fmt.Println("Failed to publish task:", err)
}
return Result{}
}
n, ok := dag.nodes.Get(queue)
if task.CurrentQueue == "" {
task.CurrentQueue = queue
}
if !ok {
fmt.Println("Node not found:", queue)
return Result{Error: fmt.Errorf("node not found %s", queue)}
}
_, err := dag.broker.AddMessageToQueue(task, queue)
if err != nil {
return Result{Error: err}
}
result := n.Handler()(ctx, *task)
if result.Queue == "" {
result.Queue = task.CurrentQueue
}
if result.MessageID == "" {
result.MessageID = task.ID
}
if result.Error != nil {
return result
}
err = dag.broker.HandleProcessedMessage(ctx, result)
if err != nil {
return Result{Error: err, Status: result.Status}
}
return result
}

View File

@@ -1,252 +0,0 @@
package main
import (
"bufio"
"encoding/json"
"fmt"
"net"
"os"
"strings"
"sync"
"time"
)
type DataItem map[string]interface{}
type NodeInfo struct {
Name string
Conn net.Conn
}
type Broker struct {
nodes map[string]NodeInfo
edges map[string]string
loops map[string][]string
conditions map[string]ConditionConfig
results map[string][]DataItem // Track task results by task ID
mu sync.Mutex
}
type ConditionConfig struct {
TrueNode string
FalseNode string
}
func NewBroker() *Broker {
return &Broker{
nodes: make(map[string]NodeInfo),
edges: make(map[string]string),
loops: make(map[string][]string),
conditions: make(map[string]ConditionConfig),
results: make(map[string][]DataItem),
}
}
func (b *Broker) RegisterNode(name string, conn net.Conn) {
b.mu.Lock()
defer b.mu.Unlock()
fmt.Printf("Registering node: %s\n", name)
b.nodes[name] = NodeInfo{Name: name, Conn: conn}
}
func (b *Broker) AddEdge(fromNode string, toNode string) {
b.mu.Lock()
defer b.mu.Unlock()
fmt.Printf("Adding edge from %s to %s\n", fromNode, toNode)
b.edges[fromNode] = toNode
}
func (b *Broker) AddLoop(loopNode string, targetNodes []string) {
b.mu.Lock()
defer b.mu.Unlock()
fmt.Printf("Adding loop at %s with targets: %v\n", loopNode, targetNodes)
b.loops[loopNode] = targetNodes
}
func (b *Broker) AddCondition(condNode string, trueNode string, falseNode string) {
b.mu.Lock()
defer b.mu.Unlock()
fmt.Printf("Adding condition at %s, True: %s, False: %s\n", condNode, trueNode, falseNode)
b.conditions[condNode] = ConditionConfig{
TrueNode: trueNode,
FalseNode: falseNode,
}
}
func (b *Broker) SendDataToNode(nodeName string, taskID string, data []DataItem, resultChannel chan []DataItem) {
b.mu.Lock()
node, exists := b.nodes[nodeName]
b.mu.Unlock()
if !exists {
fmt.Printf("Node %s not found!\n", nodeName)
return
}
fmt.Printf("Sending data to %s for task %s...\n", nodeName, taskID)
encoder := json.NewEncoder(node.Conn)
err := encoder.Encode(data)
if err != nil {
fmt.Printf("Error sending data to %s: %v\n", nodeName, err)
return
}
// Receive the processed data back from the node asynchronously
go func() {
decoder := json.NewDecoder(node.Conn)
var result []DataItem
err = decoder.Decode(&result)
if err != nil {
fmt.Printf("Error receiving data from %s for task %s: %v\n", nodeName, taskID, err)
return
}
fmt.Printf("Received processed data from %s for task %s\n", nodeName, taskID)
// Send the result to the result aggregation channel
resultChannel <- result
}()
}
func (b *Broker) DispatchData(startNode string, data []DataItem, taskID string) []DataItem {
finalResult := []DataItem{}
currentNode := startNode
resultChannel := make(chan []DataItem, len(data)) // Create a channel to handle async results
for {
b.mu.Lock()
nextNode, hasEdge := b.edges[currentNode]
loopTargets, hasLoop := b.loops[currentNode]
conditionConfig, hasCondition := b.conditions[currentNode]
b.mu.Unlock()
// Handle Loops (async dispatch)
if hasLoop {
var wg sync.WaitGroup
fmt.Printf("Dispatching to loop nodes from %s for task %s...\n", currentNode, taskID)
for _, targetNode := range loopTargets {
wg.Add(1)
go func(node string) {
defer wg.Done()
b.SendDataToNode(node, taskID, data, resultChannel)
}(targetNode)
}
// Wait for loop processing to complete
go func() {
wg.Wait()
close(resultChannel)
}()
// Collect async results
for res := range resultChannel {
finalResult = append(finalResult, res...)
}
b.AggregateResults(taskID, finalResult)
return finalResult // Exit after loop processing
}
// Handle Conditions
if hasCondition {
for _, item := range data {
resultChannel := make(chan []DataItem, 1)
go b.SendDataToNode(currentNode, taskID, []DataItem{item}, resultChannel)
select {
case result := <-resultChannel:
nextNode = conditionConfig.TrueNode
finalResult = append(finalResult, b.DispatchData(nextNode, result, taskID)...)
case <-time.After(5 * time.Second): // Timeout if no response
fmt.Printf("Condition check timed out at node: %s\n", currentNode)
nextNode = conditionConfig.FalseNode
}
}
b.AggregateResults(taskID, finalResult)
return finalResult // Exit after condition processing
}
// Handle simple edges (sequential flow)
if hasEdge {
b.SendDataToNode(currentNode, taskID, data, resultChannel)
select {
case result := <-resultChannel:
currentNode = nextNode
data = result
case <-time.After(5 * time.Second): // Timeout if no response
fmt.Printf("Processing timed out at node: %s\n", currentNode)
return finalResult
}
} else {
fmt.Printf("No edge found for node: %s, stopping...\n", currentNode)
break
}
}
b.AggregateResults(taskID, finalResult)
return finalResult
}
func (b *Broker) AggregateResults(taskID string, result []DataItem) {
b.mu.Lock()
defer b.mu.Unlock()
b.results[taskID] = append(b.results[taskID], result...)
fmt.Printf("Aggregated result for task %s: %v\n", taskID, b.results[taskID])
}
func (b *Broker) HandleConnections() {
listener, err := net.Listen("tcp", ":8081")
if err != nil {
fmt.Println("Error setting up TCP server:", err)
os.Exit(1)
}
defer listener.Close()
fmt.Println("Broker is listening on port 8081...")
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println("Error accepting connection:", err)
continue
}
go func(conn net.Conn) {
defer conn.Close()
reader := bufio.NewReader(conn)
nodeName, err := reader.ReadString('\n')
if err != nil {
fmt.Println("Error reading node name:", err)
return
}
nodeName = strings.TrimSpace(nodeName)
b.RegisterNode(nodeName, conn)
}(conn)
}
}
func main() {
broker := NewBroker()
// Set up the flow
broker.AddEdge("Node1", "Node2")
broker.AddLoop("Node2", []string{"Node3"})
broker.AddCondition("Node3", "Node4", "")
// Start the broker to listen for node connections
go broker.HandleConnections()
fmt.Println("Press ENTER to start the flow after nodes are connected...")
bufio.NewReader(os.Stdin).ReadString('\n')
// Example Data Items
dataItems := []DataItem{
{"id": 1, "value": "item1"},
{"id": 2, "value": "item2"},
{"id": 3, "value": "item3"},
}
taskID := "task-001" // Unique ID to track this task
finalResult := broker.DispatchData("Node1", dataItems, taskID)
fmt.Println("Final result after processing:", finalResult)
}

View File

@@ -2,90 +2,125 @@ package main
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"
"github.com/oarkflow/mq"
)
func handleNode1(_ context.Context, task mq.Task) mq.Result {
result := []map[string]string{
{"field": "facility", "item": "item1"},
{"field": "facility", "item": "item2"},
{"field": "facility", "item": "item3"},
}
var payload string
err := json.Unmarshal(task.Payload, &payload)
if err != nil {
return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node1", "item": "error"}`)}
}
fmt.Printf("Processing task at node1: %s\n", string(task.Payload))
bt, _ := json.Marshal(result)
return mq.Result{Status: "completed", Payload: bt}
}
func handleNode2(_ context.Context, task mq.Task) mq.Result {
var payload map[string]string
err := json.Unmarshal(task.Payload, &payload)
if err != nil {
return mq.Result{Status: "fail", Payload: json.RawMessage(`{"field": "node2", "item": "error"}`)}
}
status := "fail"
if payload["item"] == "item2" {
status = "pass"
}
fmt.Printf("Processing task at node2: %s %s\n", payload, status)
bt, _ := json.Marshal(payload)
return mq.Result{Status: status, Payload: bt}
}
func handleNode3(_ context.Context, task mq.Task) mq.Result {
var data map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
}
data["item"] = "Item processed in node3"
bt, _ := json.Marshal(data)
return mq.Result{Status: "completed", Payload: bt}
}
func handleNode4(_ context.Context, task mq.Task) mq.Result {
var data map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
}
data["item"] = "An Item processed in node4"
bt, _ := json.Marshal(data)
return mq.Result{Status: "completed", Payload: bt}
}
func main() {
ctx := context.Background()
d := mq.NewDAG(false)
d.AddNode("node1", handleNode1, true)
d.AddNode("node2", handleNode2)
d.AddNode("node3", handleNode3)
d.AddNode("node4", handleNode4)
d.AddCondition("node2", map[string]string{"pass": "node3", "fail": "node4"})
err := d.AddLoop("node1", "node2")
if err != nil {
panic(err)
}
err = d.Prepare(ctx)
if err != nil {
panic(err)
}
// Start the DAG and process the task
dag := NewDAG()
dag.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result {
log.Printf("Handling task for queue1: %s", string(task.Payload))
return mq.Result{Payload: []byte(`{"task": 123}`), MessageID: task.ID}
})
dag.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result {
log.Printf("Handling task for queue2: %s", string(task.Payload))
return mq.Result{Payload: []byte(`{"task": 456}`), MessageID: task.ID}
})
dag.AddEdge("queue1", "queue2")
// Start DAG processing
go func() {
if err := d.Start(ctx); err != nil {
fmt.Println("Error starting DAG:", err)
}
time.Sleep(2 * time.Second)
finalResult := dag.Send([]byte(`{"task": 1}`))
log.Printf("Final result received: %s", string(finalResult.Payload))
}()
result := d.ProcessTask(ctx, mq.Task{Payload: []byte(`"Start processing"`)})
fmt.Println(string(result.Payload))
time.Sleep(50 * time.Second)
err := dag.Start(context.TODO())
if err != nil {
panic(err)
}
}
type DAG struct {
server *mq.Broker
nodes map[string]*mq.Consumer
edges map[string][]string
taskChMap map[string]chan mq.Result // A map to store result channels for each task
mu sync.Mutex // Mutex to protect the taskChMap
}
func NewDAG(opts ...mq.Option) *DAG {
d := &DAG{
nodes: make(map[string]*mq.Consumer),
edges: make(map[string][]string),
taskChMap: make(map[string]chan mq.Result),
}
opts = append(opts, mq.WithCallback(d.TaskCallback))
d.server = mq.NewBroker(opts...)
return d
}
func (d *DAG) AddNode(name string, handler mq.Handler) {
con := mq.NewConsumer(name)
con.RegisterHandler(name, handler)
d.nodes[name] = con
}
func (d *DAG) AddEdge(fromNode string, toNodes ...string) {
d.edges[fromNode] = toNodes
}
func (d *DAG) Start(ctx context.Context) error {
for _, con := range d.nodes {
go con.Consume(ctx)
}
return d.server.Start(ctx)
}
func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string) (*mq.Task, error) {
task := mq.Task{
Payload: payload,
}
return d.server.Publish(ctx, task, queueName)
}
// TaskCallback is the function triggered after each task completion.
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result))
edges, exists := d.edges[task.CurrentQueue]
if !exists {
// Lock and send the result to the specific task channel
d.mu.Lock()
fmt.Println(d.taskChMap, task.ID)
for _, resultCh := range d.taskChMap {
result := mq.Result{
Command: "complete",
Payload: task.Result,
Queue: task.CurrentQueue,
MessageID: task.ID,
Status: "done",
}
resultCh <- result
delete(d.taskChMap, task.ID) // Clean up the channel
}
d.mu.Unlock()
return nil
}
// Forward the task to the next node(s)
for _, edge := range edges {
_, err := d.PublishTask(ctx, task.Result, edge)
if err != nil {
return err
}
}
return nil
}
// Send sends the task and waits for the final result.
func (d *DAG) Send(payload []byte) mq.Result {
resultCh := make(chan mq.Result)
task, err := d.PublishTask(context.TODO(), payload, "queue1")
if err != nil {
panic(err)
}
d.mu.Lock()
d.taskChMap[task.ID] = resultCh
d.mu.Unlock()
finalResult := <-resultCh
return finalResult
}