init: publisher

This commit is contained in:
sujit
2024-09-30 21:09:48 +05:45
parent 93190ffc74
commit afd43302ef
6 changed files with 71 additions and 88 deletions

View File

@@ -243,9 +243,9 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e
}
for _, callback := range b.opts.callback {
if callback != nil {
err := callback(ctx, msg)
if err != nil {
return err
result := callback(ctx, msg)
if result.Error != nil {
return result.Error
}
}
}

View File

@@ -1,6 +1,6 @@
package mq
type CMD int
type CMD byte
const (
SUBSCRIBE CMD = iota + 1

View File

@@ -123,48 +123,73 @@ func (d *DAG) Send(payload []byte) mq.Result {
return finalResult
}
func (d *DAG) handleTriggerNode(ctx context.Context, task *mq.Task) (bool, string, any, []byte, string) {
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result {
if task.Error != nil {
return mq.Result{Error: task.Error}
}
triggeredNode, ok := mq.GetTriggerNode(ctx)
var result any
var payload []byte
completed := false
var nodeType string
if !(ok && triggeredNode != "") {
return false, nodeType, result, payload, triggeredNode
}
taskResults, ok := d.taskResults[task.ID]
if !ok {
return false, nodeType, result, payload, triggeredNode
}
nodeResult, exists := taskResults[triggeredNode]
if exists {
return false, nodeType, result, payload, triggeredNode
}
nodeResult.completed++
if nodeResult.completed == nodeResult.totalItems {
completed = true
}
switch nodeResult.nodeType {
case "loop":
nodeResult.results = append(nodeResult.results, task.Result)
if completed {
result = nodeResult.results
if ok && triggeredNode != "" {
taskResults, ok := d.taskResults[task.ID]
if ok {
nodeResult, exists := taskResults[triggeredNode]
if exists {
nodeResult.completed++
if nodeResult.completed == nodeResult.totalItems {
completed = true
}
switch nodeResult.nodeType {
case "loop":
nodeResult.results = append(nodeResult.results, task.Result)
result = nodeResult.results
nodeType = "loop"
case "edge":
nodeResult.result = task.Result
result = nodeResult.result
nodeType = "edge"
}
}
if completed {
delete(taskResults, triggeredNode)
}
}
nodeType = "loop"
case "edge":
nodeResult.result = task.Result
if completed {
result = nodeResult.result
}
nodeType = "edge"
}
if completed {
delete(taskResults, triggeredNode)
payload, _ = json.Marshal(result)
} else {
payload = task.Result
}
return completed, nodeType, result, payload, triggeredNode
}
if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists {
var items []json.RawMessage
if err := json.Unmarshal(payload, &items); err != nil {
return mq.Result{Error: task.Error}
}
d.taskResults[task.ID] = map[string]*taskContext{
task.CurrentQueue: {
totalItems: len(items),
nodeType: "loop",
},
}
func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, completed bool) error {
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
for _, loopNode := range loopNodes {
for _, item := range items {
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
if err != nil {
return mq.Result{Error: task.Error}
}
}
}
return mq.Result{}
}
if nodeType == "loop" && completed {
task.CurrentQueue = triggeredNode
}
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
edges, exists := d.edges[task.CurrentQueue]
if exists {
d.taskResults[task.ID] = map[string]*taskContext{
@@ -176,7 +201,7 @@ func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, co
for _, edge := range edges {
_, err := d.PublishTask(ctx, payload, edge, task.ID)
if err != nil {
return err
return mq.Result{Error: task.Error}
}
}
} else if completed {
@@ -194,49 +219,5 @@ func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, co
}
d.mu.Unlock()
}
return nil
}
func (d *DAG) handleLoopEdges(ctx context.Context, task *mq.Task, payload []byte, loopNodes []string) error {
var items []json.RawMessage
if err := json.Unmarshal(payload, &items); err != nil {
return err
}
d.taskResults[task.ID] = map[string]*taskContext{
task.CurrentQueue: {
totalItems: len(items),
nodeType: "loop",
},
}
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
for _, loopNode := range loopNodes {
for _, item := range items {
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
if err != nil {
return err
}
}
}
return nil
}
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
if task.Error != nil {
return task.Error
}
completed, nodeType, result, payload, triggeredNode := d.handleTriggerNode(ctx, task)
if completed {
payload, _ = json.Marshal(result)
} else {
payload = task.Result
}
if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists {
return d.handleLoopEdges(ctx, task, payload, loopNodes)
}
if nodeType == "loop" && completed {
task.CurrentQueue = triggeredNode
}
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
return d.handleEdges(ctx, task, payload, completed)
return mq.Result{}
}

View File

@@ -4,11 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
"io"
"log"
"net/http"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
)
var d *dag.DAG
@@ -72,6 +73,7 @@ func sendTaskHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
fmt.Println(string(payload))
finalResult := d.Send(payload)
w.Header().Set("Content-Type", "application/json")
result := map[string]any{

View File

@@ -8,9 +8,9 @@ import (
)
func main() {
b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) error {
b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) mq.Result {
fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue)
return nil
return mq.Result{}
}))
b.NewQueue("queue1")
b.NewQueue("queue2")

View File

@@ -11,7 +11,7 @@ type Options struct {
messageHandler MessageHandler
closeHandler CloseHandler
errorHandler ErrorHandler
callback []func(context.Context, *Task) error
callback []func(context.Context, *Task) Result
maxRetries int
initialDelay time.Duration
maxBackoff time.Duration
@@ -68,7 +68,7 @@ func WithMaxBackoff(val time.Duration) Option {
}
// WithCallback -
func WithCallback(val ...func(context.Context, *Task) error) Option {
func WithCallback(val ...func(context.Context, *Task) Result) Option {
return func(opts *Options) {
opts.callback = val
}