mirror of
https://github.com/oarkflow/mq.git
synced 2025-11-02 20:04:02 +08:00
feat: sig
This commit is contained in:
@@ -15,6 +15,17 @@ import (
|
||||
"github.com/oarkflow/mq/storage/memory"
|
||||
)
|
||||
|
||||
// TaskError is used by node processors to indicate whether an error is recoverable.
|
||||
type TaskError struct {
|
||||
Err error
|
||||
Recoverable bool
|
||||
}
|
||||
|
||||
func (te TaskError) Error() string {
|
||||
return te.Err.Error()
|
||||
}
|
||||
|
||||
// TaskState holds state and intermediate results for a given task (identified by a node ID).
|
||||
type TaskState struct {
|
||||
UpdatedAt time.Time
|
||||
targetResults storage.IMap[string, mq.Result]
|
||||
@@ -39,26 +50,6 @@ type nodeResult struct {
|
||||
result mq.Result
|
||||
}
|
||||
|
||||
type TaskManager struct {
|
||||
createdAt time.Time
|
||||
taskStates storage.IMap[string, *TaskState]
|
||||
parentNodes storage.IMap[string, string]
|
||||
childNodes storage.IMap[string, int]
|
||||
deferredTasks storage.IMap[string, *task]
|
||||
iteratorNodes storage.IMap[string, []Edge]
|
||||
currentNodePayload storage.IMap[string, json.RawMessage]
|
||||
currentNodeResult storage.IMap[string, mq.Result]
|
||||
taskQueue chan *task
|
||||
result *mq.Result
|
||||
resultQueue chan nodeResult
|
||||
resultCh chan mq.Result
|
||||
stopCh chan struct{}
|
||||
taskID string
|
||||
dag *DAG
|
||||
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type task struct {
|
||||
ctx context.Context
|
||||
taskID string
|
||||
@@ -75,7 +66,41 @@ func newTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage
|
||||
}
|
||||
}
|
||||
|
||||
type TaskManagerConfig struct {
|
||||
MaxRetries int
|
||||
BaseBackoff time.Duration
|
||||
RecoveryHandler func(ctx context.Context, result mq.Result) error
|
||||
}
|
||||
|
||||
type TaskManager struct {
|
||||
createdAt time.Time
|
||||
taskStates storage.IMap[string, *TaskState]
|
||||
parentNodes storage.IMap[string, string]
|
||||
childNodes storage.IMap[string, int]
|
||||
deferredTasks storage.IMap[string, *task]
|
||||
iteratorNodes storage.IMap[string, []Edge]
|
||||
currentNodePayload storage.IMap[string, json.RawMessage]
|
||||
currentNodeResult storage.IMap[string, mq.Result]
|
||||
taskQueue chan *task
|
||||
result *mq.Result
|
||||
resultQueue chan nodeResult
|
||||
resultCh chan mq.Result
|
||||
stopCh chan struct{}
|
||||
taskID string
|
||||
dag *DAG
|
||||
maxRetries int
|
||||
baseBackoff time.Duration
|
||||
recoveryHandler func(ctx context.Context, result mq.Result) error
|
||||
pauseMu sync.Mutex
|
||||
pauseCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNodes storage.IMap[string, []Edge]) *TaskManager {
|
||||
config := TaskManagerConfig{
|
||||
MaxRetries: 3,
|
||||
BaseBackoff: time.Second,
|
||||
}
|
||||
tm := &TaskManager{
|
||||
createdAt: time.Now(),
|
||||
taskStates: memory.New[string, *TaskState](),
|
||||
@@ -86,16 +111,21 @@ func NewTaskManager(dag *DAG, taskID string, resultCh chan mq.Result, iteratorNo
|
||||
currentNodeResult: memory.New[string, mq.Result](),
|
||||
taskQueue: make(chan *task, DefaultChannelSize),
|
||||
resultQueue: make(chan nodeResult, DefaultChannelSize),
|
||||
iteratorNodes: iteratorNodes,
|
||||
stopCh: make(chan struct{}),
|
||||
resultCh: resultCh,
|
||||
stopCh: make(chan struct{}),
|
||||
taskID: taskID,
|
||||
dag: dag,
|
||||
maxRetries: config.MaxRetries,
|
||||
baseBackoff: config.BaseBackoff,
|
||||
recoveryHandler: config.RecoveryHandler,
|
||||
iteratorNodes: iteratorNodes,
|
||||
}
|
||||
|
||||
tm.wg.Add(2)
|
||||
tm.wg.Add(3)
|
||||
go tm.run()
|
||||
go tm.waitForResult()
|
||||
go tm.retryDeferredTasks()
|
||||
|
||||
return tm
|
||||
}
|
||||
|
||||
@@ -104,7 +134,6 @@ func (tm *TaskManager) ProcessTask(ctx context.Context, startNode string, payloa
|
||||
}
|
||||
|
||||
func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string, payload json.RawMessage) {
|
||||
|
||||
if index, ok := ctx.Value(ContextIndex).(string); ok {
|
||||
base := strings.Split(startNode, Delimiter)[0]
|
||||
startNode = fmt.Sprintf("%s%s%s", base, Delimiter, index)
|
||||
@@ -112,12 +141,11 @@ func (tm *TaskManager) enqueueTask(ctx context.Context, startNode, taskID string
|
||||
if _, exists := tm.taskStates.Get(startNode); !exists {
|
||||
tm.taskStates.Set(startNode, newTaskState(startNode))
|
||||
}
|
||||
|
||||
t := newTask(ctx, taskID, startNode, payload)
|
||||
select {
|
||||
case tm.taskQueue <- t:
|
||||
default:
|
||||
log.Println("task queue is full, deferring task")
|
||||
log.Println("Task queue is full, deferring task")
|
||||
tm.deferredTasks.Set(taskID, t)
|
||||
}
|
||||
}
|
||||
@@ -127,23 +155,36 @@ func (tm *TaskManager) run() {
|
||||
for {
|
||||
select {
|
||||
case <-tm.stopCh:
|
||||
log.Println("Stopping TaskManager")
|
||||
log.Println("Stopping TaskManager run loop")
|
||||
return
|
||||
case tsk := <-tm.taskQueue:
|
||||
tm.processNode(tsk)
|
||||
default:
|
||||
tm.pauseMu.Lock()
|
||||
pch := tm.pauseCh
|
||||
tm.pauseMu.Unlock()
|
||||
if pch != nil {
|
||||
<-pch
|
||||
}
|
||||
select {
|
||||
case <-tm.stopCh:
|
||||
log.Println("Stopping TaskManager run loop")
|
||||
return
|
||||
case tsk := <-tm.taskQueue:
|
||||
tm.processNode(tsk)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForResult listens for node results on resultQueue and processes them.
|
||||
func (tm *TaskManager) waitForResult() {
|
||||
defer tm.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-tm.stopCh:
|
||||
log.Println("Stopping Result Listener")
|
||||
log.Println("Stopping TaskManager result listener")
|
||||
return
|
||||
case res := <-tm.resultQueue:
|
||||
tm.onNodeCompleted(res)
|
||||
case nr := <-tm.resultQueue:
|
||||
tm.onNodeCompleted(nr)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -151,60 +192,82 @@ func (tm *TaskManager) waitForResult() {
|
||||
func (tm *TaskManager) processNode(exec *task) {
|
||||
startTime := time.Now()
|
||||
pureNodeID := strings.Split(exec.nodeID, Delimiter)[0]
|
||||
|
||||
node, exists := tm.dag.nodes.Get(pureNodeID)
|
||||
if !exists {
|
||||
tm.dag.Logger().Error("Node not found while processing node",
|
||||
logger.Field{Key: "nodeID", Value: pureNodeID})
|
||||
tm.dag.Logger().Error("Node not found", logger.Field{Key: "nodeID", Value: pureNodeID})
|
||||
return
|
||||
}
|
||||
|
||||
state, _ := tm.taskStates.Get(exec.nodeID)
|
||||
if state == nil {
|
||||
tm.dag.Logger().Warn("State not found; creating new state",
|
||||
logger.Field{Key: "nodeID", Value: exec.nodeID})
|
||||
tm.dag.Logger().Warn("State not found; creating new state", logger.Field{Key: "nodeID", Value: exec.nodeID})
|
||||
state = newTaskState(exec.nodeID)
|
||||
tm.taskStates.Set(exec.nodeID, state)
|
||||
}
|
||||
|
||||
state.Status = mq.Processing
|
||||
state.UpdatedAt = time.Now()
|
||||
tm.currentNodePayload.Clear()
|
||||
tm.currentNodeResult.Clear()
|
||||
tm.currentNodePayload.Set(exec.nodeID, exec.payload)
|
||||
|
||||
result := node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
|
||||
nodeLatency := time.Since(startTime)
|
||||
tm.logNodeExecution(exec, pureNodeID, result, nodeLatency)
|
||||
|
||||
if isLast, err := tm.dag.IsLastNode(pureNodeID); err != nil {
|
||||
tm.dag.Logger().Error("Error checking if node is last",
|
||||
logger.Field{Key: "nodeID", Value: pureNodeID},
|
||||
logger.Field{Key: "error", Value: err.Error()})
|
||||
} else if isLast {
|
||||
result.Last = true
|
||||
var result mq.Result
|
||||
attempts := 0
|
||||
for {
|
||||
// Tracing start (stubbed)
|
||||
log.Printf("Tracing: Start processing node %s (attempt %d)", exec.nodeID, attempts+1)
|
||||
result = node.processor.ProcessTask(exec.ctx, mq.NewTask(exec.taskID, exec.payload, exec.nodeID))
|
||||
if result.Error != nil {
|
||||
if te, ok := result.Error.(TaskError); ok && te.Recoverable {
|
||||
if attempts < tm.maxRetries {
|
||||
attempts++
|
||||
backoff := tm.baseBackoff * time.Duration(1<<attempts)
|
||||
log.Printf("Recoverable error on node %s, retrying in %s: %v", exec.nodeID, backoff, result.Error)
|
||||
time.Sleep(backoff)
|
||||
continue
|
||||
} else if tm.recoveryHandler != nil {
|
||||
if err := tm.recoveryHandler(exec.ctx, result); err == nil {
|
||||
result.Error = nil
|
||||
result.Status = mq.Completed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
tm.currentNodeResult.Set(exec.nodeID, result)
|
||||
state.Result = result
|
||||
result.Topic = node.ID
|
||||
tm.updateTimestamps(&result)
|
||||
log.Printf("Tracing: End processing node %s", exec.nodeID)
|
||||
nodeLatency := time.Since(startTime)
|
||||
|
||||
if result.Error != nil {
|
||||
result.Status = mq.Failed
|
||||
state.Status = mq.Failed
|
||||
state.Result.Status = mq.Failed
|
||||
state.Result.Latency = result.Latency
|
||||
state.Result.Latency = nodeLatency.String()
|
||||
tm.result = &result
|
||||
tm.resultCh <- result
|
||||
tm.processFinalResult(state)
|
||||
return
|
||||
}
|
||||
|
||||
result.Status = mq.Completed
|
||||
state.Result = result
|
||||
state.Result.Status = mq.Completed
|
||||
state.Result.Latency = result.Latency
|
||||
state.Result.Latency = nodeLatency.String()
|
||||
result.Topic = node.ID
|
||||
tm.updateTimestamps(&result)
|
||||
|
||||
isLast, err := tm.dag.IsLastNode(pureNodeID)
|
||||
if err != nil {
|
||||
tm.dag.Logger().Error("Error checking if node is last", logger.Field{Key: "nodeID", Value: pureNodeID}, logger.Field{Key: "error", Value: err.Error()})
|
||||
} else if isLast {
|
||||
result.Last = true
|
||||
}
|
||||
tm.currentNodeResult.Set(exec.nodeID, result)
|
||||
tm.logNodeExecution(exec, pureNodeID, result, nodeLatency)
|
||||
|
||||
if result.Error != nil {
|
||||
tm.result = &result
|
||||
tm.resultCh <- result
|
||||
tm.processFinalResult(state)
|
||||
return
|
||||
}
|
||||
if result.Last || node.NodeType == Page {
|
||||
tm.result = &result
|
||||
tm.resultCh <- result
|
||||
@@ -244,20 +307,18 @@ func (tm *TaskManager) updateTimestamps(rs *mq.Result) {
|
||||
func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, result mq.Result, childNode string, dispatchFinal bool) {
|
||||
state.targetResults.Set(childNode, result)
|
||||
state.targetResults.Del(state.NodeID)
|
||||
|
||||
targetCount, _ := tm.childNodes.Get(state.NodeID)
|
||||
targetsCount, _ := tm.childNodes.Get(state.NodeID)
|
||||
size := state.targetResults.Size()
|
||||
|
||||
if size == targetCount {
|
||||
if size == targetsCount {
|
||||
if size > 1 {
|
||||
aggregatedData := make([]json.RawMessage, size)
|
||||
aggregated := make([]json.RawMessage, size)
|
||||
i := 0
|
||||
state.targetResults.ForEach(func(_ string, res mq.Result) bool {
|
||||
aggregatedData[i] = res.Payload
|
||||
aggregated[i] = res.Payload
|
||||
i++
|
||||
return true
|
||||
})
|
||||
aggregatedPayload, err := json.Marshal(aggregatedData)
|
||||
aggregatedPayload, err := json.Marshal(aggregated)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -279,16 +340,14 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res
|
||||
if result.Error != nil {
|
||||
state.Status = mq.Failed
|
||||
}
|
||||
|
||||
if parentKey, ok := tm.parentNodes.Get(state.NodeID); ok {
|
||||
|
||||
nodeIDParts := strings.Split(state.NodeID, Delimiter)
|
||||
if edges, exists := tm.iteratorNodes.Get(nodeIDParts[0]); exists && state.Status == mq.Completed {
|
||||
parts := strings.Split(state.NodeID, Delimiter)
|
||||
if edges, exists := tm.iteratorNodes.Get(parts[0]); exists && state.Status == mq.Completed {
|
||||
state.Status = mq.Processing
|
||||
tm.iteratorNodes.Del(nodeIDParts[0])
|
||||
tm.iteratorNodes.Del(parts[0])
|
||||
state.targetResults.Clear()
|
||||
if len(nodeIDParts) == 2 {
|
||||
ctx = context.WithValue(ctx, ContextIndex, nodeIDParts[1])
|
||||
if len(parts) == 2 {
|
||||
ctx = context.WithValue(ctx, ContextIndex, parts[1])
|
||||
}
|
||||
toProcess := nodeResult{
|
||||
ctx: ctx,
|
||||
@@ -297,7 +356,7 @@ func (tm *TaskManager) handlePrevious(ctx context.Context, state *TaskState, res
|
||||
result: state.Result,
|
||||
}
|
||||
tm.handleEdges(toProcess, edges)
|
||||
} else if size == targetCount {
|
||||
} else if size == targetsCount {
|
||||
if parentState, _ := tm.taskStates.Get(parentKey); parentState != nil {
|
||||
state.Result.Topic = state.NodeID
|
||||
tm.handlePrevious(ctx, parentState, state.Result, state.NodeID, dispatchFinal)
|
||||
@@ -340,7 +399,7 @@ func (tm *TaskManager) enqueueResult(nr nodeResult) {
|
||||
select {
|
||||
case tm.resultQueue <- nr:
|
||||
default:
|
||||
log.Println("Result queue is full, dropping result.")
|
||||
log.Println("Result queue is full, dropping result")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,12 +436,12 @@ func (tm *TaskManager) getConditionalEdges(node *Node, result mq.Result) []Edge
|
||||
copy(edges, node.Edges)
|
||||
if result.ConditionStatus != "" {
|
||||
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
|
||||
if targetNodeKey, exists := conditions[result.ConditionStatus]; exists {
|
||||
if targetNode, found := tm.dag.nodes.Get(targetNodeKey); found {
|
||||
if targetKey, exists := conditions[result.ConditionStatus]; exists {
|
||||
if targetNode, found := tm.dag.nodes.Get(targetKey); found {
|
||||
edges = append(edges, Edge{From: node, To: targetNode})
|
||||
}
|
||||
} else if targetNodeKey, exists := conditions["default"]; exists {
|
||||
if targetNode, found := tm.dag.nodes.Get(targetNodeKey); found {
|
||||
} else if targetKey, exists := conditions["default"]; exists {
|
||||
if targetNode, found := tm.dag.nodes.Get(targetKey); found {
|
||||
edges = append(edges, Edge{From: node, To: targetNode})
|
||||
}
|
||||
}
|
||||
@@ -409,7 +468,7 @@ func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
|
||||
if edge.Type == Iterator {
|
||||
var items []json.RawMessage
|
||||
if err := json.Unmarshal(currentResult.result.Payload, &items); err != nil {
|
||||
log.Printf("Error unmarshalling data for node %s: %v\n", edge.To.ID, err)
|
||||
log.Printf("Error unmarshalling payload for node %s: %v", edge.To.ID, err)
|
||||
tm.enqueueResult(nodeResult{
|
||||
ctx: currentResult.ctx,
|
||||
nodeID: edge.To.ID,
|
||||
@@ -438,19 +497,19 @@ func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) {
|
||||
}
|
||||
|
||||
func (tm *TaskManager) retryDeferredTasks() {
|
||||
const maxRetries = 5
|
||||
backoff := time.Second
|
||||
for retries := 0; retries < maxRetries; retries++ {
|
||||
defer tm.wg.Done()
|
||||
ticker := time.NewTicker(tm.baseBackoff)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-tm.stopCh:
|
||||
log.Println("Stopping Deferred Task Retrier")
|
||||
log.Println("Stopping deferred task retrier")
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
case <-ticker.C:
|
||||
tm.deferredTasks.ForEach(func(taskID string, tsk *task) bool {
|
||||
tm.enqueueTask(tsk.ctx, tsk.nodeID, taskID, tsk.payload)
|
||||
return true
|
||||
})
|
||||
backoff *= 2
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -463,6 +522,25 @@ func (tm *TaskManager) processFinalResult(state *TaskState) {
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) Pause() {
|
||||
tm.pauseMu.Lock()
|
||||
defer tm.pauseMu.Unlock()
|
||||
if tm.pauseCh == nil {
|
||||
tm.pauseCh = make(chan struct{})
|
||||
log.Println("TaskManager paused")
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) Resume() {
|
||||
tm.pauseMu.Lock()
|
||||
defer tm.pauseMu.Unlock()
|
||||
if tm.pauseCh != nil {
|
||||
close(tm.pauseCh)
|
||||
tm.pauseCh = nil
|
||||
log.Println("TaskManager resumed")
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) Stop() {
|
||||
close(tm.stopCh)
|
||||
tm.wg.Wait()
|
||||
|
||||
Reference in New Issue
Block a user