diff --git a/broker.go b/broker.go index c840e63..edd3bea 100644 --- a/broker.go +++ b/broker.go @@ -243,9 +243,9 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e } for _, callback := range b.opts.callback { if callback != nil { - err := callback(ctx, msg) - if err != nil { - return err + result := callback(ctx, msg) + if result.Error != nil { + return result.Error } } } diff --git a/constants.go b/constants.go index dd314f7..9bd8dbf 100644 --- a/constants.go +++ b/constants.go @@ -1,6 +1,6 @@ package mq -type CMD int +type CMD byte const ( SUBSCRIBE CMD = iota + 1 diff --git a/dag/dag.go b/dag/dag.go index 36d8cb1..cb699e0 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -123,48 +123,73 @@ func (d *DAG) Send(payload []byte) mq.Result { return finalResult } -func (d *DAG) handleTriggerNode(ctx context.Context, task *mq.Task) (bool, string, any, []byte, string) { +func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result { + if task.Error != nil { + return mq.Result{Error: task.Error} + } triggeredNode, ok := mq.GetTriggerNode(ctx) var result any var payload []byte completed := false var nodeType string - if !(ok && triggeredNode != "") { - return false, nodeType, result, payload, triggeredNode - } - taskResults, ok := d.taskResults[task.ID] - if !ok { - return false, nodeType, result, payload, triggeredNode - } - nodeResult, exists := taskResults[triggeredNode] - if exists { - return false, nodeType, result, payload, triggeredNode - } - nodeResult.completed++ - if nodeResult.completed == nodeResult.totalItems { - completed = true - } - switch nodeResult.nodeType { - case "loop": - nodeResult.results = append(nodeResult.results, task.Result) - if completed { - result = nodeResult.results + if ok && triggeredNode != "" { + taskResults, ok := d.taskResults[task.ID] + if ok { + nodeResult, exists := taskResults[triggeredNode] + if exists { + nodeResult.completed++ + if nodeResult.completed == nodeResult.totalItems { + completed = true + } + switch nodeResult.nodeType { + case "loop": + nodeResult.results = append(nodeResult.results, task.Result) + result = nodeResult.results + nodeType = "loop" + case "edge": + nodeResult.result = task.Result + result = nodeResult.result + nodeType = "edge" + } + } + if completed { + delete(taskResults, triggeredNode) + } } - nodeType = "loop" - case "edge": - nodeResult.result = task.Result - if completed { - result = nodeResult.result - } - nodeType = "edge" } if completed { - delete(taskResults, triggeredNode) + payload, _ = json.Marshal(result) + } else { + payload = task.Result } - return completed, nodeType, result, payload, triggeredNode -} + if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists { + var items []json.RawMessage + if err := json.Unmarshal(payload, &items); err != nil { + return mq.Result{Error: task.Error} + } + d.taskResults[task.ID] = map[string]*taskContext{ + task.CurrentQueue: { + totalItems: len(items), + nodeType: "loop", + }, + } -func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, completed bool) error { + ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue}) + for _, loopNode := range loopNodes { + for _, item := range items { + _, err := d.PublishTask(ctx, item, loopNode, task.ID) + if err != nil { + return mq.Result{Error: task.Error} + } + } + } + + return mq.Result{} + } + if nodeType == "loop" && completed { + task.CurrentQueue = triggeredNode + } + ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue}) edges, exists := d.edges[task.CurrentQueue] if exists { d.taskResults[task.ID] = map[string]*taskContext{ @@ -176,7 +201,7 @@ func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, co for _, edge := range edges { _, err := d.PublishTask(ctx, payload, edge, task.ID) if err != nil { - return err + return mq.Result{Error: task.Error} } } } else if completed { @@ -194,49 +219,5 @@ func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, co } d.mu.Unlock() } - return nil - -} - -func (d *DAG) handleLoopEdges(ctx context.Context, task *mq.Task, payload []byte, loopNodes []string) error { - var items []json.RawMessage - if err := json.Unmarshal(payload, &items); err != nil { - return err - } - d.taskResults[task.ID] = map[string]*taskContext{ - task.CurrentQueue: { - totalItems: len(items), - nodeType: "loop", - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue}) - for _, loopNode := range loopNodes { - for _, item := range items { - _, err := d.PublishTask(ctx, item, loopNode, task.ID) - if err != nil { - return err - } - } - } - return nil -} - -func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error { - if task.Error != nil { - return task.Error - } - completed, nodeType, result, payload, triggeredNode := d.handleTriggerNode(ctx, task) - if completed { - payload, _ = json.Marshal(result) - } else { - payload = task.Result - } - if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists { - return d.handleLoopEdges(ctx, task, payload, loopNodes) - } - if nodeType == "loop" && completed { - task.CurrentQueue = triggeredNode - } - ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue}) - return d.handleEdges(ctx, task, payload, completed) + return mq.Result{} } diff --git a/examples/dag.go b/examples/dag.go index 08c91c0..8e046d0 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -4,11 +4,12 @@ import ( "context" "encoding/json" "fmt" - "github.com/oarkflow/mq" - "github.com/oarkflow/mq/dag" "io" "log" "net/http" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/dag" ) var d *dag.DAG @@ -72,6 +73,7 @@ func sendTaskHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Empty request body", http.StatusBadRequest) return } + fmt.Println(string(payload)) finalResult := d.Send(payload) w.Header().Set("Content-Type", "application/json") result := map[string]any{ diff --git a/examples/server.go b/examples/server.go index 3f6a916..2be640d 100644 --- a/examples/server.go +++ b/examples/server.go @@ -8,9 +8,9 @@ import ( ) func main() { - b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) error { + b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) mq.Result { fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue) - return nil + return mq.Result{} })) b.NewQueue("queue1") b.NewQueue("queue2") diff --git a/options.go b/options.go index 207f04d..7bb6502 100644 --- a/options.go +++ b/options.go @@ -11,7 +11,7 @@ type Options struct { messageHandler MessageHandler closeHandler CloseHandler errorHandler ErrorHandler - callback []func(context.Context, *Task) error + callback []func(context.Context, *Task) Result maxRetries int initialDelay time.Duration maxBackoff time.Duration @@ -68,7 +68,7 @@ func WithMaxBackoff(val time.Duration) Option { } // WithCallback - -func WithCallback(val ...func(context.Context, *Task) error) Option { +func WithCallback(val ...func(context.Context, *Task) Result) Option { return func(opts *Options) { opts.callback = val }