mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-27 04:15:52 +08:00
375 lines
10 KiB
Go
375 lines
10 KiB
Go
package v1
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/oarkflow/mq/consts"
|
|
"github.com/oarkflow/mq/storage"
|
|
"github.com/oarkflow/mq/storage/memory"
|
|
|
|
"github.com/oarkflow/mq"
|
|
)
|
|
|
|
type TaskManager struct {
|
|
createdAt time.Time
|
|
processedAt time.Time
|
|
status string
|
|
dag *DAG
|
|
taskID string
|
|
wg *WaitGroup
|
|
topic string
|
|
result mq.Result
|
|
|
|
iteratorNodes storage.IMap[string, []Edge]
|
|
taskNodeStatus storage.IMap[string, *taskNodeStatus]
|
|
}
|
|
|
|
func NewTaskManager(d *DAG, taskID string, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
|
|
return &TaskManager{
|
|
dag: d,
|
|
taskNodeStatus: memory.New[string, *taskNodeStatus](),
|
|
taskID: taskID,
|
|
iteratorNodes: iteratorNodes,
|
|
wg: NewWaitGroup(),
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
|
|
tm.updateTS(&tm.result)
|
|
tm.dag.callbackToConsumer(ctx, tm.result)
|
|
if tm.dag.server.NotifyHandler() != nil {
|
|
_ = tm.dag.server.NotifyHandler()(ctx, tm.result)
|
|
}
|
|
tm.dag.taskCleanupCh <- tm.taskID
|
|
tm.topic = tm.result.Topic
|
|
return tm.result
|
|
}
|
|
|
|
func (tm *TaskManager) reportNodeResult(result mq.Result, final bool) {
|
|
if tm.dag.reportNodeResultCallback != nil {
|
|
tm.dag.reportNodeResultCallback(result)
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) SetTotalItems(topic string, i int) {
|
|
if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok {
|
|
nodeStatus.totalItems = i
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
|
|
topic := getTopic(ctx, node.Key)
|
|
tm.taskNodeStatus.Set(topic, newNodeStatus(topic))
|
|
defer mq.RecoverPanic(mq.RecoverTitle)
|
|
dag, isDAG := isDAGNode(node)
|
|
if isDAG {
|
|
if tm.dag.server.SyncMode() && !dag.server.SyncMode() {
|
|
dag.server.Options().SetSyncMode(true)
|
|
}
|
|
}
|
|
tm.ChangeNodeStatus(ctx, node.Key, Processing, mq.Result{Payload: payload, Topic: node.Key})
|
|
var result mq.Result
|
|
if tm.dag.server.SyncMode() {
|
|
defer func() {
|
|
if isDAG {
|
|
result.Topic = dag.consumerTopic
|
|
result.TaskID = tm.taskID
|
|
tm.reportNodeResult(result, false)
|
|
tm.handleNextTask(result.Ctx, result)
|
|
} else {
|
|
result.Topic = node.Key
|
|
tm.reportNodeResult(result, false)
|
|
tm.handleNextTask(ctx, result)
|
|
}
|
|
}()
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
|
|
tm.reportNodeResult(result, true)
|
|
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
|
return
|
|
default:
|
|
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
|
|
if tm.dag.server.SyncMode() {
|
|
result = node.ProcessTask(ctx, mq.NewTask(tm.taskID, payload, node.Key))
|
|
if result.Error != nil {
|
|
tm.reportNodeResult(result, true)
|
|
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
|
|
if err != nil {
|
|
tm.reportNodeResult(mq.Result{Error: err}, true)
|
|
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
|
|
defer mq.RecoverPanic(mq.RecoverTitle)
|
|
node, ok := tm.dag.nodes[nodeID]
|
|
if !ok {
|
|
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
|
|
}
|
|
if tm.createdAt.IsZero() {
|
|
tm.createdAt = time.Now()
|
|
}
|
|
tm.wg.Add(1)
|
|
go func() {
|
|
ctxx := context.Background()
|
|
if headers, ok := mq.GetHeaders(ctx); ok {
|
|
headers.Set(consts.QueueKey, node.Key)
|
|
headers.Set("index", fmt.Sprintf("%s__%d", node.Key, 0))
|
|
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
|
}
|
|
go tm.processNode(ctx, node, payload)
|
|
}()
|
|
tm.wg.Wait()
|
|
requestType, ok := mq.GetHeader(ctx, "request_type")
|
|
if ok && requestType == "render" {
|
|
return tm.renderResult(ctx)
|
|
}
|
|
return tm.dispatchFinalResult(ctx)
|
|
}
|
|
|
|
func (tm *TaskManager) handleNextTask(ctx context.Context, result mq.Result) mq.Result {
|
|
tm.topic = result.Topic
|
|
defer func() {
|
|
tm.wg.Done()
|
|
mq.RecoverPanic(mq.RecoverTitle)
|
|
}()
|
|
if result.Ctx != nil {
|
|
if headers, ok := mq.GetHeaders(ctx); ok {
|
|
ctx = mq.SetHeaders(result.Ctx, headers.AsMap())
|
|
}
|
|
}
|
|
node, ok := tm.dag.nodes[result.Topic]
|
|
if !ok {
|
|
return result
|
|
}
|
|
if result.Error != nil {
|
|
tm.reportNodeResult(result, true)
|
|
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
|
return result
|
|
}
|
|
edges := tm.getConditionalEdges(node, result)
|
|
if len(edges) == 0 {
|
|
tm.reportNodeResult(result, true)
|
|
tm.ChangeNodeStatus(ctx, node.Key, Completed, result)
|
|
return result
|
|
} else {
|
|
tm.reportNodeResult(result, false)
|
|
}
|
|
if node.Type == Page {
|
|
return result
|
|
}
|
|
for _, edge := range edges {
|
|
switch edge.Type {
|
|
case Iterator:
|
|
var items []json.RawMessage
|
|
err := json.Unmarshal(result.Payload, &items)
|
|
if err != nil {
|
|
tm.reportNodeResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
|
|
result.Error = err
|
|
tm.ChangeNodeStatus(ctx, node.Key, Failed, result)
|
|
return result
|
|
}
|
|
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(items)*len(edge.To))
|
|
for _, target := range edge.To {
|
|
for i, item := range items {
|
|
tm.wg.Add(1)
|
|
go func(ctx context.Context, target *Node, item json.RawMessage, i int) {
|
|
ctxx := context.Background()
|
|
if headers, ok := mq.GetHeaders(ctx); ok {
|
|
headers.Set(consts.QueueKey, target.Key)
|
|
headers.Set("index", fmt.Sprintf("%s__%d", target.Key, i))
|
|
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
|
}
|
|
tm.processNode(ctxx, target, item)
|
|
}(ctx, target, item, i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for _, edge := range edges {
|
|
switch edge.Type {
|
|
case Simple:
|
|
if _, ok := tm.iteratorNodes.Get(edge.From.Key); ok {
|
|
continue
|
|
}
|
|
tm.processEdge(ctx, edge, result)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (tm *TaskManager) processEdge(ctx context.Context, edge Edge, result mq.Result) {
|
|
tm.SetTotalItems(getTopic(ctx, edge.From.Key), len(edge.To))
|
|
index, _ := mq.GetHeader(ctx, "index")
|
|
if index != "" && strings.Contains(index, "__") {
|
|
index = strings.Split(index, "__")[1]
|
|
} else {
|
|
index = "0"
|
|
}
|
|
for _, target := range edge.To {
|
|
tm.wg.Add(1)
|
|
go func(ctx context.Context, target *Node, result mq.Result) {
|
|
ctxx := context.Background()
|
|
if headers, ok := mq.GetHeaders(ctx); ok {
|
|
headers.Set(consts.QueueKey, target.Key)
|
|
headers.Set("index", fmt.Sprintf("%s__%s", target.Key, index))
|
|
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
|
}
|
|
tm.processNode(ctxx, target, result.Payload)
|
|
}(ctx, target, result)
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge {
|
|
edges := make([]Edge, len(node.Edges))
|
|
copy(edges, node.Edges)
|
|
if result.ConditionStatus != "" {
|
|
if conditions, ok := tm.dag.conditions[FromNode(result.Topic)]; ok {
|
|
if targetNodeKey, ok := conditions[When(result.ConditionStatus)]; ok {
|
|
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
|
|
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
|
|
}
|
|
} else if targetNodeKey, ok = conditions["default"]; ok {
|
|
if targetNode, ok := tm.dag.nodes[string(targetNodeKey)]; ok {
|
|
edges = append(edges, Edge{From: node, To: []*Node{targetNode}})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return edges
|
|
}
|
|
|
|
func (tm *TaskManager) renderResult(ctx context.Context) mq.Result {
|
|
var rs mq.Result
|
|
tm.updateTS(&rs)
|
|
tm.dag.callbackToConsumer(ctx, rs)
|
|
tm.topic = rs.Topic
|
|
return rs
|
|
}
|
|
|
|
func (tm *TaskManager) ChangeNodeStatus(ctx context.Context, nodeID string, status NodeStatus, rs mq.Result) {
|
|
topic := nodeID
|
|
if !strings.Contains(nodeID, "__") {
|
|
nodeID = getTopic(ctx, nodeID)
|
|
} else {
|
|
topic = strings.Split(nodeID, "__")[0]
|
|
}
|
|
nodeStatus, ok := tm.taskNodeStatus.Get(nodeID)
|
|
if !ok || nodeStatus == nil {
|
|
return
|
|
}
|
|
|
|
nodeStatus.markAs(rs, status)
|
|
switch status {
|
|
case Completed:
|
|
canProceed := false
|
|
edges, ok := tm.iteratorNodes.Get(topic)
|
|
if ok {
|
|
if len(edges) == 0 {
|
|
canProceed = true
|
|
} else {
|
|
nodeStatus.status = Processing
|
|
nodeStatus.totalItems = 1
|
|
nodeStatus.itemResults.Clear()
|
|
for _, edge := range edges {
|
|
tm.processEdge(ctx, edge, rs)
|
|
}
|
|
tm.iteratorNodes.Del(topic)
|
|
}
|
|
}
|
|
if canProceed || !ok {
|
|
if topic == tm.dag.startNode {
|
|
tm.result = rs
|
|
} else {
|
|
tm.markParentTask(ctx, topic, nodeID, status, rs)
|
|
}
|
|
}
|
|
case Failed:
|
|
if topic == tm.dag.startNode {
|
|
tm.result = rs
|
|
} else {
|
|
tm.markParentTask(ctx, topic, nodeID, status, rs)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) markParentTask(ctx context.Context, topic, nodeID string, status NodeStatus, rs mq.Result) {
|
|
parentNodes, err := tm.dag.GetPreviousNodes(topic)
|
|
if err != nil {
|
|
return
|
|
}
|
|
var index string
|
|
nodeParts := strings.Split(nodeID, "__")
|
|
if len(nodeParts) == 2 {
|
|
index = nodeParts[1]
|
|
}
|
|
for _, parentNode := range parentNodes {
|
|
parentKey := fmt.Sprintf("%s__%s", parentNode.Key, index)
|
|
parentNodeStatus, exists := tm.taskNodeStatus.Get(parentKey)
|
|
if !exists {
|
|
parentKey = fmt.Sprintf("%s__%s", parentNode.Key, "0")
|
|
parentNodeStatus, exists = tm.taskNodeStatus.Get(parentKey)
|
|
}
|
|
if exists {
|
|
parentNodeStatus.itemResults.Set(nodeID, rs)
|
|
if parentNodeStatus.IsDone() {
|
|
rt := tm.prepareResult(ctx, parentNodeStatus)
|
|
tm.ChangeNodeStatus(ctx, parentKey, status, rt)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) prepareResult(ctx context.Context, nodeStatus *taskNodeStatus) mq.Result {
|
|
aggregatedOutput := make([]json.RawMessage, 0)
|
|
var status mq.Status
|
|
var topic string
|
|
var err1 error
|
|
if nodeStatus.totalItems == 1 {
|
|
rs := nodeStatus.itemResults.Values()[0]
|
|
if rs.Ctx == nil {
|
|
rs.Ctx = ctx
|
|
}
|
|
return rs
|
|
}
|
|
nodeStatus.itemResults.ForEach(func(key string, result mq.Result) bool {
|
|
if topic == "" {
|
|
topic = result.Topic
|
|
status = result.Status
|
|
}
|
|
if result.Error != nil {
|
|
err1 = result.Error
|
|
return false
|
|
}
|
|
var item json.RawMessage
|
|
err := json.Unmarshal(result.Payload, &item)
|
|
if err != nil {
|
|
err1 = err
|
|
return false
|
|
}
|
|
aggregatedOutput = append(aggregatedOutput, item)
|
|
return true
|
|
})
|
|
if err1 != nil {
|
|
return mq.HandleError(ctx, err1)
|
|
}
|
|
finalOutput, err := json.Marshal(aggregatedOutput)
|
|
if err != nil {
|
|
return mq.HandleError(ctx, err)
|
|
}
|
|
return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx}
|
|
}
|