mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-05 07:57:00 +08:00
115 lines
2.6 KiB
Go
115 lines
2.6 KiB
Go
package dag
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/oarkflow/mq/storage"
|
|
"github.com/oarkflow/mq/storage/memory"
|
|
|
|
"github.com/oarkflow/mq"
|
|
)
|
|
|
|
type TaskManager struct {
|
|
createdAt time.Time
|
|
processedAt time.Time
|
|
status string
|
|
dag *DAG
|
|
taskID string
|
|
wg *WaitGroup
|
|
topic string
|
|
result mq.Result
|
|
|
|
taskNodeStatus storage.IMap[string, *taskNodeStatus]
|
|
}
|
|
|
|
func NewTaskManager(d *DAG, taskID string, iteratorNodes map[string][]Edge) *TaskManager {
|
|
if iteratorNodes == nil {
|
|
iteratorNodes = make(map[string][]Edge)
|
|
}
|
|
return &TaskManager{
|
|
dag: d,
|
|
taskNodeStatus: memory.New[string, *taskNodeStatus](),
|
|
taskID: taskID,
|
|
wg: NewWaitGroup(),
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) updateTS(result *mq.Result) {
|
|
result.CreatedAt = tm.createdAt
|
|
result.ProcessedAt = time.Now()
|
|
result.Latency = fmt.Sprintf("%s", time.Since(tm.createdAt))
|
|
}
|
|
|
|
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) mq.Result {
|
|
tm.updateTS(&tm.result)
|
|
tm.dag.callbackToConsumer(ctx, tm.result)
|
|
if tm.dag.server.NotifyHandler() != nil {
|
|
_ = tm.dag.server.NotifyHandler()(ctx, tm.result)
|
|
}
|
|
tm.dag.taskCleanupCh <- tm.taskID
|
|
tm.topic = tm.result.Topic
|
|
return tm.result
|
|
}
|
|
|
|
func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result {
|
|
var rs mq.Result
|
|
switch res := results.(type) {
|
|
case []mq.Result:
|
|
aggregatedOutput := make([]json.RawMessage, 0)
|
|
var status, topic string
|
|
for i, result := range res {
|
|
if i == 0 {
|
|
status = result.Status
|
|
topic = result.Topic
|
|
}
|
|
if result.Error != nil {
|
|
return mq.HandleError(ctx, result.Error)
|
|
}
|
|
var item json.RawMessage
|
|
err := json.Unmarshal(result.Payload, &item)
|
|
if err != nil {
|
|
return mq.HandleError(ctx, err)
|
|
}
|
|
aggregatedOutput = append(aggregatedOutput, item)
|
|
}
|
|
finalOutput, err := json.Marshal(aggregatedOutput)
|
|
if err != nil {
|
|
return mq.HandleError(ctx, err)
|
|
}
|
|
return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic, Ctx: ctx}
|
|
case mq.Result:
|
|
if res.Ctx == nil {
|
|
res.Ctx = ctx
|
|
}
|
|
return res
|
|
}
|
|
if rs.Ctx == nil {
|
|
rs.Ctx = ctx
|
|
}
|
|
return rs
|
|
}
|
|
|
|
func (tm *TaskManager) appendResult(result mq.Result, final bool) {
|
|
if tm.dag.reportNodeResultCallback != nil {
|
|
tm.dag.reportNodeResultCallback(result)
|
|
}
|
|
}
|
|
|
|
func (tm *TaskManager) SetTotalItems(topic string, i int) {
|
|
if nodeStatus, ok := tm.taskNodeStatus.Get(topic); ok {
|
|
nodeStatus.totalItems = i
|
|
}
|
|
}
|
|
|
|
func isDAGNode(node *Node) (*DAG, bool) {
|
|
switch node := node.processor.(type) {
|
|
case *DAG:
|
|
return node, true
|
|
default:
|
|
return nil, false
|
|
}
|
|
}
|