mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-04 15:42:49 +08:00
feat: implement Validate
to check for cycle
This commit is contained in:
@@ -284,9 +284,6 @@ func (b *Broker) Start(ctx context.Context) error {
|
|||||||
c.Close()
|
c.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Optionally set connection timeouts to prevent idle connections
|
|
||||||
c.SetReadDeadline(time.Now().Add(5 * time.Minute))
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Attempt to read the message
|
// Attempt to read the message
|
||||||
err := b.readMessage(ctx, c)
|
err := b.readMessage(ctx, c)
|
||||||
|
27
dag/dag.go
27
dag/dag.go
@@ -68,6 +68,8 @@ type DAG struct {
|
|||||||
opts []mq.Option
|
opts []mq.Option
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
paused bool
|
paused bool
|
||||||
|
Error error
|
||||||
|
report string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) SetKey(key string) {
|
func (tm *DAG) SetKey(key string) {
|
||||||
@@ -283,32 +285,47 @@ func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG {
|
func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG {
|
||||||
tm.addEdge(Iterator, label, from, targets...)
|
tm.Error = tm.addEdge(Iterator, label, from, targets...)
|
||||||
return tm
|
return tm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG {
|
func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG {
|
||||||
tm.addEdge(Simple, label, from, targets...)
|
tm.Error = tm.addEdge(Simple, label, from, targets...)
|
||||||
return tm
|
return tm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) {
|
func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) error {
|
||||||
tm.mu.Lock()
|
tm.mu.Lock()
|
||||||
defer tm.mu.Unlock()
|
defer tm.mu.Unlock()
|
||||||
fromNode, ok := tm.nodes[from]
|
fromNode, ok := tm.nodes[from]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return fmt.Errorf("Error: 'from' node %s does not exist\n", from)
|
||||||
}
|
}
|
||||||
var nodes []*Node
|
var nodes []*Node
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
toNode, ok := tm.nodes[target]
|
toNode, ok := tm.nodes[target]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return fmt.Errorf("Error: 'from' node %s does not exist\n", target)
|
||||||
}
|
}
|
||||||
nodes = append(nodes, toNode)
|
nodes = append(nodes, toNode)
|
||||||
}
|
}
|
||||||
edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label}
|
edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label}
|
||||||
fromNode.Edges = append(fromNode.Edges, edge)
|
fromNode.Edges = append(fromNode.Edges, edge)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *DAG) Validate() error {
|
||||||
|
report, hasCycle, err := tm.ClassifyEdges()
|
||||||
|
if hasCycle || err != nil {
|
||||||
|
tm.Error = err
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tm.report = report
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *DAG) GetReport() string {
|
||||||
|
return tm.report
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||||
|
@@ -98,8 +98,10 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
|
|||||||
}
|
}
|
||||||
edges := tm.getConditionalEdges(node, result)
|
edges := tm.getConditionalEdges(node, result)
|
||||||
if len(edges) == 0 {
|
if len(edges) == 0 {
|
||||||
tm.appendFinalResult(result)
|
tm.appendResult(result, true)
|
||||||
return result
|
return result
|
||||||
|
} else {
|
||||||
|
tm.appendResult(result, false)
|
||||||
}
|
}
|
||||||
for _, edge := range edges {
|
for _, edge := range edges {
|
||||||
switch edge.Type {
|
switch edge.Type {
|
||||||
@@ -107,7 +109,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
|
|||||||
var items []json.RawMessage
|
var items []json.RawMessage
|
||||||
err := json.Unmarshal(result.Payload, &items)
|
err := json.Unmarshal(result.Payload, &items)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err})
|
tm.appendResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}, false)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
for _, target := range edge.To {
|
for _, target := range edge.To {
|
||||||
@@ -170,10 +172,12 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
|
|||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TaskManager) appendFinalResult(result mq.Result) {
|
func (tm *TaskManager) appendResult(result mq.Result, final bool) {
|
||||||
tm.mutex.Lock()
|
tm.mutex.Lock()
|
||||||
tm.updateTS(&result)
|
tm.updateTS(&result)
|
||||||
tm.results = append(tm.results, result)
|
if final {
|
||||||
|
tm.results = append(tm.results, result)
|
||||||
|
}
|
||||||
tm.nodeResults[result.Topic] = result
|
tm.nodeResults[result.Topic] = result
|
||||||
tm.mutex.Unlock()
|
tm.mutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -199,7 +203,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
|
result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err(), Ctx: ctx}
|
||||||
tm.appendFinalResult(result)
|
tm.appendResult(result, false)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
|
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key})
|
||||||
@@ -210,14 +214,14 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
|
|||||||
result.TaskID = tm.taskID
|
result.TaskID = tm.taskID
|
||||||
}
|
}
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
tm.appendFinalResult(result)
|
tm.appendResult(result, false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
|
err := tm.dag.server.Publish(ctx, mq.NewTask(tm.taskID, payload, node.Key), node.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tm.appendFinalResult(mq.Result{Error: err})
|
tm.appendResult(mq.Result{Error: err}, false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
66
dag/ui.go
66
dag/ui.go
@@ -32,7 +32,8 @@ func (tm *DAG) PrintGraph() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) ClassifyEdges(startNodes ...string) {
|
func (tm *DAG) ClassifyEdges(startNodes ...string) (string, bool, error) {
|
||||||
|
builder := &strings.Builder{}
|
||||||
startNode := tm.GetStartNode()
|
startNode := tm.GetStartNode()
|
||||||
tm.mu.RLock()
|
tm.mu.RLock()
|
||||||
defer tm.mu.RUnlock()
|
defer tm.mu.RUnlock()
|
||||||
@@ -43,57 +44,78 @@ func (tm *DAG) ClassifyEdges(startNodes ...string) {
|
|||||||
discoveryTime := make(map[string]int)
|
discoveryTime := make(map[string]int)
|
||||||
finishedTime := make(map[string]int)
|
finishedTime := make(map[string]int)
|
||||||
timeVal := 0
|
timeVal := 0
|
||||||
|
inRecursionStack := make(map[string]bool) // track nodes in the recursion stack for cycle detection
|
||||||
if startNode == "" {
|
if startNode == "" {
|
||||||
firstNode := tm.findStartNode()
|
firstNode := tm.findStartNode()
|
||||||
if firstNode != nil {
|
if firstNode != nil {
|
||||||
startNode = firstNode.Key
|
startNode = firstNode.Key
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if startNode != "" {
|
if startNode == "" {
|
||||||
tm.dfs(startNode, visited, discoveryTime, finishedTime, &timeVal)
|
return "", false, fmt.Errorf("no start node found")
|
||||||
}
|
}
|
||||||
|
hasCycle, cycleErr := tm.dfs(startNode, visited, discoveryTime, finishedTime, &timeVal, inRecursionStack, builder)
|
||||||
|
if cycleErr != nil {
|
||||||
|
return builder.String(), hasCycle, cycleErr
|
||||||
|
}
|
||||||
|
return builder.String(), hasCycle, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, timeVal *int) {
|
func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, timeVal *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) {
|
||||||
visited[v] = true
|
visited[v] = true
|
||||||
|
inRecursionStack[v] = true // mark node as part of recursion stack
|
||||||
*timeVal++
|
*timeVal++
|
||||||
discoveryTime[v] = *timeVal
|
discoveryTime[v] = *timeVal
|
||||||
node := tm.nodes[v]
|
node := tm.nodes[v]
|
||||||
|
hasCycle := false
|
||||||
|
var err error
|
||||||
for _, edge := range node.Edges {
|
for _, edge := range node.Edges {
|
||||||
for _, adj := range edge.To {
|
for _, adj := range edge.To {
|
||||||
switch edge.Type {
|
if !visited[adj.Key] {
|
||||||
case Simple:
|
builder.WriteString(fmt.Sprintf("Traversing Edge: %s -> %s\n", v, adj.Key))
|
||||||
if !visited[adj.Key] {
|
hasCycle, err := tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
|
||||||
fmt.Printf("Simple Edge: %s -> %s\n", v, adj.Key)
|
if err != nil {
|
||||||
tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal)
|
return true, err
|
||||||
}
|
}
|
||||||
case Iterator:
|
if hasCycle {
|
||||||
if !visited[adj.Key] {
|
return true, nil
|
||||||
fmt.Printf("Iterator Edge: %s -> %s\n", v, adj.Key)
|
|
||||||
tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal)
|
|
||||||
}
|
}
|
||||||
|
} else if inRecursionStack[adj.Key] {
|
||||||
|
cycleMsg := fmt.Sprintf("Cycle detected: %s -> %s\n", v, adj.Key)
|
||||||
|
return true, fmt.Errorf(cycleMsg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal)
|
hasCycle, err = tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal, inRecursionStack, builder)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
*timeVal++
|
*timeVal++
|
||||||
finishedTime[v] = *timeVal
|
finishedTime[v] = *timeVal
|
||||||
|
inRecursionStack[v] = false // remove from recursion stack after finishing processing
|
||||||
|
return hasCycle, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int) {
|
func (tm *DAG) handleConditionalEdges(v string, visited map[string]bool, discoveryTime, finishedTime map[string]int, time *int, inRecursionStack map[string]bool, builder *strings.Builder) (bool, error) {
|
||||||
node := tm.nodes[v]
|
node := tm.nodes[v]
|
||||||
for when, then := range tm.conditions[FromNode(node.Key)] {
|
for when, then := range tm.conditions[FromNode(node.Key)] {
|
||||||
if targetNodeKey, ok := tm.nodes[string(then)]; ok {
|
if targetNode, ok := tm.nodes[string(then)]; ok {
|
||||||
if !visited[targetNodeKey.Key] {
|
if !visited[targetNode.Key] {
|
||||||
fmt.Printf("Conditional Edge [%s]: %s -> %s\n", when, v, targetNodeKey.Key)
|
builder.WriteString(fmt.Sprintf("Traversing Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key))
|
||||||
tm.dfs(targetNodeKey.Key, visited, discoveryTime, finishedTime, time)
|
hasCycle, err := tm.dfs(targetNode.Key, visited, discoveryTime, finishedTime, time, inRecursionStack, builder)
|
||||||
} else {
|
if err != nil {
|
||||||
if discoveryTime[v] > discoveryTime[targetNodeKey.Key] {
|
return true, err
|
||||||
fmt.Printf("Conditional Loop Edge [%s]: %s -> %s\n", when, v, targetNodeKey.Key)
|
|
||||||
}
|
}
|
||||||
|
if hasCycle {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
} else if inRecursionStack[targetNode.Key] {
|
||||||
|
cycleMsg := fmt.Sprintf("Cycle detected in Conditional Edge [%s]: %s -> %s\n", when, v, targetNode.Key)
|
||||||
|
return true, fmt.Errorf(cycleMsg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) SaveDOTFile(filename string) error {
|
func (tm *DAG) SaveDOTFile(filename string) error {
|
||||||
|
@@ -52,7 +52,12 @@ func Sync() {
|
|||||||
func aSync() {
|
func aSync() {
|
||||||
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete(), mq.WithNotifyResponse(tasks.NotifyResponse))
|
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete(), mq.WithNotifyResponse(tasks.NotifyResponse))
|
||||||
setup(f)
|
setup(f)
|
||||||
err := f.Start(context.TODO(), ":8083")
|
err := f.Validate()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = f.Start(context.TODO(), ":8083")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@@ -5,8 +5,6 @@ import (
|
|||||||
|
|
||||||
"github.com/oarkflow/json"
|
"github.com/oarkflow/json"
|
||||||
|
|
||||||
"github.com/oarkflow/dipper"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
"github.com/oarkflow/mq/services"
|
"github.com/oarkflow/mq/services"
|
||||||
)
|
)
|
||||||
@@ -109,38 +107,3 @@ func (e *InAppNotification) ProcessTask(ctx context.Context, task *mq.Task) mq.R
|
|||||||
}
|
}
|
||||||
return mq.Result{Payload: task.Payload, Ctx: ctx}
|
return mq.Result{Payload: task.Payload, Ctx: ctx}
|
||||||
}
|
}
|
||||||
|
|
||||||
type DataBranchHandler struct{ services.Operation }
|
|
||||||
|
|
||||||
func (v *DataBranchHandler) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
|
||||||
ctx = context.WithValue(ctx, "extra_params", map[string]any{"iphone": true})
|
|
||||||
var row map[string]any
|
|
||||||
var result mq.Result
|
|
||||||
result.Payload = task.Payload
|
|
||||||
err := json.Unmarshal(result.Payload, &row)
|
|
||||||
if err != nil {
|
|
||||||
result.Error = err
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
b := make(map[string]any)
|
|
||||||
switch branches := row["data_branch"].(type) {
|
|
||||||
case map[string]any:
|
|
||||||
for field, handler := range branches {
|
|
||||||
data, err := dipper.Get(row, field)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
b[handler.(string)] = data
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
br, err := json.Marshal(b)
|
|
||||||
if err != nil {
|
|
||||||
result.Error = err
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
result.Status = "branches"
|
|
||||||
result.Payload = br
|
|
||||||
result.Ctx = ctx
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
Reference in New Issue
Block a user