init: publisher

This commit is contained in:
sujit
2024-10-01 08:39:10 +05:45
parent a2859aa4be
commit c04f1baa16

View File

@@ -10,11 +10,11 @@ import (
)
type taskContext struct {
totalItems int
completed int
results []json.RawMessage
result json.RawMessage
nodeType string
totalItems int
completed int
results []json.RawMessage
result json.RawMessage
multipleResults bool
}
type DAG struct {
@@ -191,37 +191,31 @@ func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result {
return result
}
func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
if task.Error != nil {
return mq.Result{Error: task.Error}
}
triggeredNode, ok := mq.GetTriggerNode(ctx)
func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) {
var result any
var payload []byte
completed := false
var nodeType string
multipleResults := false
if ok && triggeredNode != "" {
taskResults, ok := d.taskResults[task.MessageID]
if ok {
nodeResult, exists := taskResults[triggeredNode]
if exists {
multipleResults = nodeResult.multipleResults
nodeResult.completed++
if nodeResult.completed == nodeResult.totalItems {
completed = true
}
switch nodeResult.nodeType {
case "loop":
if multipleResults {
nodeResult.results = append(nodeResult.results, task.Payload)
if completed {
result = nodeResult.results
}
nodeType = "loop"
case "edge":
} else {
nodeResult.result = task.Payload
if completed {
result = nodeResult.result
}
nodeType = "edge"
}
}
if completed {
@@ -234,6 +228,15 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
} else {
payload = task.Payload
}
return payload, completed, multipleResults
}
func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
if task.Error != nil {
return mq.Result{Error: task.Error}
}
triggeredNode, ok := mq.GetTriggerNode(ctx)
payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode)
if loopNodes, exists := d.loopEdges[task.Queue]; exists {
var items []json.RawMessage
if err := json.Unmarshal(payload, &items); err != nil {
@@ -241,8 +244,8 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
}
d.taskResults[task.MessageID] = map[string]*taskContext{
task.Queue: {
totalItems: len(items),
nodeType: "loop",
totalItems: len(items),
multipleResults: true,
},
}
@@ -256,9 +259,9 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
}
}
return mq.Result{}
return task
}
if nodeType == "loop" && completed {
if multipleResults && completed {
task.Queue = triggeredNode
}
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.Queue})
@@ -267,7 +270,6 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
d.taskResults[task.MessageID] = map[string]*taskContext{
task.Queue: {
totalItems: 1,
nodeType: "edge",
},
}
rs := d.PublishTask(ctx, payload, edge, task.MessageID)