mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-06 00:16:49 +08:00
188 lines
4.5 KiB
Go
188 lines
4.5 KiB
Go
package v2
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/exp/maps"
|
|
)
|
|
|
|
type TaskState struct {
|
|
NodeID string
|
|
Status TaskStatus
|
|
Timestamp time.Time
|
|
Result Result
|
|
targetResults map[string]Result
|
|
my sync.Mutex
|
|
}
|
|
|
|
type nodeResult struct {
|
|
taskID string
|
|
nodeID string
|
|
result Result
|
|
}
|
|
|
|
type TaskManager struct {
|
|
taskStates map[string]*TaskState
|
|
dag *DAG
|
|
mu sync.Mutex
|
|
taskQueue chan taskExecution
|
|
resultQueue chan nodeResult
|
|
}
|
|
|
|
type taskExecution struct {
|
|
taskID string
|
|
nodeID string
|
|
payload json.RawMessage
|
|
}
|
|
|
|
func NewTaskManager(dag *DAG) *TaskManager {
|
|
tm := &TaskManager{
|
|
taskStates: make(map[string]*TaskState),
|
|
taskQueue: make(chan taskExecution, 100),
|
|
resultQueue: make(chan nodeResult, 100),
|
|
dag: dag,
|
|
}
|
|
go tm.WaitForResult()
|
|
return tm
|
|
}
|
|
|
|
func (tm *TaskManager) Trigger(taskID, startNode string, payload json.RawMessage) {
|
|
tm.mu.Lock()
|
|
tm.taskStates[startNode] = newTaskState(startNode)
|
|
tm.mu.Unlock()
|
|
tm.taskQueue <- taskExecution{taskID: taskID, nodeID: startNode, payload: payload}
|
|
}
|
|
|
|
func newTaskState(nodeID string) *TaskState {
|
|
return &TaskState{
|
|
NodeID: nodeID,
|
|
Status: StatusPending,
|
|
Timestamp: time.Now(),
|
|
targetResults: make(map[string]Result),
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) Run() {
|
|
go func() {
|
|
for task := range tm.taskQueue {
|
|
tm.processNode(task)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (tm *TaskManager) processNode(exec taskExecution) {
|
|
node, exists := tm.dag.nodes.Get(exec.nodeID)
|
|
if !exists {
|
|
fmt.Printf("Node %s does not exist\n", exec.nodeID)
|
|
return
|
|
}
|
|
tm.mu.Lock()
|
|
state := tm.taskStates[exec.nodeID]
|
|
if state == nil {
|
|
state = newTaskState(exec.nodeID)
|
|
tm.taskStates[exec.nodeID] = state
|
|
}
|
|
state.Status = StatusProcessing
|
|
state.Timestamp = time.Now()
|
|
tm.mu.Unlock()
|
|
result := node.Handler(exec.payload)
|
|
tm.mu.Lock()
|
|
state.Timestamp = time.Now()
|
|
state.Result = result
|
|
state.Status = result.Status
|
|
tm.mu.Unlock()
|
|
if result.Status == StatusFailed {
|
|
fmt.Printf("Task %s failed at node %s: %v\n", exec.taskID, exec.nodeID, result.Error)
|
|
tm.processFinalResult(exec.taskID, state)
|
|
return
|
|
}
|
|
tm.resultQueue <- nodeResult{taskID: exec.taskID, nodeID: exec.nodeID, result: result}
|
|
}
|
|
|
|
func (tm *TaskManager) WaitForResult() {
|
|
go func() {
|
|
for nr := range tm.resultQueue {
|
|
tm.onNodeCompleted(nr)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) {
|
|
node, ok := tm.dag.nodes.Get(nodeResult.nodeID)
|
|
if !ok {
|
|
return
|
|
}
|
|
if len(node.Edges) > 0 {
|
|
for _, edge := range node.Edges {
|
|
tm.mu.Lock()
|
|
if _, exists := tm.taskStates[edge.To.ID]; !exists {
|
|
tm.taskStates[edge.To.ID] = newTaskState(edge.To.ID)
|
|
}
|
|
tm.mu.Unlock()
|
|
tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: edge.To.ID, payload: nodeResult.result.Data}
|
|
}
|
|
} else {
|
|
parentNodes, err := tm.dag.GetPreviousNodes(nodeResult.nodeID)
|
|
if err == nil {
|
|
for _, parentNode := range parentNodes {
|
|
tm.mu.Lock()
|
|
state := tm.taskStates[parentNode.ID]
|
|
if state == nil {
|
|
state = newTaskState(parentNode.ID)
|
|
tm.taskStates[parentNode.ID] = state
|
|
}
|
|
state.targetResults[nodeResult.nodeID] = nodeResult.result
|
|
allTargetNodesDone := len(parentNode.Edges) == len(state.targetResults)
|
|
tm.mu.Unlock()
|
|
|
|
if tm.areAllTargetNodesCompleted(parentNode.ID) && allTargetNodesDone {
|
|
tm.aggregateResults(parentNode.ID, nodeResult.taskID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) areAllTargetNodesCompleted(parentNodeID string) bool {
|
|
parentNode, ok := tm.dag.nodes.Get(parentNodeID)
|
|
if !ok {
|
|
return false
|
|
}
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
for _, targetNode := range parentNode.Edges {
|
|
state := tm.taskStates[targetNode.To.ID]
|
|
if state == nil || state.Status != StatusCompleted {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (tm *TaskManager) aggregateResults(parentNode string, taskID string) {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
state := tm.taskStates[parentNode]
|
|
if len(state.targetResults) > 1 {
|
|
aggregatedData := make([]json.RawMessage, len(state.targetResults))
|
|
i := 0
|
|
for _, result := range state.targetResults {
|
|
aggregatedData[i] = result.Data
|
|
i++
|
|
}
|
|
aggregatedPayload, _ := json.Marshal(aggregatedData)
|
|
state.Result = Result{Data: aggregatedPayload, Status: StatusCompleted}
|
|
} else if len(state.targetResults) == 1 {
|
|
state.Result = maps.Values(state.targetResults)[0]
|
|
}
|
|
tm.processFinalResult(taskID, state)
|
|
}
|
|
|
|
func (tm *TaskManager) processFinalResult(taskID string, state *TaskState) {
|
|
clear(state.targetResults)
|
|
tm.dag.finalResult(taskID, state.Result)
|
|
}
|