diff --git a/examples/dag.go b/examples/dag.go index 76bf297..a228aea 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -57,11 +57,11 @@ func handler6(ctx context.Context, task *mq.Task) mq.Result { } var ( - d = v2.NewDAG(mq.WithSyncMode(false)) + d = v2.NewDAG(mq.WithSyncMode(true)) ) func main() { - d.AddNode("A", handler1) + d.AddNode("A", handler1, true) d.AddNode("B", handler2) d.AddNode("C", handler3) d.AddNode("D", handler4) @@ -72,7 +72,6 @@ func main() { d.AddEdge("B", "C") d.AddEdge("D", "F") d.AddEdge("E", "F") - // fmt.Println(rs.TaskID, "Task", string(rs.Payload)) http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) @@ -101,7 +100,9 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ http.Error(w, "Empty request body", http.StatusBadRequest) return } - rs := d.ProcessTask(context.Background(), "A", payload) + ctx := context.Background() + // ctx = context.WithValue(ctx, "initial_node", "E") + rs := d.ProcessTask(ctx, payload) w.Header().Set("Content-Type", "application/json") result := map[string]any{ "message_id": rs.TaskID, diff --git a/v2/dag.go b/v2/dag.go index 5163fa2..2115ad4 100644 --- a/v2/dag.go +++ b/v2/dag.go @@ -3,6 +3,7 @@ package v2 import ( "context" "encoding/json" + "fmt" "log" "net/http" "sync" @@ -20,12 +21,6 @@ func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task { return &mq.Task{ID: id, Payload: payload, Topic: nodeKey} } -type Node struct { - Key string - Edges []Edge - consumer *mq.Consumer -} - type EdgeType int func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge } @@ -35,6 +30,12 @@ const ( LoopEdge ) +type Node struct { + Key string + Edges []Edge + consumer *mq.Consumer +} + type Edge struct { From *Node To *Node @@ -42,6 +43,7 @@ type Edge struct { } type DAG struct { + FirstNode string Nodes map[string]*Node server *mq.Broker taskContext map[string]*TaskManager @@ -68,21 +70,21 @@ func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { } func (tm *DAG) Start(ctx context.Context, addr string) error { - if tm.server.SyncMode() { - return nil - } - go func() { - err := tm.server.Start(ctx) - if err != nil { - panic(err) + if !tm.server.SyncMode() { + go func() { + err := tm.server.Start(ctx) + if err != nil { + panic(err) + } + }() + for _, con := range tm.Nodes { + go func(con *Node) { + time.Sleep(1 * time.Second) + con.consumer.Consume(ctx) + }(con) } - }() - for _, con := range tm.Nodes { - go func(con *Node) { - time.Sleep(1 * time.Second) - con.consumer.Consume(ctx) - }(con) } + log.Printf("HTTP server started on %s", addr) config := tm.server.TLSConfig() if config.UseTLS { @@ -91,7 +93,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { return http.ListenAndServe(addr, nil) } -func (tm *DAG) AddNode(key string, handler mq.Handler) { +func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) { tm.mu.Lock() defer tm.mu.Unlock() con := mq.NewConsumer(key, key, handler) @@ -99,6 +101,9 @@ func (tm *DAG) AddNode(key string, handler mq.Handler) { Key: key, consumer: con, } + if len(firstNode) > 0 && firstNode[0] { + tm.FirstNode = key + } } func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { @@ -125,11 +130,51 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { fromNode.Edges = append(fromNode.Edges, edge) } -func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) mq.Result { +func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { + val := ctx.Value("initial_node") + initialNode, ok := val.(string) + if !ok { + if tm.FirstNode == "" { + firstNode := tm.FindInitialNode() + if firstNode != nil { + tm.FirstNode = firstNode.Key + } + } + if tm.FirstNode == "" { + return mq.Result{Error: fmt.Errorf("initial node not found")} + } + initialNode = tm.FirstNode + } tm.mu.Lock() defer tm.mu.Unlock() taskID := xid.New().String() manager := NewTaskManager(tm, taskID) tm.taskContext[taskID] = manager - return manager.processTask(ctx, node, payload) + return manager.processTask(ctx, initialNode, payload) +} + +func (tm *DAG) FindInitialNode() *Node { + incomingEdges := make(map[string]bool) + connectedNodes := make(map[string]bool) + for _, node := range tm.Nodes { + for _, edge := range node.Edges { + if edge.Type.IsValid() { + connectedNodes[node.Key] = true + connectedNodes[edge.To.Key] = true + incomingEdges[edge.To.Key] = true + } + } + if cond, ok := tm.conditions[node.Key]; ok { + for _, target := range cond { + connectedNodes[target] = true + incomingEdges[target] = true + } + } + } + for nodeID, node := range tm.Nodes { + if !incomingEdges[nodeID] && connectedNodes[nodeID] { + return node + } + } + return nil } diff --git a/v2/task_manager.go b/v2/task_manager.go index b9a5f27..92f89c2 100644 --- a/v2/task_manager.go +++ b/v2/task_manager.go @@ -32,42 +32,63 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { } } +func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result { + tm.done = make(chan struct{}) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + close(tm.done) + }() + select { + case <-ctx.Done(): + return mq.Result{Error: ctx.Err()} + case <-tm.done: + tm.mutex.Lock() + defer tm.mutex.Unlock() + if len(tm.results) == 1 { + return tm.handleResult(ctx, tm.results[0]) + } + return tm.handleResult(ctx, tm.results) + } +} + +func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result { + tm.finalResult = make(chan mq.Result) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + }() + select { + case result := <-tm.finalResult: // Block until a result is available + return result + case <-ctx.Done(): // Handle context cancellation + return mq.Result{Error: ctx.Err()} + } +} + func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { node, ok := tm.dag.Nodes[nodeID] if !ok { return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } if tm.dag.server.SyncMode() { - tm.done = make(chan struct{}) - tm.wg.Add(1) - go tm.processNode(ctx, node, payload) - go func() { - tm.wg.Wait() - close(tm.done) - }() - select { - case <-ctx.Done(): - return mq.Result{Error: ctx.Err()} - case <-tm.done: - tm.mutex.Lock() - defer tm.mutex.Unlock() - if len(tm.results) == 1 { - return tm.handleResult(ctx, tm.results[0]) - } - return tm.handleResult(ctx, tm.results) + return tm.handleSyncTask(ctx, node, payload) + } + return tm.handleAsyncTask(ctx, node, payload) +} + +func (tm *TaskManager) dispatchFinalResult(ctx context.Context) { + if !tm.dag.server.SyncMode() { + var rs mq.Result + if len(tm.results) == 1 { + rs = tm.handleResult(ctx, tm.results[0]) + } else { + rs = tm.handleResult(ctx, tm.results) } - } else { - tm.finalResult = make(chan mq.Result) - tm.wg.Add(1) - go tm.processNode(ctx, node, payload) - go func() { - tm.wg.Wait() - }() - select { - case result := <-tm.finalResult: // Block until a result is available - return result - case <-ctx.Done(): // Handle context cancellation - return mq.Result{Error: ctx.Err()} + if tm.waitingCallback == 0 { + tm.finalResult <- rs } } } @@ -93,17 +114,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq. } if len(edges) == 0 { tm.appendFinalResult(result) - if !tm.dag.server.SyncMode() { - var rs mq.Result - if len(tm.results) == 1 { - rs = tm.handleResult(ctx, tm.results[0]) - } else { - rs = tm.handleResult(ctx, tm.results) - } - if tm.waitingCallback == 0 { - tm.finalResult <- rs - } - } + tm.dispatchFinalResult(ctx) return result } for _, edge := range edges { @@ -205,3 +216,10 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json tm.mutex.Unlock() tm.handleCallback(ctx, result) } + +func (tm *TaskManager) Clear() error { + tm.waitingCallback = 0 + clear(tm.results) + tm.nodeResults = make(map[string]mq.Result) + return nil +}