mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-25 10:00:20 +08:00
init: publisher
This commit is contained in:
44
dag/dag.go
44
dag/dag.go
@@ -10,11 +10,11 @@ import (
|
||||
)
|
||||
|
||||
type taskContext struct {
|
||||
totalItems int
|
||||
completed int
|
||||
results []json.RawMessage
|
||||
result json.RawMessage
|
||||
nodeType string
|
||||
totalItems int
|
||||
completed int
|
||||
results []json.RawMessage
|
||||
result json.RawMessage
|
||||
multipleResults bool
|
||||
}
|
||||
|
||||
type DAG struct {
|
||||
@@ -191,37 +191,31 @@ func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result {
|
||||
return result
|
||||
}
|
||||
|
||||
func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
if task.Error != nil {
|
||||
return mq.Result{Error: task.Error}
|
||||
}
|
||||
triggeredNode, ok := mq.GetTriggerNode(ctx)
|
||||
func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) {
|
||||
var result any
|
||||
var payload []byte
|
||||
completed := false
|
||||
var nodeType string
|
||||
multipleResults := false
|
||||
if ok && triggeredNode != "" {
|
||||
taskResults, ok := d.taskResults[task.MessageID]
|
||||
if ok {
|
||||
nodeResult, exists := taskResults[triggeredNode]
|
||||
if exists {
|
||||
multipleResults = nodeResult.multipleResults
|
||||
nodeResult.completed++
|
||||
if nodeResult.completed == nodeResult.totalItems {
|
||||
completed = true
|
||||
}
|
||||
switch nodeResult.nodeType {
|
||||
case "loop":
|
||||
if multipleResults {
|
||||
nodeResult.results = append(nodeResult.results, task.Payload)
|
||||
if completed {
|
||||
result = nodeResult.results
|
||||
}
|
||||
nodeType = "loop"
|
||||
case "edge":
|
||||
} else {
|
||||
nodeResult.result = task.Payload
|
||||
if completed {
|
||||
result = nodeResult.result
|
||||
}
|
||||
nodeType = "edge"
|
||||
}
|
||||
}
|
||||
if completed {
|
||||
@@ -234,6 +228,15 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
} else {
|
||||
payload = task.Payload
|
||||
}
|
||||
return payload, completed, multipleResults
|
||||
}
|
||||
|
||||
func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
if task.Error != nil {
|
||||
return mq.Result{Error: task.Error}
|
||||
}
|
||||
triggeredNode, ok := mq.GetTriggerNode(ctx)
|
||||
payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode)
|
||||
if loopNodes, exists := d.loopEdges[task.Queue]; exists {
|
||||
var items []json.RawMessage
|
||||
if err := json.Unmarshal(payload, &items); err != nil {
|
||||
@@ -241,8 +244,8 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
}
|
||||
d.taskResults[task.MessageID] = map[string]*taskContext{
|
||||
task.Queue: {
|
||||
totalItems: len(items),
|
||||
nodeType: "loop",
|
||||
totalItems: len(items),
|
||||
multipleResults: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -256,9 +259,9 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
}
|
||||
}
|
||||
|
||||
return mq.Result{}
|
||||
return task
|
||||
}
|
||||
if nodeType == "loop" && completed {
|
||||
if multipleResults && completed {
|
||||
task.Queue = triggeredNode
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.Queue})
|
||||
@@ -267,7 +270,6 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
d.taskResults[task.MessageID] = map[string]*taskContext{
|
||||
task.Queue: {
|
||||
totalItems: 1,
|
||||
nodeType: "edge",
|
||||
},
|
||||
}
|
||||
rs := d.PublishTask(ctx, payload, edge, task.MessageID)
|
||||
|
||||
Reference in New Issue
Block a user