mirror of
https://github.com/oarkflow/mq.git
synced 2025-11-03 04:13:18 +08:00
init: publisher
This commit is contained in:
@@ -17,4 +17,5 @@ var (
|
||||
ContentType = "Content-Type"
|
||||
TypeJson = "application/json"
|
||||
HeaderKey = "headers"
|
||||
TriggerNode = "triggerNode"
|
||||
)
|
||||
|
||||
9
ctx.go
9
ctx.go
@@ -62,6 +62,15 @@ func GetConsumerID(ctx context.Context) (string, bool) {
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func GetTriggerNode(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[TriggerNode]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func GetPublisherID(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
|
||||
142
dag/dag.go
142
dag/dag.go
@@ -3,36 +3,36 @@ package dag
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
type taskContext struct {
|
||||
totalItems int
|
||||
completed int
|
||||
results []json.RawMessage
|
||||
result json.RawMessage
|
||||
nodeType string
|
||||
}
|
||||
|
||||
type DAG struct {
|
||||
server *mq.Broker
|
||||
nodes map[string]*mq.Consumer
|
||||
edges map[string][]string
|
||||
loopEdges map[string]string
|
||||
taskChMap map[string]chan mq.Result
|
||||
loopTaskMap map[string]*loopTaskContext
|
||||
taskResults map[string]map[string]*taskContext
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type loopTaskContext struct {
|
||||
subResultCh chan mq.Result
|
||||
totalItems int
|
||||
completed int
|
||||
results []json.RawMessage
|
||||
}
|
||||
|
||||
func New(opts ...mq.Option) *DAG {
|
||||
d := &DAG{
|
||||
nodes: make(map[string]*mq.Consumer),
|
||||
edges: make(map[string][]string),
|
||||
loopEdges: make(map[string]string),
|
||||
taskChMap: make(map[string]chan mq.Result),
|
||||
loopTaskMap: make(map[string]*loopTaskContext),
|
||||
taskResults: make(map[string]map[string]*taskContext),
|
||||
}
|
||||
opts = append(opts, mq.WithCallback(d.TaskCallback))
|
||||
d.server = mq.NewBroker(opts...)
|
||||
@@ -84,50 +84,90 @@ func (d *DAG) Send(payload []byte) mq.Result {
|
||||
}
|
||||
|
||||
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
||||
log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result))
|
||||
d.mu.Lock()
|
||||
loopCtx, isLoopTask := d.loopTaskMap[task.ID]
|
||||
d.mu.Unlock()
|
||||
if isLoopTask {
|
||||
loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID}
|
||||
if task.Error != nil {
|
||||
return task.Error
|
||||
}
|
||||
triggeredNode, ok := mq.GetTriggerNode(ctx)
|
||||
var result any
|
||||
var payload []byte
|
||||
completed := false
|
||||
var nodeType string
|
||||
if ok && triggeredNode != "" {
|
||||
taskResults, ok := d.taskResults[task.ID]
|
||||
if ok {
|
||||
nodeResult, exists := taskResults[triggeredNode]
|
||||
if exists {
|
||||
nodeResult.completed++
|
||||
switch nodeResult.nodeType {
|
||||
case "loop":
|
||||
nodeResult.results = append(nodeResult.results, task.Result)
|
||||
nodeType = "loop"
|
||||
case "edge":
|
||||
nodeResult.result = task.Result
|
||||
nodeType = "edge"
|
||||
}
|
||||
if nodeResult.completed == nodeResult.totalItems {
|
||||
completed = true
|
||||
switch nodeResult.nodeType {
|
||||
case "loop":
|
||||
result = nodeResult.results
|
||||
case "edge":
|
||||
result = nodeResult.result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if completed {
|
||||
payload, _ = json.Marshal(result)
|
||||
} else {
|
||||
payload = task.Result
|
||||
}
|
||||
|
||||
if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists {
|
||||
var items []json.RawMessage
|
||||
if err := json.Unmarshal(task.Result, &items); err != nil {
|
||||
if err := json.Unmarshal(payload, &items); err != nil {
|
||||
return err
|
||||
}
|
||||
loopCtx := &loopTaskContext{
|
||||
subResultCh: make(chan mq.Result, len(items)),
|
||||
totalItems: len(items),
|
||||
results: make([]json.RawMessage, 0, len(items)),
|
||||
d.taskResults[task.ID] = map[string]*taskContext{
|
||||
task.CurrentQueue: {
|
||||
totalItems: len(items),
|
||||
nodeType: "loop",
|
||||
},
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.loopTaskMap[task.ID] = loopCtx
|
||||
d.mu.Unlock()
|
||||
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
|
||||
for _, item := range items {
|
||||
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue)
|
||||
return nil
|
||||
}
|
||||
if nodeType == "loop" && completed {
|
||||
task.CurrentQueue = triggeredNode
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
|
||||
edges, exists := d.edges[task.CurrentQueue]
|
||||
if exists {
|
||||
d.taskResults[task.ID] = map[string]*taskContext{
|
||||
task.CurrentQueue: {
|
||||
totalItems: 1,
|
||||
nodeType: "edge",
|
||||
},
|
||||
}
|
||||
for _, edge := range edges {
|
||||
_, err := d.PublishTask(ctx, task.Result, edge, task.ID)
|
||||
_, err := d.PublishTask(ctx, payload, edge, task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
} else if completed {
|
||||
d.mu.Lock()
|
||||
if resultCh, ok := d.taskChMap[task.ID]; ok {
|
||||
resultCh <- mq.Result{
|
||||
Command: "complete",
|
||||
Payload: task.Result,
|
||||
Payload: payload,
|
||||
Queue: task.CurrentQueue,
|
||||
MessageID: task.ID,
|
||||
Status: "done",
|
||||
@@ -138,47 +178,3 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) {
|
||||
d.mu.Lock()
|
||||
loopCtx := d.loopTaskMap[taskID]
|
||||
d.mu.Unlock()
|
||||
for result := range loopCtx.subResultCh {
|
||||
loopCtx.results = append(loopCtx.results, result.Payload)
|
||||
loopCtx.completed++
|
||||
if loopCtx.completed == loopCtx.totalItems {
|
||||
close(loopCtx.subResultCh)
|
||||
aggregatedResult, err := json.Marshal(loopCtx.results)
|
||||
if err != nil {
|
||||
log.Printf("Error aggregating results: %v", err)
|
||||
return
|
||||
}
|
||||
d.mu.Lock()
|
||||
delete(d.loopTaskMap, taskID)
|
||||
d.mu.Unlock()
|
||||
edges, exists := d.edges[currentQueue]
|
||||
if exists {
|
||||
for _, edge := range edges {
|
||||
_, err := d.PublishTask(ctx, aggregatedResult, edge, taskID)
|
||||
if err != nil {
|
||||
log.Printf("Error publishing aggregated result: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
d.mu.Lock()
|
||||
if resultCh, ok := d.taskChMap[taskID]; ok {
|
||||
resultCh <- mq.Result{
|
||||
Command: "complete",
|
||||
Payload: aggregatedResult,
|
||||
Queue: currentQueue,
|
||||
MessageID: taskID,
|
||||
Status: "done",
|
||||
}
|
||||
delete(d.taskChMap, taskID)
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user