init: publisher

This commit is contained in:
sujit
2024-09-30 21:09:48 +05:45
parent 93190ffc74
commit afd43302ef
6 changed files with 71 additions and 88 deletions

View File

@@ -243,9 +243,9 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e
} }
for _, callback := range b.opts.callback { for _, callback := range b.opts.callback {
if callback != nil { if callback != nil {
err := callback(ctx, msg) result := callback(ctx, msg)
if err != nil { if result.Error != nil {
return err return result.Error
} }
} }
} }

View File

@@ -1,6 +1,6 @@
package mq package mq
type CMD int type CMD byte
const ( const (
SUBSCRIBE CMD = iota + 1 SUBSCRIBE CMD = iota + 1

View File

@@ -123,23 +123,20 @@ func (d *DAG) Send(payload []byte) mq.Result {
return finalResult return finalResult
} }
func (d *DAG) handleTriggerNode(ctx context.Context, task *mq.Task) (bool, string, any, []byte, string) { func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result {
if task.Error != nil {
return mq.Result{Error: task.Error}
}
triggeredNode, ok := mq.GetTriggerNode(ctx) triggeredNode, ok := mq.GetTriggerNode(ctx)
var result any var result any
var payload []byte var payload []byte
completed := false completed := false
var nodeType string var nodeType string
if !(ok && triggeredNode != "") { if ok && triggeredNode != "" {
return false, nodeType, result, payload, triggeredNode
}
taskResults, ok := d.taskResults[task.ID] taskResults, ok := d.taskResults[task.ID]
if !ok { if ok {
return false, nodeType, result, payload, triggeredNode
}
nodeResult, exists := taskResults[triggeredNode] nodeResult, exists := taskResults[triggeredNode]
if exists { if exists {
return false, nodeType, result, payload, triggeredNode
}
nodeResult.completed++ nodeResult.completed++
if nodeResult.completed == nodeResult.totalItems { if nodeResult.completed == nodeResult.totalItems {
completed = true completed = true
@@ -147,24 +144,52 @@ func (d *DAG) handleTriggerNode(ctx context.Context, task *mq.Task) (bool, strin
switch nodeResult.nodeType { switch nodeResult.nodeType {
case "loop": case "loop":
nodeResult.results = append(nodeResult.results, task.Result) nodeResult.results = append(nodeResult.results, task.Result)
if completed {
result = nodeResult.results result = nodeResult.results
}
nodeType = "loop" nodeType = "loop"
case "edge": case "edge":
nodeResult.result = task.Result nodeResult.result = task.Result
if completed {
result = nodeResult.result result = nodeResult.result
}
nodeType = "edge" nodeType = "edge"
} }
}
if completed { if completed {
delete(taskResults, triggeredNode) delete(taskResults, triggeredNode)
} }
return completed, nodeType, result, payload, triggeredNode }
} }
if completed {
payload, _ = json.Marshal(result)
} else {
payload = task.Result
}
if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists {
var items []json.RawMessage
if err := json.Unmarshal(payload, &items); err != nil {
return mq.Result{Error: task.Error}
}
d.taskResults[task.ID] = map[string]*taskContext{
task.CurrentQueue: {
totalItems: len(items),
nodeType: "loop",
},
}
func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, completed bool) error { ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
for _, loopNode := range loopNodes {
for _, item := range items {
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
if err != nil {
return mq.Result{Error: task.Error}
}
}
}
return mq.Result{}
}
if nodeType == "loop" && completed {
task.CurrentQueue = triggeredNode
}
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
edges, exists := d.edges[task.CurrentQueue] edges, exists := d.edges[task.CurrentQueue]
if exists { if exists {
d.taskResults[task.ID] = map[string]*taskContext{ d.taskResults[task.ID] = map[string]*taskContext{
@@ -176,7 +201,7 @@ func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, co
for _, edge := range edges { for _, edge := range edges {
_, err := d.PublishTask(ctx, payload, edge, task.ID) _, err := d.PublishTask(ctx, payload, edge, task.ID)
if err != nil { if err != nil {
return err return mq.Result{Error: task.Error}
} }
} }
} else if completed { } else if completed {
@@ -194,49 +219,5 @@ func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, co
} }
d.mu.Unlock() d.mu.Unlock()
} }
return nil return mq.Result{}
}
func (d *DAG) handleLoopEdges(ctx context.Context, task *mq.Task, payload []byte, loopNodes []string) error {
var items []json.RawMessage
if err := json.Unmarshal(payload, &items); err != nil {
return err
}
d.taskResults[task.ID] = map[string]*taskContext{
task.CurrentQueue: {
totalItems: len(items),
nodeType: "loop",
},
}
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
for _, loopNode := range loopNodes {
for _, item := range items {
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
if err != nil {
return err
}
}
}
return nil
}
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
if task.Error != nil {
return task.Error
}
completed, nodeType, result, payload, triggeredNode := d.handleTriggerNode(ctx, task)
if completed {
payload, _ = json.Marshal(result)
} else {
payload = task.Result
}
if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists {
return d.handleLoopEdges(ctx, task, payload, loopNodes)
}
if nodeType == "loop" && completed {
task.CurrentQueue = triggeredNode
}
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
return d.handleEdges(ctx, task, payload, completed)
} }

View File

@@ -4,11 +4,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
"io" "io"
"log" "log"
"net/http" "net/http"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
) )
var d *dag.DAG var d *dag.DAG
@@ -72,6 +73,7 @@ func sendTaskHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Empty request body", http.StatusBadRequest) http.Error(w, "Empty request body", http.StatusBadRequest)
return return
} }
fmt.Println(string(payload))
finalResult := d.Send(payload) finalResult := d.Send(payload)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
result := map[string]any{ result := map[string]any{

View File

@@ -8,9 +8,9 @@ import (
) )
func main() { func main() {
b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) error { b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) mq.Result {
fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue) fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue)
return nil return mq.Result{}
})) }))
b.NewQueue("queue1") b.NewQueue("queue1")
b.NewQueue("queue2") b.NewQueue("queue2")

View File

@@ -11,7 +11,7 @@ type Options struct {
messageHandler MessageHandler messageHandler MessageHandler
closeHandler CloseHandler closeHandler CloseHandler
errorHandler ErrorHandler errorHandler ErrorHandler
callback []func(context.Context, *Task) error callback []func(context.Context, *Task) Result
maxRetries int maxRetries int
initialDelay time.Duration initialDelay time.Duration
maxBackoff time.Duration maxBackoff time.Duration
@@ -68,7 +68,7 @@ func WithMaxBackoff(val time.Duration) Option {
} }
// WithCallback - // WithCallback -
func WithCallback(val ...func(context.Context, *Task) error) Option { func WithCallback(val ...func(context.Context, *Task) Result) Option {
return func(opts *Options) { return func(opts *Options) {
opts.callback = val opts.callback = val
} }