init: publisher

This commit is contained in:
sujit
2024-09-29 16:26:54 +05:45
parent 488652d322
commit 4221627389
3 changed files with 79 additions and 73 deletions

View File

@@ -17,4 +17,5 @@ var (
ContentType = "Content-Type"
TypeJson = "application/json"
HeaderKey = "headers"
TriggerNode = "triggerNode"
)

9
ctx.go
View File

@@ -62,6 +62,15 @@ func GetConsumerID(ctx context.Context) (string, bool) {
return contentType, ok
}
func GetTriggerNode(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
contentType, ok := headers[TriggerNode]
return contentType, ok
}
func GetPublisherID(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {

View File

@@ -3,36 +3,36 @@ package dag
import (
"context"
"encoding/json"
"log"
"sync"
"github.com/oarkflow/mq"
)
type taskContext struct {
totalItems int
completed int
results []json.RawMessage
result json.RawMessage
nodeType string
}
type DAG struct {
server *mq.Broker
nodes map[string]*mq.Consumer
edges map[string][]string
loopEdges map[string]string
taskChMap map[string]chan mq.Result
loopTaskMap map[string]*loopTaskContext
taskResults map[string]map[string]*taskContext
mu sync.Mutex
}
type loopTaskContext struct {
subResultCh chan mq.Result
totalItems int
completed int
results []json.RawMessage
}
func New(opts ...mq.Option) *DAG {
d := &DAG{
nodes: make(map[string]*mq.Consumer),
edges: make(map[string][]string),
loopEdges: make(map[string]string),
taskChMap: make(map[string]chan mq.Result),
loopTaskMap: make(map[string]*loopTaskContext),
taskResults: make(map[string]map[string]*taskContext),
}
opts = append(opts, mq.WithCallback(d.TaskCallback))
d.server = mq.NewBroker(opts...)
@@ -84,50 +84,90 @@ func (d *DAG) Send(payload []byte) mq.Result {
}
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result))
d.mu.Lock()
loopCtx, isLoopTask := d.loopTaskMap[task.ID]
d.mu.Unlock()
if isLoopTask {
loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID}
if task.Error != nil {
return task.Error
}
triggeredNode, ok := mq.GetTriggerNode(ctx)
var result any
var payload []byte
completed := false
var nodeType string
if ok && triggeredNode != "" {
taskResults, ok := d.taskResults[task.ID]
if ok {
nodeResult, exists := taskResults[triggeredNode]
if exists {
nodeResult.completed++
switch nodeResult.nodeType {
case "loop":
nodeResult.results = append(nodeResult.results, task.Result)
nodeType = "loop"
case "edge":
nodeResult.result = task.Result
nodeType = "edge"
}
if nodeResult.completed == nodeResult.totalItems {
completed = true
switch nodeResult.nodeType {
case "loop":
result = nodeResult.results
case "edge":
result = nodeResult.result
}
}
}
}
}
if completed {
payload, _ = json.Marshal(result)
} else {
payload = task.Result
}
if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists {
var items []json.RawMessage
if err := json.Unmarshal(task.Result, &items); err != nil {
if err := json.Unmarshal(payload, &items); err != nil {
return err
}
loopCtx := &loopTaskContext{
subResultCh: make(chan mq.Result, len(items)),
totalItems: len(items),
results: make([]json.RawMessage, 0, len(items)),
d.taskResults[task.ID] = map[string]*taskContext{
task.CurrentQueue: {
totalItems: len(items),
nodeType: "loop",
},
}
d.mu.Lock()
d.loopTaskMap[task.ID] = loopCtx
d.mu.Unlock()
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
for _, item := range items {
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
if err != nil {
return err
}
}
go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue)
return nil
}
if nodeType == "loop" && completed {
task.CurrentQueue = triggeredNode
}
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
edges, exists := d.edges[task.CurrentQueue]
if exists {
d.taskResults[task.ID] = map[string]*taskContext{
task.CurrentQueue: {
totalItems: 1,
nodeType: "edge",
},
}
for _, edge := range edges {
_, err := d.PublishTask(ctx, task.Result, edge, task.ID)
_, err := d.PublishTask(ctx, payload, edge, task.ID)
if err != nil {
return err
}
}
} else {
} else if completed {
d.mu.Lock()
if resultCh, ok := d.taskChMap[task.ID]; ok {
resultCh <- mq.Result{
Command: "complete",
Payload: task.Result,
Payload: payload,
Queue: task.CurrentQueue,
MessageID: task.ID,
Status: "done",
@@ -138,47 +178,3 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
}
return nil
}
func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) {
d.mu.Lock()
loopCtx := d.loopTaskMap[taskID]
d.mu.Unlock()
for result := range loopCtx.subResultCh {
loopCtx.results = append(loopCtx.results, result.Payload)
loopCtx.completed++
if loopCtx.completed == loopCtx.totalItems {
close(loopCtx.subResultCh)
aggregatedResult, err := json.Marshal(loopCtx.results)
if err != nil {
log.Printf("Error aggregating results: %v", err)
return
}
d.mu.Lock()
delete(d.loopTaskMap, taskID)
d.mu.Unlock()
edges, exists := d.edges[currentQueue]
if exists {
for _, edge := range edges {
_, err := d.PublishTask(ctx, aggregatedResult, edge, taskID)
if err != nil {
log.Printf("Error publishing aggregated result: %v", err)
return
}
}
} else {
d.mu.Lock()
if resultCh, ok := d.taskChMap[taskID]; ok {
resultCh <- mq.Result{
Command: "complete",
Payload: aggregatedResult,
Queue: currentQueue,
MessageID: taskID,
Status: "done",
}
delete(d.taskChMap, taskID)
}
d.mu.Unlock()
}
}
}
}