diff --git a/broker.go b/broker.go index edd3bea..c660f32 100644 --- a/broker.go +++ b/broker.go @@ -71,7 +71,6 @@ type Task struct { CreatedAt time.Time `json:"created_at"` ProcessedAt time.Time `json:"processed_at"` CurrentQueue string `json:"current_queue"` - Result json.RawMessage `json:"result"` Status string `json:"status"` Error error `json:"error"` } @@ -141,7 +140,7 @@ func (b *Broker) sendToPublisher(ctx context.Context, publisherID string, result return pub.send(ctx, result) } -func (b *Broker) onClose(ctx context.Context, conn net.Conn) error { +func (b *Broker) onClose(ctx context.Context, _ net.Conn) error { consumerID, ok := GetConsumerID(ctx) if ok && consumerID != "" { if con, exists := b.consumers.Get(consumerID); exists { @@ -163,7 +162,7 @@ func (b *Broker) onClose(ctx context.Context, conn net.Conn) error { return nil } -func (b *Broker) onError(ctx context.Context, conn net.Conn, err error) { +func (b *Broker) onError(_ context.Context, conn net.Conn, err error) { fmt.Println("Error reading from connection:", err, conn.RemoteAddr()) } @@ -186,18 +185,24 @@ func (b *Broker) Start(ctx context.Context) error { } } -func (b *Broker) Publish(ctx context.Context, message Task, queueName string) (*Task, error) { +func (b *Broker) Publish(ctx context.Context, message Task, queueName string) Result { queue, task, err := b.AddMessageToQueue(&message, queueName) if err != nil { - return nil, err + return Result{Error: err} + } + result := Result{ + Command: "PUBLISH", + Payload: message.Payload, + Queue: queueName, + MessageID: task.ID, } if queue.consumers.Size() == 0 { queue.deferred.Set(NewID(), &message) fmt.Println("task deferred as no consumers are connected", queueName) - return task, nil + return result } queue.send(ctx, message) - return task, nil + return result } func (b *Broker) NewQueue(qName string) *Queue { @@ -210,44 +215,32 @@ func (b *Broker) NewQueue(qName string) *Queue { return q } -func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, *Task, error) { +func (b *Broker) AddMessageToQueue(task *Task, queueName string) (*Queue, *Task, error) { queue := b.NewQueue(queueName) - if message.ID == "" { - message.ID = NewID() + if task.ID == "" { + task.ID = NewID() } if queueName != "" { - message.CurrentQueue = queueName + task.CurrentQueue = queueName } - message.CreatedAt = time.Now() - queue.messages.Set(message.ID, message) - return queue, message, nil + task.CreatedAt = time.Now() + queue.messages.Set(task.ID, task) + return queue, task, nil } -func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) error { +func (b *Broker) HandleProcessedMessage(ctx context.Context, result Result) error { publisherID, ok := GetPublisherID(ctx) if ok && publisherID != "" { - err := b.sendToPublisher(ctx, publisherID, clientMsg) + err := b.sendToPublisher(ctx, publisherID, result) if err != nil { return err } } - if queue, ok := b.queues.Get(clientMsg.Queue); ok { - if msg, ok := queue.messages.Get(clientMsg.MessageID); ok { - msg.ProcessedAt = time.Now() - msg.Status = clientMsg.Status - msg.Result = clientMsg.Payload - msg.Error = clientMsg.Error - msg.CurrentQueue = clientMsg.Queue - if clientMsg.Error != nil { - msg.Status = "error" - } - for _, callback := range b.opts.callback { - if callback != nil { - result := callback(ctx, msg) - if result.Error != nil { - return result.Error - } - } + for _, callback := range b.opts.callback { + if callback != nil { + rs := callback(ctx, result) + if rs.Error != nil { + return rs.Error } } } @@ -327,6 +320,10 @@ func (b *Broker) handleTaskMessage(ctx context.Context, _ net.Conn, msg Result) } func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error { + status := "PUBLISH" + if msg.Command == REQUEST { + status = "REQUEST" + } b.addPublisher(ctx, msg.Queue, conn) task := Task{ ID: msg.MessageID, @@ -334,41 +331,14 @@ func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error CreatedAt: time.Now(), CurrentQueue: msg.Queue, } - _, err := b.Publish(ctx, task, msg.Queue) - if err != nil { - return err + result := b.Publish(ctx, task, msg.Queue) + if result.Error != nil { + return result.Error } if task.ID != "" { - result := Result{ - Command: "PUBLISH", - MessageID: task.ID, - Status: "success", - Queue: msg.Queue, - } - return Write(ctx, conn, result) - } - return nil -} - -func (b *Broker) request(ctx context.Context, conn net.Conn, msg Command) error { - b.addPublisher(ctx, msg.Queue, conn) - task := Task{ - ID: msg.MessageID, - Payload: msg.Payload, - CreatedAt: time.Now(), - CurrentQueue: msg.Queue, - } - _, err := b.Publish(ctx, task, msg.Queue) - if err != nil { - return err - } - if task.ID != "" { - result := Result{ - Command: "REQUEST", - MessageID: task.ID, - Status: "success", - Queue: msg.Queue, - } + result.Status = status + result.MessageID = task.ID + result.Queue = msg.Queue return Write(ctx, conn, result) } return nil @@ -379,10 +349,8 @@ func (b *Broker) handleCommandMessage(ctx context.Context, conn net.Conn, msg Co case SUBSCRIBE: b.subscribe(ctx, msg.Queue, conn) return nil - case PUBLISH: + case PUBLISH, REQUEST: return b.publish(ctx, conn, msg) - case REQUEST: - return b.request(ctx, conn, msg) default: return fmt.Errorf("unknown command: %d", msg.Command) } diff --git a/dag/dag.go b/dag/dag.go index cb699e0..8e6c6cc 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -21,7 +21,7 @@ type DAG struct { FirstNode string server *mq.Broker nodes map[string]*mq.Consumer - edges map[string][]string + edges map[string]string loopEdges map[string][]string taskChMap map[string]chan mq.Result taskResults map[string]map[string]*taskContext @@ -31,7 +31,7 @@ type DAG struct { func New(opts ...mq.Option) *DAG { d := &DAG{ nodes: make(map[string]*mq.Consumer), - edges: make(map[string][]string), + edges: make(map[string]string), loopEdges: make(map[string][]string), taskChMap: make(map[string]chan mq.Result), taskResults: make(map[string]map[string]*taskContext), @@ -50,7 +50,7 @@ func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { d.nodes[name] = con } -func (d *DAG) AddEdge(fromNode string, toNodes ...string) { +func (d *DAG) AddEdge(fromNode string, toNodes string) { d.edges[fromNode] = toNodes } @@ -58,13 +58,17 @@ func (d *DAG) AddLoop(fromNode string, toNode ...string) { d.loopEdges[fromNode] = toNode } -func (d *DAG) Start(ctx context.Context) error { +func (d *DAG) Prepare() { if d.FirstNode == "" { firstNode, ok := d.FindFirstNode() if ok && firstNode != "" { d.FirstNode = firstNode } } +} + +func (d *DAG) Start(ctx context.Context) error { + d.Prepare() if d.server.SyncMode() { return nil } @@ -74,7 +78,7 @@ func (d *DAG) Start(ctx context.Context) error { return d.server.Start(ctx) } -func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) (*mq.Task, error) { +func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) mq.Result { task := mq.Task{ Payload: payload, } @@ -89,10 +93,8 @@ func (d *DAG) FindFirstNode() (string, bool) { for n, _ := range d.nodes { inDegree[n] = 0 } - for _, targets := range d.edges { - for _, outNode := range targets { - inDegree[outNode]++ - } + for _, outNode := range d.edges { + inDegree[outNode]++ } for _, targets := range d.loopEdges { for _, outNode := range targets { @@ -107,23 +109,89 @@ func (d *DAG) FindFirstNode() (string, bool) { return "", false } -func (d *DAG) Send(payload []byte) mq.Result { +func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result { + return d.sendSync(ctx, mq.Result{Payload: payload}) +} + +func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { if d.FirstNode == "" { return mq.Result{Error: fmt.Errorf("initial node not defined")} } + if d.server.SyncMode() { + return d.sendSync(ctx, mq.Result{Payload: payload}) + } resultCh := make(chan mq.Result) - task, err := d.PublishTask(context.TODO(), payload, d.FirstNode) - if err != nil { - return mq.Result{Error: err} + result := d.PublishTask(ctx, payload, d.FirstNode) + if result.Error != nil { + return result } d.mu.Lock() - d.taskChMap[task.ID] = resultCh + d.taskChMap[result.MessageID] = resultCh d.mu.Unlock() finalResult := <-resultCh return finalResult } -func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result { +func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { + if con, ok := d.nodes[task.Queue]; ok { + return con.ProcessTask(ctx, mq.Task{ + ID: task.MessageID, + Payload: task.Payload, + CurrentQueue: task.Queue, + }) + } + return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)} +} + +func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { + if task.MessageID == "" { + task.MessageID = mq.NewID() + } + if task.Queue == "" { + task.Queue = d.FirstNode + } + result := d.processNode(ctx, task) + if result.Error != nil { + return result + } + for _, target := range d.loopEdges[task.Queue] { + var items, results []json.RawMessage + if err := json.Unmarshal(result.Payload, &items); err != nil { + return mq.Result{Error: err} + } + for _, item := range items { + result = d.sendSync(ctx, mq.Result{ + Command: result.Command, + Payload: item, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + results = append(results, result.Payload) + } + bt, err := json.Marshal(results) + if err != nil { + return mq.Result{Error: err} + } + result.Payload = bt + } + if target, ok := d.edges[task.Queue]; ok { + result = d.sendSync(ctx, mq.Result{ + Command: result.Command, + Payload: result.Payload, + Queue: target, + MessageID: result.MessageID, + }) + if result.Error != nil { + return result + } + } + return result +} + +func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { if task.Error != nil { return mq.Result{Error: task.Error} } @@ -133,7 +201,7 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result { completed := false var nodeType string if ok && triggeredNode != "" { - taskResults, ok := d.taskResults[task.ID] + taskResults, ok := d.taskResults[task.MessageID] if ok { nodeResult, exists := taskResults[triggeredNode] if exists { @@ -143,12 +211,16 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result { } switch nodeResult.nodeType { case "loop": - nodeResult.results = append(nodeResult.results, task.Result) - result = nodeResult.results + nodeResult.results = append(nodeResult.results, task.Payload) + if completed { + result = nodeResult.results + } nodeType = "loop" case "edge": - nodeResult.result = task.Result - result = nodeResult.result + nodeResult.result = task.Payload + if completed { + result = nodeResult.result + } nodeType = "edge" } } @@ -160,26 +232,26 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result { if completed { payload, _ = json.Marshal(result) } else { - payload = task.Result + payload = task.Payload } - if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists { + if loopNodes, exists := d.loopEdges[task.Queue]; exists { var items []json.RawMessage if err := json.Unmarshal(payload, &items); err != nil { return mq.Result{Error: task.Error} } - d.taskResults[task.ID] = map[string]*taskContext{ - task.CurrentQueue: { + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { totalItems: len(items), nodeType: "loop", }, } - ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue}) + ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.Queue}) for _, loopNode := range loopNodes { for _, item := range items { - _, err := d.PublishTask(ctx, item, loopNode, task.ID) - if err != nil { - return mq.Result{Error: task.Error} + rs := d.PublishTask(ctx, item, loopNode, task.MessageID) + if rs.Error != nil { + return rs } } } @@ -187,37 +259,35 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result { return mq.Result{} } if nodeType == "loop" && completed { - task.CurrentQueue = triggeredNode + task.Queue = triggeredNode } - ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue}) - edges, exists := d.edges[task.CurrentQueue] + ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.Queue}) + edge, exists := d.edges[task.Queue] if exists { - d.taskResults[task.ID] = map[string]*taskContext{ - task.CurrentQueue: { + d.taskResults[task.MessageID] = map[string]*taskContext{ + task.Queue: { totalItems: 1, nodeType: "edge", }, } - for _, edge := range edges { - _, err := d.PublishTask(ctx, payload, edge, task.ID) - if err != nil { - return mq.Result{Error: task.Error} - } + rs := d.PublishTask(ctx, payload, edge, task.MessageID) + if rs.Error != nil { + return rs } } else if completed { d.mu.Lock() - if resultCh, ok := d.taskChMap[task.ID]; ok { + if resultCh, ok := d.taskChMap[task.MessageID]; ok { resultCh <- mq.Result{ Command: "complete", Payload: payload, - Queue: task.CurrentQueue, - MessageID: task.ID, + Queue: task.Queue, + MessageID: task.MessageID, Status: "done", } - delete(d.taskChMap, task.ID) - delete(d.taskResults, task.ID) + delete(d.taskChMap, task.MessageID) + delete(d.taskResults, task.MessageID) } d.mu.Unlock() } - return mq.Result{} + return task } diff --git a/examples/dag.go b/examples/dag.go index 8e046d0..e54e47a 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -3,84 +3,71 @@ package main import ( "context" "encoding/json" - "fmt" "io" "log" "net/http" "github.com/oarkflow/mq" "github.com/oarkflow/mq/dag" + "github.com/oarkflow/mq/examples/tasks" ) var d *dag.DAG func main() { d = dag.New() - d.AddNode("queue1", func(ctx context.Context, task mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, MessageID: task.ID} - }) - d.AddNode("queue2", func(ctx context.Context, task mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, MessageID: task.ID} - }) - d.AddNode("queue3", func(ctx context.Context, task mq.Task) mq.Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return mq.Result{Error: err} - } - data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) - bt, _ := json.Marshal(data) - return mq.Result{Payload: bt, MessageID: task.ID} - }) - d.AddNode("queue4", func(ctx context.Context, task mq.Task) mq.Result { - var data []map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return mq.Result{Error: err} - } - payload := map[string]any{"storage": data} - bt, _ := json.Marshal(payload) - return mq.Result{Payload: bt, MessageID: task.ID} - }) + d.AddNode("queue1", tasks.Node1) + d.AddNode("queue2", tasks.Node2) + d.AddNode("queue3", tasks.Node3) + d.AddNode("queue4", tasks.Node4) + d.AddEdge("queue1", "queue2") d.AddLoop("queue2", "queue3") d.AddEdge("queue2", "queue4") + d.Prepare() go func() { err := d.Start(context.TODO()) if err != nil { panic(err) } }() - http.HandleFunc("/send-task", sendTaskHandler) + http.HandleFunc("/publish", requestHandler("publish")) + http.HandleFunc("/request", requestHandler("request")) log.Println("HTTP server started on http://localhost:8083") log.Fatal(http.ListenAndServe(":8083", nil)) } -func sendTaskHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var payload []byte - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err = io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) + +func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) return } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return + var payload []byte + if r.Body != nil { + defer r.Body.Close() + var err error + payload, err = io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + } else { + http.Error(w, "Empty request body", http.StatusBadRequest) + return + } + var rs mq.Result + if requestType == "request" { + rs = d.Request(context.Background(), payload) + } else { + rs = d.Send(context.Background(), payload) + } + w.Header().Set("Content-Type", "application/json") + result := map[string]any{ + "message_id": rs.MessageID, + "payload": string(rs.Payload), + "error": rs.Error, + } + json.NewEncoder(w).Encode(result) } - fmt.Println(string(payload)) - finalResult := d.Send(payload) - w.Header().Set("Content-Type", "application/json") - result := map[string]any{ - "message_id": finalResult.MessageID, - "payload": string(finalResult.Payload), - "error": finalResult.Error, - } - - json.NewEncoder(w).Encode(result) } diff --git a/examples/server.go b/examples/server.go index 2be640d..bad655d 100644 --- a/examples/server.go +++ b/examples/server.go @@ -8,8 +8,8 @@ import ( ) func main() { - b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) mq.Result { - fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue) + b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task mq.Result) mq.Result { + fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue) return mq.Result{} })) b.NewQueue("queue1") diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go new file mode 100644 index 0000000..e37b84b --- /dev/null +++ b/examples/tasks/tasks.go @@ -0,0 +1,40 @@ +package tasks + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/oarkflow/mq" +) + +func Node1(ctx context.Context, task mq.Task) mq.Result { + fmt.Println("Processing queue1") + return mq.Result{Payload: task.Payload, MessageID: task.ID} +} + +func Node2(ctx context.Context, task mq.Task) mq.Result { + return mq.Result{Payload: task.Payload, MessageID: task.ID} +} + +func Node3(ctx context.Context, task mq.Task) mq.Result { + var data map[string]any + err := json.Unmarshal(task.Payload, &data) + if err != nil { + return mq.Result{Error: err} + } + data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) + bt, _ := json.Marshal(data) + return mq.Result{Payload: bt, MessageID: task.ID} +} + +func Node4(ctx context.Context, task mq.Task) mq.Result { + var data []map[string]any + err := json.Unmarshal(task.Payload, &data) + if err != nil { + return mq.Result{Error: err} + } + payload := map[string]any{"storage": data} + bt, _ := json.Marshal(payload) + return mq.Result{Payload: bt, MessageID: task.ID} +} diff --git a/options.go b/options.go index 7bb6502..c1e3d3a 100644 --- a/options.go +++ b/options.go @@ -11,7 +11,7 @@ type Options struct { messageHandler MessageHandler closeHandler CloseHandler errorHandler ErrorHandler - callback []func(context.Context, *Task) Result + callback []func(context.Context, Result) Result maxRetries int initialDelay time.Duration maxBackoff time.Duration @@ -68,7 +68,7 @@ func WithMaxBackoff(val time.Duration) Option { } // WithCallback - -func WithCallback(val ...func(context.Context, *Task) Result) Option { +func WithCallback(val ...func(context.Context, Result) Result) Option { return func(opts *Options) { opts.callback = val }