mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-05 07:57:00 +08:00
init: publisher
This commit is contained in:
@@ -243,9 +243,9 @@ func (b *Broker) HandleProcessedMessage(ctx context.Context, clientMsg Result) e
|
||||
}
|
||||
for _, callback := range b.opts.callback {
|
||||
if callback != nil {
|
||||
err := callback(ctx, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
result := callback(ctx, msg)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,6 @@
|
||||
package mq
|
||||
|
||||
type CMD int
|
||||
type CMD byte
|
||||
|
||||
const (
|
||||
SUBSCRIBE CMD = iota + 1
|
||||
|
103
dag/dag.go
103
dag/dag.go
@@ -123,23 +123,20 @@ func (d *DAG) Send(payload []byte) mq.Result {
|
||||
return finalResult
|
||||
}
|
||||
|
||||
func (d *DAG) handleTriggerNode(ctx context.Context, task *mq.Task) (bool, string, any, []byte, string) {
|
||||
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) mq.Result {
|
||||
if task.Error != nil {
|
||||
return mq.Result{Error: task.Error}
|
||||
}
|
||||
triggeredNode, ok := mq.GetTriggerNode(ctx)
|
||||
var result any
|
||||
var payload []byte
|
||||
completed := false
|
||||
var nodeType string
|
||||
if !(ok && triggeredNode != "") {
|
||||
return false, nodeType, result, payload, triggeredNode
|
||||
}
|
||||
if ok && triggeredNode != "" {
|
||||
taskResults, ok := d.taskResults[task.ID]
|
||||
if !ok {
|
||||
return false, nodeType, result, payload, triggeredNode
|
||||
}
|
||||
if ok {
|
||||
nodeResult, exists := taskResults[triggeredNode]
|
||||
if exists {
|
||||
return false, nodeType, result, payload, triggeredNode
|
||||
}
|
||||
nodeResult.completed++
|
||||
if nodeResult.completed == nodeResult.totalItems {
|
||||
completed = true
|
||||
@@ -147,24 +144,52 @@ func (d *DAG) handleTriggerNode(ctx context.Context, task *mq.Task) (bool, strin
|
||||
switch nodeResult.nodeType {
|
||||
case "loop":
|
||||
nodeResult.results = append(nodeResult.results, task.Result)
|
||||
if completed {
|
||||
result = nodeResult.results
|
||||
}
|
||||
nodeType = "loop"
|
||||
case "edge":
|
||||
nodeResult.result = task.Result
|
||||
if completed {
|
||||
result = nodeResult.result
|
||||
}
|
||||
nodeType = "edge"
|
||||
}
|
||||
}
|
||||
if completed {
|
||||
delete(taskResults, triggeredNode)
|
||||
}
|
||||
return completed, nodeType, result, payload, triggeredNode
|
||||
}
|
||||
}
|
||||
if completed {
|
||||
payload, _ = json.Marshal(result)
|
||||
} else {
|
||||
payload = task.Result
|
||||
}
|
||||
if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists {
|
||||
var items []json.RawMessage
|
||||
if err := json.Unmarshal(payload, &items); err != nil {
|
||||
return mq.Result{Error: task.Error}
|
||||
}
|
||||
d.taskResults[task.ID] = map[string]*taskContext{
|
||||
task.CurrentQueue: {
|
||||
totalItems: len(items),
|
||||
nodeType: "loop",
|
||||
},
|
||||
}
|
||||
|
||||
func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, completed bool) error {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
|
||||
for _, loopNode := range loopNodes {
|
||||
for _, item := range items {
|
||||
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
|
||||
if err != nil {
|
||||
return mq.Result{Error: task.Error}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mq.Result{}
|
||||
}
|
||||
if nodeType == "loop" && completed {
|
||||
task.CurrentQueue = triggeredNode
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
|
||||
edges, exists := d.edges[task.CurrentQueue]
|
||||
if exists {
|
||||
d.taskResults[task.ID] = map[string]*taskContext{
|
||||
@@ -176,7 +201,7 @@ func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, co
|
||||
for _, edge := range edges {
|
||||
_, err := d.PublishTask(ctx, payload, edge, task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
return mq.Result{Error: task.Error}
|
||||
}
|
||||
}
|
||||
} else if completed {
|
||||
@@ -194,49 +219,5 @@ func (d *DAG) handleEdges(ctx context.Context, task *mq.Task, payload []byte, co
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (d *DAG) handleLoopEdges(ctx context.Context, task *mq.Task, payload []byte, loopNodes []string) error {
|
||||
var items []json.RawMessage
|
||||
if err := json.Unmarshal(payload, &items); err != nil {
|
||||
return err
|
||||
}
|
||||
d.taskResults[task.ID] = map[string]*taskContext{
|
||||
task.CurrentQueue: {
|
||||
totalItems: len(items),
|
||||
nodeType: "loop",
|
||||
},
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
|
||||
for _, loopNode := range loopNodes {
|
||||
for _, item := range items {
|
||||
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
||||
if task.Error != nil {
|
||||
return task.Error
|
||||
}
|
||||
completed, nodeType, result, payload, triggeredNode := d.handleTriggerNode(ctx, task)
|
||||
if completed {
|
||||
payload, _ = json.Marshal(result)
|
||||
} else {
|
||||
payload = task.Result
|
||||
}
|
||||
if loopNodes, exists := d.loopEdges[task.CurrentQueue]; exists {
|
||||
return d.handleLoopEdges(ctx, task, payload, loopNodes)
|
||||
}
|
||||
if nodeType == "loop" && completed {
|
||||
task.CurrentQueue = triggeredNode
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
|
||||
return d.handleEdges(ctx, task, payload, completed)
|
||||
return mq.Result{}
|
||||
}
|
||||
|
@@ -4,11 +4,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
)
|
||||
|
||||
var d *dag.DAG
|
||||
@@ -72,6 +73,7 @@ func sendTaskHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Empty request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fmt.Println(string(payload))
|
||||
finalResult := d.Send(payload)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
result := map[string]any{
|
||||
|
@@ -8,9 +8,9 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) error {
|
||||
b := mq.NewBroker(mq.WithCallback(func(ctx context.Context, task *mq.Task) mq.Result {
|
||||
fmt.Println("Received task", task.ID, "Payload", string(task.Payload), "Result", string(task.Result), task.Error, task.CurrentQueue)
|
||||
return nil
|
||||
return mq.Result{}
|
||||
}))
|
||||
b.NewQueue("queue1")
|
||||
b.NewQueue("queue2")
|
||||
|
@@ -11,7 +11,7 @@ type Options struct {
|
||||
messageHandler MessageHandler
|
||||
closeHandler CloseHandler
|
||||
errorHandler ErrorHandler
|
||||
callback []func(context.Context, *Task) error
|
||||
callback []func(context.Context, *Task) Result
|
||||
maxRetries int
|
||||
initialDelay time.Duration
|
||||
maxBackoff time.Duration
|
||||
@@ -68,7 +68,7 @@ func WithMaxBackoff(val time.Duration) Option {
|
||||
}
|
||||
|
||||
// WithCallback -
|
||||
func WithCallback(val ...func(context.Context, *Task) error) Option {
|
||||
func WithCallback(val ...func(context.Context, *Task) Result) Option {
|
||||
return func(opts *Options) {
|
||||
opts.callback = val
|
||||
}
|
||||
|
Reference in New Issue
Block a user