mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-27 12:22:08 +08:00
438 lines
11 KiB
Go
438 lines
11 KiB
Go
package dag
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/oarkflow/mq/logger"
|
|
"log"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/fiber/v2/middleware/recover"
|
|
"github.com/oarkflow/form"
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/oarkflow/mq/sio"
|
|
|
|
"github.com/oarkflow/mq"
|
|
"github.com/oarkflow/mq/storage"
|
|
"github.com/oarkflow/mq/storage/memory"
|
|
)
|
|
|
|
type Node struct {
|
|
processor mq.Processor
|
|
Label string
|
|
ID string
|
|
Edges []Edge
|
|
NodeType NodeType
|
|
isReady bool
|
|
}
|
|
|
|
type Edge struct {
|
|
From *Node
|
|
To *Node
|
|
Label string
|
|
Type EdgeType
|
|
}
|
|
|
|
type DAG struct {
|
|
nodes storage.IMap[string, *Node]
|
|
taskManager storage.IMap[string, *TaskManager]
|
|
iteratorNodes storage.IMap[string, []Edge]
|
|
Error error
|
|
conditions map[string]map[string]string
|
|
consumer *mq.Consumer
|
|
finalResult func(taskID string, result mq.Result)
|
|
pool *mq.Pool
|
|
Notifier *sio.Server
|
|
server *mq.Broker
|
|
reportNodeResultCallback func(mq.Result)
|
|
key string
|
|
consumerTopic string
|
|
startNode string
|
|
name string
|
|
report string
|
|
hasPageNode bool
|
|
paused bool
|
|
}
|
|
|
|
func NewDAG(name, key string, finalResultCallback func(taskID string, result mq.Result), opts ...mq.Option) *DAG {
|
|
callback := func(ctx context.Context, result mq.Result) error { return nil }
|
|
d := &DAG{
|
|
name: name,
|
|
key: key,
|
|
nodes: memory.New[string, *Node](),
|
|
taskManager: memory.New[string, *TaskManager](),
|
|
iteratorNodes: memory.New[string, []Edge](),
|
|
conditions: make(map[string]map[string]string),
|
|
finalResult: finalResultCallback,
|
|
}
|
|
opts = append(opts,
|
|
mq.WithCallback(d.onTaskCallback),
|
|
mq.WithConsumerOnSubscribe(d.onConsumerJoin),
|
|
mq.WithConsumerOnClose(d.onConsumerClose),
|
|
)
|
|
d.server = mq.NewBroker(opts...)
|
|
options := d.server.Options()
|
|
d.pool = mq.NewPool(
|
|
options.NumOfWorkers(),
|
|
mq.WithTaskQueueSize(options.QueueSize()),
|
|
mq.WithMaxMemoryLoad(options.MaxMemoryLoad()),
|
|
mq.WithHandler(d.ProcessTask),
|
|
mq.WithPoolCallback(callback),
|
|
mq.WithTaskStorage(options.Storage()),
|
|
)
|
|
d.pool.Start(d.server.Options().NumOfWorkers())
|
|
return d
|
|
}
|
|
|
|
func (tm *DAG) SetKey(key string) {
|
|
tm.key = key
|
|
}
|
|
|
|
func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
|
|
tm.reportNodeResultCallback = callback
|
|
}
|
|
|
|
func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
|
|
if manager, ok := tm.taskManager.Get(result.TaskID); ok && result.Topic != "" {
|
|
manager.onNodeCompleted(nodeResult{
|
|
ctx: ctx,
|
|
nodeID: result.Topic,
|
|
status: result.Status,
|
|
result: result,
|
|
})
|
|
}
|
|
return mq.Result{}
|
|
}
|
|
|
|
func (tm *DAG) GetType() string {
|
|
return tm.key
|
|
}
|
|
|
|
func (tm *DAG) Stop(ctx context.Context) error {
|
|
tm.nodes.ForEach(func(_ string, n *Node) bool {
|
|
err := n.processor.Stop(ctx)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) GetKey() string {
|
|
return tm.key
|
|
}
|
|
|
|
func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
|
|
tm.server.SetNotifyHandler(callback)
|
|
}
|
|
|
|
func (tm *DAG) SetStartNode(node string) {
|
|
tm.startNode = node
|
|
}
|
|
|
|
func (tm *DAG) GetStartNode() string {
|
|
return tm.startNode
|
|
}
|
|
|
|
func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) *DAG {
|
|
tm.conditions[fromNode] = conditions
|
|
return tm
|
|
}
|
|
|
|
func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Processor, startNode ...bool) *DAG {
|
|
if tm.Error != nil {
|
|
return tm
|
|
}
|
|
con := mq.NewConsumer(nodeID, nodeID, handler.ProcessTask)
|
|
n := &Node{
|
|
Label: name,
|
|
ID: nodeID,
|
|
NodeType: nodeType,
|
|
processor: con,
|
|
}
|
|
if tm.server != nil && tm.server.SyncMode() {
|
|
n.isReady = true
|
|
}
|
|
tm.nodes.Set(nodeID, n)
|
|
if len(startNode) > 0 && startNode[0] {
|
|
tm.startNode = nodeID
|
|
}
|
|
if nodeType == Page && !tm.hasPageNode {
|
|
tm.hasPageNode = true
|
|
}
|
|
return tm
|
|
}
|
|
|
|
func (tm *DAG) AddDeferredNode(nodeType NodeType, name, key string, firstNode ...bool) error {
|
|
if tm.server.SyncMode() {
|
|
return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
|
|
}
|
|
tm.nodes.Set(key, &Node{
|
|
Label: name,
|
|
ID: key,
|
|
NodeType: nodeType,
|
|
})
|
|
if len(firstNode) > 0 && firstNode[0] {
|
|
tm.startNode = key
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) IsReady() bool {
|
|
var isReady bool
|
|
tm.nodes.ForEach(func(_ string, n *Node) bool {
|
|
if !n.isReady {
|
|
return false
|
|
}
|
|
isReady = true
|
|
return true
|
|
})
|
|
return isReady
|
|
}
|
|
|
|
func (tm *DAG) AddEdge(edgeType EdgeType, label, from string, targets ...string) *DAG {
|
|
if tm.Error != nil {
|
|
return tm
|
|
}
|
|
if edgeType == Iterator {
|
|
tm.iteratorNodes.Set(from, []Edge{})
|
|
}
|
|
node, ok := tm.nodes.Get(from)
|
|
if !ok {
|
|
tm.Error = fmt.Errorf("node not found %s", from)
|
|
return tm
|
|
}
|
|
for _, target := range targets {
|
|
if targetNode, ok := tm.nodes.Get(target); ok {
|
|
edge := Edge{From: node, To: targetNode, Type: edgeType, Label: label}
|
|
node.Edges = append(node.Edges, edge)
|
|
if edgeType != Iterator {
|
|
if edges, ok := tm.iteratorNodes.Get(node.ID); ok {
|
|
edges = append(edges, edge)
|
|
tm.iteratorNodes.Set(node.ID, edges)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return tm
|
|
}
|
|
|
|
func (tm *DAG) getCurrentNode(manager *TaskManager) string {
|
|
if manager.currentNodePayload.Size() == 0 {
|
|
return ""
|
|
}
|
|
return manager.currentNodePayload.Keys()[0]
|
|
}
|
|
|
|
func (tm *DAG) Logger() logger.Logger {
|
|
return tm.server.Options().Logger()
|
|
}
|
|
|
|
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
|
ctx = context.WithValue(ctx, "task_id", task.ID)
|
|
userContext := form.UserContext(ctx)
|
|
next := userContext.Get("next")
|
|
manager, ok := tm.taskManager.Get(task.ID)
|
|
resultCh := make(chan mq.Result, 1)
|
|
if !ok {
|
|
manager = NewTaskManager(tm, task.ID, resultCh, tm.iteratorNodes.Clone())
|
|
tm.taskManager.Set(task.ID, manager)
|
|
} else {
|
|
manager.resultCh = resultCh
|
|
}
|
|
currentKey := tm.getCurrentNode(manager)
|
|
currentNode := strings.Split(currentKey, Delimiter)[0]
|
|
node, exists := tm.nodes.Get(currentNode)
|
|
method, ok := ctx.Value("method").(string)
|
|
if method == "GET" && exists && node.NodeType == Page {
|
|
ctx = context.WithValue(ctx, "initial_node", currentNode)
|
|
if manager.result != nil {
|
|
task.Payload = manager.result.Payload
|
|
}
|
|
} else if next == "true" {
|
|
nodes, err := tm.GetNextNodes(currentNode)
|
|
if err != nil {
|
|
return mq.Result{Error: err, Ctx: ctx}
|
|
}
|
|
if len(nodes) > 0 {
|
|
ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
|
|
}
|
|
}
|
|
if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
|
|
var taskPayload, resultPayload map[string]any
|
|
if err := json.Unmarshal(task.Payload, &taskPayload); err == nil {
|
|
if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
|
|
for key, val := range resultPayload {
|
|
taskPayload[key] = val
|
|
}
|
|
task.Payload, _ = json.Marshal(taskPayload)
|
|
}
|
|
}
|
|
}
|
|
firstNode, err := tm.parseInitialNode(ctx)
|
|
if err != nil {
|
|
return mq.Result{Error: err, Ctx: ctx}
|
|
}
|
|
node, ok = tm.nodes.Get(firstNode)
|
|
task.Topic = firstNode
|
|
ctx = context.WithValue(ctx, ContextIndex, "0")
|
|
manager.ProcessTask(ctx, firstNode, task.Payload)
|
|
if tm.hasPageNode {
|
|
return <-resultCh
|
|
}
|
|
select {
|
|
case result := <-resultCh:
|
|
return result
|
|
case <-time.After(30 * time.Second):
|
|
return mq.Result{
|
|
Error: fmt.Errorf("timeout waiting for task result"),
|
|
Ctx: ctx,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
|
|
var taskID string
|
|
userCtx := form.UserContext(ctx)
|
|
if val := userCtx.Get("task_id"); val != "" {
|
|
taskID = val
|
|
} else {
|
|
taskID = mq.NewID()
|
|
}
|
|
return tm.ProcessTask(ctx, mq.NewTask(taskID, payload, ""))
|
|
}
|
|
|
|
func (tm *DAG) Validate() error {
|
|
report, hasCycle, err := tm.ClassifyEdges()
|
|
if hasCycle || err != nil {
|
|
tm.Error = err
|
|
return err
|
|
}
|
|
tm.report = report
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) GetReport() string {
|
|
return tm.report
|
|
}
|
|
|
|
func (tm *DAG) AddDAGNode(nodeType NodeType, name string, key string, dag *DAG, firstNode ...bool) *DAG {
|
|
dag.AssignTopic(key)
|
|
tm.nodes.Set(key, &Node{
|
|
Label: name,
|
|
ID: key,
|
|
NodeType: nodeType,
|
|
processor: dag,
|
|
isReady: true,
|
|
})
|
|
if len(firstNode) > 0 && firstNode[0] {
|
|
tm.startNode = key
|
|
}
|
|
return tm
|
|
}
|
|
|
|
func (tm *DAG) Start(ctx context.Context, addr string) error {
|
|
|
|
go func() {
|
|
defer mq.RecoverPanic(mq.RecoverTitle)
|
|
if err := tm.server.Start(ctx); err != nil {
|
|
panic(err)
|
|
}
|
|
}()
|
|
|
|
if !tm.server.SyncMode() {
|
|
tm.nodes.ForEach(func(_ string, con *Node) bool {
|
|
go func(con *Node) {
|
|
defer mq.RecoverPanic(mq.RecoverTitle)
|
|
limiter := rate.NewLimiter(rate.Every(1*time.Second), 1)
|
|
for {
|
|
err := con.processor.Consume(ctx)
|
|
if err != nil {
|
|
log.Printf("[ERROR] - Consumer %s failed to start: %v", con.ID, err)
|
|
} else {
|
|
log.Printf("[INFO] - Consumer %s started successfully", con.ID)
|
|
break
|
|
}
|
|
limiter.Wait(ctx)
|
|
}
|
|
}(con)
|
|
return true
|
|
})
|
|
}
|
|
app := fiber.New()
|
|
app.Use(recover.New(recover.Config{
|
|
EnableStackTrace: true,
|
|
}))
|
|
tm.Handlers(app)
|
|
return app.Listen(addr)
|
|
}
|
|
|
|
func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
|
|
var taskID string
|
|
userCtx := form.UserContext(ctx)
|
|
if val := userCtx.Get("task_id"); val != "" {
|
|
taskID = val
|
|
} else {
|
|
taskID = mq.NewID()
|
|
}
|
|
t := mq.NewTask(taskID, payload, "")
|
|
|
|
ctx = context.WithValue(ctx, "task_id", taskID)
|
|
userContext := form.UserContext(ctx)
|
|
next := userContext.Get("next")
|
|
manager, ok := tm.taskManager.Get(taskID)
|
|
resultCh := make(chan mq.Result, 1)
|
|
if !ok {
|
|
manager = NewTaskManager(tm, taskID, resultCh, tm.iteratorNodes.Clone())
|
|
tm.taskManager.Set(taskID, manager)
|
|
} else {
|
|
manager.resultCh = resultCh
|
|
}
|
|
currentKey := tm.getCurrentNode(manager)
|
|
currentNode := strings.Split(currentKey, Delimiter)[0]
|
|
node, exists := tm.nodes.Get(currentNode)
|
|
method, ok := ctx.Value("method").(string)
|
|
if method == "GET" && exists && node.NodeType == Page {
|
|
ctx = context.WithValue(ctx, "initial_node", currentNode)
|
|
} else if next == "true" {
|
|
nodes, err := tm.GetNextNodes(currentNode)
|
|
if err != nil {
|
|
return mq.Result{Error: err, Ctx: ctx}
|
|
}
|
|
if len(nodes) > 0 {
|
|
ctx = context.WithValue(ctx, "initial_node", nodes[0].ID)
|
|
}
|
|
}
|
|
if currentNodeResult, hasResult := manager.currentNodeResult.Get(currentKey); hasResult {
|
|
var taskPayload, resultPayload map[string]any
|
|
if err := json.Unmarshal(payload, &taskPayload); err == nil {
|
|
if err = json.Unmarshal(currentNodeResult.Payload, &resultPayload); err == nil {
|
|
for key, val := range resultPayload {
|
|
taskPayload[key] = val
|
|
}
|
|
payload, _ = json.Marshal(taskPayload)
|
|
}
|
|
}
|
|
}
|
|
firstNode, err := tm.parseInitialNode(ctx)
|
|
if err != nil {
|
|
return mq.Result{Error: err, Ctx: ctx}
|
|
}
|
|
node, ok = tm.nodes.Get(firstNode)
|
|
t.Topic = firstNode
|
|
ctx = context.WithValue(ctx, ContextIndex, "0")
|
|
|
|
headers, ok := mq.GetHeaders(ctx)
|
|
ctxx := context.Background()
|
|
if ok {
|
|
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
|
}
|
|
tm.pool.Scheduler().AddTask(ctxx, t, opts...)
|
|
return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: mq.Pending}
|
|
}
|