mirror of
https://github.com/oarkflow/mq.git
synced 2025-11-03 12:00:50 +08:00
init: publisher
This commit is contained in:
@@ -17,4 +17,5 @@ var (
|
|||||||
ContentType = "Content-Type"
|
ContentType = "Content-Type"
|
||||||
TypeJson = "application/json"
|
TypeJson = "application/json"
|
||||||
HeaderKey = "headers"
|
HeaderKey = "headers"
|
||||||
|
TriggerNode = "triggerNode"
|
||||||
)
|
)
|
||||||
|
|||||||
9
ctx.go
9
ctx.go
@@ -62,6 +62,15 @@ func GetConsumerID(ctx context.Context) (string, bool) {
|
|||||||
return contentType, ok
|
return contentType, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTriggerNode(ctx context.Context) (string, bool) {
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
contentType, ok := headers[TriggerNode]
|
||||||
|
return contentType, ok
|
||||||
|
}
|
||||||
|
|
||||||
func GetPublisherID(ctx context.Context) (string, bool) {
|
func GetPublisherID(ctx context.Context) (string, bool) {
|
||||||
headers, ok := GetHeaders(ctx)
|
headers, ok := GetHeaders(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
140
dag/dag.go
140
dag/dag.go
@@ -3,36 +3,36 @@ package dag
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type taskContext struct {
|
||||||
|
totalItems int
|
||||||
|
completed int
|
||||||
|
results []json.RawMessage
|
||||||
|
result json.RawMessage
|
||||||
|
nodeType string
|
||||||
|
}
|
||||||
|
|
||||||
type DAG struct {
|
type DAG struct {
|
||||||
server *mq.Broker
|
server *mq.Broker
|
||||||
nodes map[string]*mq.Consumer
|
nodes map[string]*mq.Consumer
|
||||||
edges map[string][]string
|
edges map[string][]string
|
||||||
loopEdges map[string]string
|
loopEdges map[string]string
|
||||||
taskChMap map[string]chan mq.Result
|
taskChMap map[string]chan mq.Result
|
||||||
loopTaskMap map[string]*loopTaskContext
|
taskResults map[string]map[string]*taskContext
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type loopTaskContext struct {
|
|
||||||
subResultCh chan mq.Result
|
|
||||||
totalItems int
|
|
||||||
completed int
|
|
||||||
results []json.RawMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(opts ...mq.Option) *DAG {
|
func New(opts ...mq.Option) *DAG {
|
||||||
d := &DAG{
|
d := &DAG{
|
||||||
nodes: make(map[string]*mq.Consumer),
|
nodes: make(map[string]*mq.Consumer),
|
||||||
edges: make(map[string][]string),
|
edges: make(map[string][]string),
|
||||||
loopEdges: make(map[string]string),
|
loopEdges: make(map[string]string),
|
||||||
taskChMap: make(map[string]chan mq.Result),
|
taskChMap: make(map[string]chan mq.Result),
|
||||||
loopTaskMap: make(map[string]*loopTaskContext),
|
taskResults: make(map[string]map[string]*taskContext),
|
||||||
}
|
}
|
||||||
opts = append(opts, mq.WithCallback(d.TaskCallback))
|
opts = append(opts, mq.WithCallback(d.TaskCallback))
|
||||||
d.server = mq.NewBroker(opts...)
|
d.server = mq.NewBroker(opts...)
|
||||||
@@ -84,50 +84,90 @@ func (d *DAG) Send(payload []byte) mq.Result {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
||||||
log.Printf("Callback from queue %s with result: %s", task.CurrentQueue, string(task.Result))
|
if task.Error != nil {
|
||||||
d.mu.Lock()
|
return task.Error
|
||||||
loopCtx, isLoopTask := d.loopTaskMap[task.ID]
|
}
|
||||||
d.mu.Unlock()
|
triggeredNode, ok := mq.GetTriggerNode(ctx)
|
||||||
if isLoopTask {
|
var result any
|
||||||
loopCtx.subResultCh <- mq.Result{Payload: task.Result, MessageID: task.ID}
|
var payload []byte
|
||||||
|
completed := false
|
||||||
|
var nodeType string
|
||||||
|
if ok && triggeredNode != "" {
|
||||||
|
taskResults, ok := d.taskResults[task.ID]
|
||||||
|
if ok {
|
||||||
|
nodeResult, exists := taskResults[triggeredNode]
|
||||||
|
if exists {
|
||||||
|
nodeResult.completed++
|
||||||
|
switch nodeResult.nodeType {
|
||||||
|
case "loop":
|
||||||
|
nodeResult.results = append(nodeResult.results, task.Result)
|
||||||
|
nodeType = "loop"
|
||||||
|
case "edge":
|
||||||
|
nodeResult.result = task.Result
|
||||||
|
nodeType = "edge"
|
||||||
|
}
|
||||||
|
if nodeResult.completed == nodeResult.totalItems {
|
||||||
|
completed = true
|
||||||
|
switch nodeResult.nodeType {
|
||||||
|
case "loop":
|
||||||
|
result = nodeResult.results
|
||||||
|
case "edge":
|
||||||
|
result = nodeResult.result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if completed {
|
||||||
|
payload, _ = json.Marshal(result)
|
||||||
|
} else {
|
||||||
|
payload = task.Result
|
||||||
}
|
}
|
||||||
|
|
||||||
if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists {
|
if loopNode, exists := d.loopEdges[task.CurrentQueue]; exists {
|
||||||
var items []json.RawMessage
|
var items []json.RawMessage
|
||||||
if err := json.Unmarshal(task.Result, &items); err != nil {
|
if err := json.Unmarshal(payload, &items); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
loopCtx := &loopTaskContext{
|
d.taskResults[task.ID] = map[string]*taskContext{
|
||||||
subResultCh: make(chan mq.Result, len(items)),
|
task.CurrentQueue: {
|
||||||
totalItems: len(items),
|
totalItems: len(items),
|
||||||
results: make([]json.RawMessage, 0, len(items)),
|
nodeType: "loop",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
d.mu.Lock()
|
|
||||||
d.loopTaskMap[task.ID] = loopCtx
|
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
|
||||||
d.mu.Unlock()
|
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
|
_, err := d.PublishTask(ctx, item, loopNode, task.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
go d.waitForLoopCompletion(ctx, task.ID, task.CurrentQueue)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if nodeType == "loop" && completed {
|
||||||
|
task.CurrentQueue = triggeredNode
|
||||||
|
}
|
||||||
|
ctx = mq.SetHeaders(ctx, map[string]string{mq.TriggerNode: task.CurrentQueue})
|
||||||
edges, exists := d.edges[task.CurrentQueue]
|
edges, exists := d.edges[task.CurrentQueue]
|
||||||
if exists {
|
if exists {
|
||||||
|
d.taskResults[task.ID] = map[string]*taskContext{
|
||||||
|
task.CurrentQueue: {
|
||||||
|
totalItems: 1,
|
||||||
|
nodeType: "edge",
|
||||||
|
},
|
||||||
|
}
|
||||||
for _, edge := range edges {
|
for _, edge := range edges {
|
||||||
_, err := d.PublishTask(ctx, task.Result, edge, task.ID)
|
_, err := d.PublishTask(ctx, payload, edge, task.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if completed {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
if resultCh, ok := d.taskChMap[task.ID]; ok {
|
if resultCh, ok := d.taskChMap[task.ID]; ok {
|
||||||
resultCh <- mq.Result{
|
resultCh <- mq.Result{
|
||||||
Command: "complete",
|
Command: "complete",
|
||||||
Payload: task.Result,
|
Payload: payload,
|
||||||
Queue: task.CurrentQueue,
|
Queue: task.CurrentQueue,
|
||||||
MessageID: task.ID,
|
MessageID: task.ID,
|
||||||
Status: "done",
|
Status: "done",
|
||||||
@@ -138,47 +178,3 @@ func (d *DAG) TaskCallback(ctx context.Context, task *mq.Task) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DAG) waitForLoopCompletion(ctx context.Context, taskID string, currentQueue string) {
|
|
||||||
d.mu.Lock()
|
|
||||||
loopCtx := d.loopTaskMap[taskID]
|
|
||||||
d.mu.Unlock()
|
|
||||||
for result := range loopCtx.subResultCh {
|
|
||||||
loopCtx.results = append(loopCtx.results, result.Payload)
|
|
||||||
loopCtx.completed++
|
|
||||||
if loopCtx.completed == loopCtx.totalItems {
|
|
||||||
close(loopCtx.subResultCh)
|
|
||||||
aggregatedResult, err := json.Marshal(loopCtx.results)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error aggregating results: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d.mu.Lock()
|
|
||||||
delete(d.loopTaskMap, taskID)
|
|
||||||
d.mu.Unlock()
|
|
||||||
edges, exists := d.edges[currentQueue]
|
|
||||||
if exists {
|
|
||||||
for _, edge := range edges {
|
|
||||||
_, err := d.PublishTask(ctx, aggregatedResult, edge, taskID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error publishing aggregated result: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
d.mu.Lock()
|
|
||||||
if resultCh, ok := d.taskChMap[taskID]; ok {
|
|
||||||
resultCh <- mq.Result{
|
|
||||||
Command: "complete",
|
|
||||||
Payload: aggregatedResult,
|
|
||||||
Queue: currentQueue,
|
|
||||||
MessageID: taskID,
|
|
||||||
Status: "done",
|
|
||||||
}
|
|
||||||
delete(d.taskChMap, taskID)
|
|
||||||
}
|
|
||||||
d.mu.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user