mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-05 07:57:00 +08:00
update: dependencies
This commit is contained in:
@@ -4,247 +4,307 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq"
|
||||
|
||||
"github.com/oarkflow/mq/storage"
|
||||
"github.com/oarkflow/mq/storage/memory"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
type TaskState struct {
|
||||
NodeID string
|
||||
Status mq.Status
|
||||
UpdatedAt time.Time
|
||||
Result mq.Result
|
||||
targetResults storage.IMap[string, mq.Result]
|
||||
}
|
||||
|
||||
func newTaskState(nodeID string) *TaskState {
|
||||
return &TaskState{
|
||||
NodeID: nodeID,
|
||||
Status: mq.Pending,
|
||||
UpdatedAt: time.Now(),
|
||||
targetResults: memory.New[string, mq.Result](),
|
||||
}
|
||||
}
|
||||
|
||||
type nodeResult struct {
|
||||
ctx context.Context
|
||||
nodeID string
|
||||
status mq.Status
|
||||
result mq.Result
|
||||
}
|
||||
|
||||
type TaskManager struct {
|
||||
createdAt time.Time
|
||||
processedAt time.Time
|
||||
status string
|
||||
dag *DAG
|
||||
taskID string
|
||||
wg *WaitGroup
|
||||
topic string
|
||||
result mq.Result
|
||||
|
||||
iteratorNodes storage.IMap[string, []Edge]
|
||||
taskNodeStatus storage.IMap[string, *taskNodeStatus]
|
||||
taskStates storage.IMap[string, *TaskState]
|
||||
parentNodes storage.IMap[string, string]
|
||||
childNodes storage.IMap[string, int]
|
||||
deferredTasks storage.IMap[string, *task]
|
||||
iteratorNodes storage.IMap[string, []Edge]
|
||||
currentNodePayload storage.IMap[string, json.RawMessage]
|
||||
currentNodeResult storage.IMap[string, mq.Result]
|
||||
result *mq.Result
|
||||
dag *DAG
|
||||
taskID string
|
||||
taskQueue chan *task
|
||||
resultQueue chan nodeResult
|
||||
resultCh chan mq.Result
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
|
||||
return &TaskManager{
|
||||
dag: d,
|
||||
taskNodeStatus: memory.New[string, *taskNodeStatus](),
|
||||
taskID: taskID,
|
||||
iteratorNodes: iteratorNodes,
|
||||
wg: NewWaitGroup(),
|
||||
type task struct {
|
||||
ctx context.Context
|
||||
taskID string
|
||||
nodeID string
|
||||
payload json.RawMessage
|
||||
}
|
||||
|
||||
func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *task {
|
||||
return &task{
|
||||
ctx: ctx,
|
||||
taskID: taskID,
|
||||
nodeID: nodeID,
|
||||
payload: payload,
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
|
||||
tm.updateTS(&tm.result)
|
||||
tm.dag.callbackToConsumer(ctx, tm.result)
|
||||
if tm.dag.server.NotifyHandler() != nil {
|
||||
_ = tm.dag.server.NotifyHandler()(ctx, tm.result)
|
||||
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
|
||||
tm := &TaskManager{
|
||||
taskStates: memory.New[string, *TaskState](),
|
||||
parentNodes: memory.New[string, string](),
|
||||
childNodes: memory.New[string, int](),
|
||||
deferredTasks: memory.New[string, *task](),
|
||||
currentNodePayload: memory.New[string, json.RawMessage](),
|
||||
currentNodeResult: memory.New[string, mq.Result](),
|
||||
taskQueue: make(chan *task, DefaultChannelSize),
|
||||
resultQueue: make(chan nodeResult, DefaultChannelSize),
|
||||
iteratorNodes: iteratorNodes,
|
||||
stopCh: make(chan struct{}),
|
||||
resultCh: resultCh,
|
||||
taskID: taskID,
|
||||
dag: dag,
|
||||
}
|
||||
tm.dag.taskCleanupCh <- tm.taskID
|
||||
tm.topic = tm.result.Topic
|
||||
return tm.result
|
||||
go tm.run()
|
||||
go tm.waitForResult()
|
||||
return tm
|
||||
}
|
||||
|
||||
func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) {
|
||||
if tm.dag.reportNodeResultCallback != nil {
|
||||
tm.dag.reportNodeResultCallback(result)
|
||||
}
|
||||
func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payload json.RawMessage) {
|
||||
tm.send(ctx, startNode, tm.taskID, payload)
|
||||
}
|
||||
|
||||
func (tm *TaskManager) SetTotalItems(topic string, i int) {
|
||||
if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok {
|
||||
nodeStatus.totalItems = i
|
||||
func (tm *TaskManager) send(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
|
||||
if index, ok := ctx.Value(ContextIndex).(string); ok {
|
||||
startNode = strings.Split(startNode, Delimiter)[0]
|
||||
startNode = fmt.Sprintf("%s%s%s", startNode, Delimiter, index)
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
|
||||
topic := getTopic(ctx, node.Key)
|
||||
tm.taskNodeStatus.Set(topic, newNodeStatus(topic))
|
||||
defer mq.RecoverPanic(mq.RecoverTitle)
|
||||
dag, isDAG := isDAGNode(node)
|
||||
if isDAG {
|
||||
if tm.dag.server.SyncMode() && !dag.server.SyncMode() {
|
||||
dag.server.Options().SetSyncMode(true)
|
||||
}
|
||||
}
|
||||
tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key})
|
||||
var result mq.Result
|
||||
if tm.dag.server.SyncMode() {
|
||||
defer func() {
|
||||
if isDAG {
|
||||
result.Topic = dag.consumerTopic
|
||||
result.TaskID = tm.taskID
|
||||
tm.reportNodeResult(result, false)
|
||||
tm.handleNextTask(result.Ctx, result)
|
||||
} else {
|
||||
result.Topic = node.Key
|
||||
tm.reportNodeResult(result, false)
|
||||
tm.handleNextTask(ctx, result)
|
||||
}
|
||||
}()
|
||||
if _, exists := tm.taskStates.Get(startNode); !exists {
|
||||
tm.taskStates.Set(startNode, newTaskState(startNode))
|
||||
}
|
||||
t := newTask(ctx, taskID, startNode, payload)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
|
||||
tm.reportNodeResult(result, true)
|
||||
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
||||
return
|
||||
case tm.taskQueue <- t:
|
||||
default:
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
|
||||
if tm.dag.server.SyncMode() {
|
||||
result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key))
|
||||
if result.Error != nil {
|
||||
tm.reportNodeResult(result, true)
|
||||
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
||||
return
|
||||
log.Println("task queue is full, dropping task.")
|
||||
tm.deferredTasks.Set(taskID, t)
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) run() {
|
||||
for {
|
||||
select {
|
||||
case <-tm.stopCh:
|
||||
log.Println("Stopping TaskManager")
|
||||
return
|
||||
case task := <-tm.taskQueue:
|
||||
tm.processNode(task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) waitForResult() {
|
||||
for {
|
||||
select {
|
||||
case <-tm.stopCh:
|
||||
log.Println("Stopping Result Listener")
|
||||
return
|
||||
case nr := <-tm.resultQueue:
|
||||
tm.onNodeCompleted(nr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) processNode(exec *task) {
|
||||
pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
|
||||
node, exists := tm.dag.nodes.Get(pureNodeID)
|
||||
if !exists {
|
||||
log.Printf("Node %s does not exist while processing node\n", pureNodeID)
|
||||
return
|
||||
}
|
||||
state, _ := tm.taskStates.Get(exec.nodeID)
|
||||
if state == nil {
|
||||
log.Printf("State for node %s not found; creating new state.\n", exec.nodeID)
|
||||
state = newTaskState(exec.nodeID)
|
||||
tm.taskStates.Set(exec.nodeID, state)
|
||||
}
|
||||
state.Status = mq.Processing
|
||||
state.UpdatedAt = time.Now()
|
||||
tm.currentNodePayload.Clear()
|
||||
tm.currentNodeResult.Clear()
|
||||
tm.currentNodePayload.Set(exec.nodeID, exec.payload)
|
||||
result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
|
||||
tm.currentNodeResult.Set(exec.nodeID, result)
|
||||
state.Result = result
|
||||
result.Topic = node.ID
|
||||
if result.Error != nil {
|
||||
tm.result = &result
|
||||
tm.resultCh <- result
|
||||
tm.processFinalResult(state)
|
||||
return
|
||||
}
|
||||
if node.NodeType == Page {
|
||||
tm.result = &result
|
||||
tm.resultCh <- result
|
||||
return
|
||||
}
|
||||
tm.handleNext(exec.ctx, node, state, result)
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
|
||||
state.targetResults.Set(childNode, result)
|
||||
state.targetResults.Del(state.NodeID)
|
||||
targetsCount, _ := tm.childNodes.Get(state.NodeID)
|
||||
size := state.targetResults.Size()
|
||||
nodeID := strings.Split(state.NodeID, Delimiter)
|
||||
if size == targetsCount {
|
||||
if size > 1 {
|
||||
aggregatedData := make([]json.RawMessage, size)
|
||||
i := 0
|
||||
state.targetResults.ForEach(func(_ string, rs mq.Result) bool {
|
||||
aggregatedData[i] = rs.Payload
|
||||
i++
|
||||
return true
|
||||
})
|
||||
aggregatedPayload, err := json.Marshal(aggregatedData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
|
||||
if err != nil {
|
||||
tm.reportNodeResult(mq.Result{Error: err}, true)
|
||||
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
||||
return
|
||||
state.Result = mq.Result{Payload: aggregatedPayload, Status: mq.Completed, Ctx: ctx, Topic: state.NodeID}
|
||||
} else if size == 1 {
|
||||
state.Result = state.targetResults.Values()[0]
|
||||
}
|
||||
state.Status = result.Status
|
||||
state.Result.Status = result.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
|
||||
defer mq.RecoverPanic(mq.RecoverTitle)
|
||||
node, ok := tm.dag.nodes[nodeID]
|
||||
if !ok {
|
||||
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
|
||||
if state.Result.Payload == nil {
|
||||
state.Result.Payload = result.Payload
|
||||
}
|
||||
if tm.createdAt.IsZero() {
|
||||
tm.createdAt = time.Now()
|
||||
}
|
||||
tm.wg.Add(1)
|
||||
go func() {
|
||||
ctxx := context.Background()
|
||||
if headers, ok := mq.GetHeaders(ctx); ok {
|
||||
headers.Set(consts.QueueKey, node.Key)
|
||||
headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0))
|
||||
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
||||
}
|
||||
go tm.processNode(ctx, node, payload)
|
||||
}()
|
||||
tm.wg.Wait()
|
||||
requestType, ok := mq.GetHeader(ctx, "request_type")
|
||||
if ok && requestType == "render" {
|
||||
return tm.renderResult(ctx)
|
||||
}
|
||||
return tm.dispatchFinalResult(ctx)
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result {
|
||||
tm.topic = result.Topic
|
||||
defer func() {
|
||||
tm.wg.Done()
|
||||
mq.RecoverPanic(mq.RecoverTitle)
|
||||
}()
|
||||
if result.Ctx != nil {
|
||||
if headers, ok := mq.GetHeaders(ctx); ok {
|
||||
ctx = mq.SetHeaders(result.Ctx, headers.AsMap())
|
||||
}
|
||||
}
|
||||
node, ok := tm.dag.nodes[result.Topic]
|
||||
if !ok {
|
||||
return result
|
||||
state.UpdatedAt = time.Now()
|
||||
if result.Ctx == nil {
|
||||
result.Ctx = ctx
|
||||
}
|
||||
if result.Error != nil {
|
||||
tm.reportNodeResult(result, true)
|
||||
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
||||
return result
|
||||
state.Status = mq.Failed
|
||||
}
|
||||
edges := tm.getConditionalEdges(node, result)
|
||||
if len(edges) == 0 {
|
||||
tm.reportNodeResult(result, true)
|
||||
tm.ChangeNodeStatus(ctx, node.Key, Completed, result)
|
||||
return result
|
||||
} else {
|
||||
tm.reportNodeResult(result, false)
|
||||
}
|
||||
if node.Type == Page {
|
||||
return result
|
||||
}
|
||||
for _, edge := range edges {
|
||||
switch edge.Type {
|
||||
case Iterator:
|
||||
var items []json.RawMessage
|
||||
err := json.Unmarshal(result.Payload, &items)
|
||||
if err != nil {
|
||||
tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
|
||||
result.Error = err
|
||||
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
||||
return result
|
||||
pn, ok := tm.parentNodes.Get(state.NodeID)
|
||||
if edges, exists := tm.iteratorNodes.Get(nodeID[0]); exists && state.Status == mq.Completed {
|
||||
state.Status = mq.Processing
|
||||
tm.iteratorNodes.Del(nodeID[0])
|
||||
state.targetResults.Clear()
|
||||
if len(nodeID) == 2 {
|
||||
ctx = context.WithValue(ctx, ContextIndex, nodeID[1])
|
||||
}
|
||||
toProcess := nodeResult{
|
||||
ctx: ctx,
|
||||
nodeID: state.NodeID,
|
||||
status: state.Status,
|
||||
result: state.Result,
|
||||
}
|
||||
tm.handleEdges(toProcess, edges)
|
||||
} else if ok {
|
||||
if targetsCount == size {
|
||||
parentState, _ := tm.taskStates.Get(pn)
|
||||
if parentState != nil {
|
||||
state.Result.Topic = state.NodeID
|
||||
tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
|
||||
}
|
||||
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To))
|
||||
for _, target := range edge.To {
|
||||
for i, item := range items {
|
||||
tm.wg.Add(1)
|
||||
go func(ctx context.Context, target *Node, item json.RawMessage, i int) {
|
||||
ctxx := context.Background()
|
||||
if headers, ok := mq.GetHeaders(ctx); ok {
|
||||
headers.Set(consts.QueueKey, target.Key)
|
||||
headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i))
|
||||
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
||||
}
|
||||
tm.processNode(ctxx, target, item)
|
||||
}(ctx, target, item, i)
|
||||
}
|
||||
} else {
|
||||
tm.result = &state.Result
|
||||
state.Result.Topic = strings.Split(state.NodeID, Delimiter)[0]
|
||||
tm.resultCh <- state.Result
|
||||
tm.processFinalResult(state)
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handleNext(ctx context.Context, node *Node, state *TaskState, result mq.Result) {
|
||||
state.UpdatedAt = time.Now()
|
||||
if result.Ctx == nil {
|
||||
result.Ctx = ctx
|
||||
}
|
||||
if result.Error != nil {
|
||||
state.Status = mq.Failed
|
||||
} else {
|
||||
edges := tm.getConditionalEdges(node, result)
|
||||
if len(edges) == 0 {
|
||||
state.Status = mq.Completed
|
||||
}
|
||||
}
|
||||
if result.Status == "" {
|
||||
result.Status = state.Status
|
||||
}
|
||||
select {
|
||||
case tm.resultQueue <- nodeResult{
|
||||
ctx: ctx,
|
||||
nodeID: state.NodeID,
|
||||
result: result,
|
||||
status: state.Status,
|
||||
}:
|
||||
default:
|
||||
log.Println("Result queue is full, dropping result.")
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) onNodeCompleted(rs nodeResult) {
|
||||
nodeID := strings.Split(rs.nodeID, Delimiter)[0]
|
||||
node, ok := tm.dag.nodes.Get(nodeID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
edges := tm.getConditionalEdges(node, rs.result)
|
||||
hasErrorOrCompleted := rs.result.Error != nil || len(edges) == 0
|
||||
if hasErrorOrCompleted {
|
||||
if index, ok := rs.ctx.Value(ContextIndex).(string); ok {
|
||||
childNode := fmt.Sprintf("%s%s%s", node.ID, Delimiter, index)
|
||||
pn, ok := tm.parentNodes.Get(childNode)
|
||||
if ok {
|
||||
parentState, _ := tm.taskStates.Get(pn)
|
||||
if parentState != nil {
|
||||
pn = strings.Split(pn, Delimiter)[0]
|
||||
tm.handlePrevious(rs.ctx, parentState, rs.result, rs.nodeID, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, edge := range edges {
|
||||
switch edge.Type {
|
||||
case Simple:
|
||||
if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok {
|
||||
continue
|
||||
}
|
||||
tm.processEdge(ctx, edge, result)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) {
|
||||
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To))
|
||||
index, _ := mq.GetHeader(ctx, "index")
|
||||
if index != "" && strings.Contains(index, "__") {
|
||||
index = strings.Split(index, "__")[1]
|
||||
} else {
|
||||
index = "0"
|
||||
}
|
||||
for _, target := range edge.To {
|
||||
tm.wg.Add(1)
|
||||
go func(ctx context.Context, target *Node, result mq.Result) {
|
||||
ctxx := context.Background()
|
||||
if headers, ok := mq.GetHeaders(ctx); ok {
|
||||
headers.Set(consts.QueueKey, target.Key)
|
||||
headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index))
|
||||
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
||||
}
|
||||
tm.processNode(ctxx, target, result.Payload)
|
||||
}(ctx, target, result)
|
||||
}
|
||||
tm.handleEdges(rs, edges)
|
||||
}
|
||||
|
||||
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
|
||||
edges := make([]Edge, len(node.Edges))
|
||||
copy(edges, node.Edges)
|
||||
if result.ConditionStatus != "" {
|
||||
if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok {
|
||||
if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok {
|
||||
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
|
||||
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
|
||||
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
|
||||
if targetNodeKey, ok := conditions[result.ConditionStatus]; ok {
|
||||
if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
|
||||
edges = append(edges, Edge{From: node, To: targetNode})
|
||||
}
|
||||
} else if targetNodeKey, ok = conditions["default"]; ok {
|
||||
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
|
||||
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
|
||||
if targetNode, ok := tm.dag.nodes.Get(targetNodeKey); ok {
|
||||
edges = append(edges, Edge{From: node, To: targetNode})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -252,123 +312,77 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge
|
||||
return edges
|
||||
}
|
||||
|
||||
func (tm *TaskManager) renderResult(ctx context.Context) mq.Result {
|
||||
var rs mq.Result
|
||||
tm.updateTS(&rs)
|
||||
tm.dag.callbackToConsumer(ctx, rs)
|
||||
tm.topic = rs.Topic
|
||||
return rs
|
||||
}
|
||||
|
||||
func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) {
|
||||
topic := nodeID
|
||||
if !strings.Contains(nodeID, "__") {
|
||||
nodeID = getTopic(ctx, nodeID)
|
||||
} else {
|
||||
topic = strings.Split(nodeID, "__")[0]
|
||||
}
|
||||
nodeStatus, ok := tm.taskNodeStatus.Get(nodeID)
|
||||
if !ok || nodeStatus == nil {
|
||||
return
|
||||
}
|
||||
|
||||
nodeStatus.markAs(rs, status)
|
||||
switch status {
|
||||
case Completed:
|
||||
canProceed := false
|
||||
edges, ok := tm.iteratorNodes.Get(topic)
|
||||
if ok {
|
||||
if len(edges) == 0 {
|
||||
canProceed = true
|
||||
} else {
|
||||
nodeStatus.status = Processing
|
||||
nodeStatus.totalItems = 1
|
||||
nodeStatus.itemResults.Clear()
|
||||
for _, edge := range edges {
|
||||
tm.processEdge(ctx, edge, rs)
|
||||
func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
|
||||
for _, edge := range edges {
|
||||
index, ok := currentResult.ctx.Value(ContextIndex).(string)
|
||||
if !ok {
|
||||
index = "0"
|
||||
}
|
||||
parentNode := fmt.Sprintf("%s%s%s", edge.From.ID, Delimiter, index)
|
||||
if edge.Type == Simple {
|
||||
if _, ok := tm.iteratorNodes.Get(edge.From.ID); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if edge.Type == Iterator {
|
||||
var items []json.RawMessage
|
||||
err := json.Unmarshal(currentResult.result.Payload, &items)
|
||||
if err != nil {
|
||||
log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err)
|
||||
tm.resultQueue <- nodeResult{
|
||||
ctx: currentResult.ctx,
|
||||
nodeID: edge.To.ID,
|
||||
status: mq.Failed,
|
||||
result: mq.Result{Error: err},
|
||||
}
|
||||
tm.iteratorNodes.Del(topic)
|
||||
return
|
||||
}
|
||||
}
|
||||
if canProceed || !ok {
|
||||
if topic == tm.dag.startNode {
|
||||
tm.result = rs
|
||||
} else {
|
||||
tm.markParentTask(ctx, topic, nodeID, status, rs)
|
||||
tm.childNodes.Set(parentNode, len(items))
|
||||
for i, item := range items {
|
||||
childNode := fmt.Sprintf("%s%s%d", edge.To.ID, Delimiter, i)
|
||||
ctx := context.WithValue(currentResult.ctx, ContextIndex, fmt.Sprintf("%d", i))
|
||||
tm.parentNodes.Set(childNode, parentNode)
|
||||
tm.send(ctx, edge.To.ID, tm.taskID, item)
|
||||
}
|
||||
}
|
||||
case Failed:
|
||||
if topic == tm.dag.startNode {
|
||||
tm.result = rs
|
||||
} else {
|
||||
tm.markParentTask(ctx, topic, nodeID, status, rs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) {
|
||||
parentNodes, err := tm.dag.GetPreviousNodes(topic)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var index string
|
||||
nodeParts := strings.Split(nodeID, "__")
|
||||
if len(nodeParts) == 2 {
|
||||
index = nodeParts[1]
|
||||
}
|
||||
for _, parentNode := range parentNodes {
|
||||
parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index)
|
||||
parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey)
|
||||
if !exists {
|
||||
parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0")
|
||||
parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey)
|
||||
}
|
||||
if exists {
|
||||
parentNodeStatus.itemResults.Set(nodeID, rs)
|
||||
if parentNodeStatus.IsDone() {
|
||||
rt := tm.prepareResult(ctx, parentNodeStatus)
|
||||
tm.ChangeNodeStatus(ctx, parentKey, status, rt)
|
||||
tm.childNodes.Set(parentNode, 1)
|
||||
idx, ok := currentResult.ctx.Value(ContextIndex).(string)
|
||||
if !ok {
|
||||
idx = "0"
|
||||
}
|
||||
childNode := fmt.Sprintf("%s%s%s", edge.To.ID, Delimiter, idx)
|
||||
ctx := context.WithValue(currentResult.ctx, ContextIndex, idx)
|
||||
tm.parentNodes.Set(childNode, parentNode)
|
||||
tm.send(ctx, edge.To.ID, tm.taskID, currentResult.result.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result {
|
||||
aggregatedOutput := make([]json.RawMessage, 0)
|
||||
var status mq.Status
|
||||
var topic string
|
||||
var err1 error
|
||||
if nodeStatus.totalItems == 1 {
|
||||
rs := nodeStatus.itemResults.Values()[0]
|
||||
if rs.Ctx == nil {
|
||||
rs.Ctx = ctx
|
||||
func (tm *TaskManager) retryDeferredTasks() {
|
||||
const maxRetries = 5
|
||||
retries := 0
|
||||
for retries < maxRetries {
|
||||
select {
|
||||
case <-tm.stopCh:
|
||||
log.Println("Stopping Deferred task Retrier")
|
||||
return
|
||||
case <-time.After(RetryInterval):
|
||||
tm.deferredTasks.ForEach(func(taskID string, task *task) bool {
|
||||
tm.send(task.ctx, task.nodeID, taskID, task.payload)
|
||||
retries++
|
||||
return true
|
||||
})
|
||||
}
|
||||
return rs
|
||||
}
|
||||
nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool {
|
||||
if topic == "" {
|
||||
topic = result.Topic
|
||||
status = result.Status
|
||||
}
|
||||
if result.Error != nil {
|
||||
err1 = result.Error
|
||||
return false
|
||||
}
|
||||
var item json.RawMessage
|
||||
err := json.Unmarshal(result.Payload, &item)
|
||||
if err != nil {
|
||||
err1 = err
|
||||
return false
|
||||
}
|
||||
aggregatedOutput = append(aggregatedOutput, item)
|
||||
return true
|
||||
})
|
||||
if err1 != nil {
|
||||
return mq.HandleError(ctx, err1)
|
||||
}
|
||||
finalOutput, err := json.Marshal(aggregatedOutput)
|
||||
if err != nil {
|
||||
return mq.HandleError(ctx, err)
|
||||
}
|
||||
return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) processFinalResult(state *TaskState) {
|
||||
state.targetResults.Clear()
|
||||
if tm.dag.finalResult != nil {
|
||||
tm.dag.finalResult(tm.taskID, state.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) Stop() {
|
||||
close(tm.stopCh)
|
||||
}
|
||||
|
Reference in New Issue
Block a user