feat: add example

This commit is contained in:
sujit
2024-10-08 23:57:44 +05:45
parent 3e8f47086f
commit e477acf91c
3 changed files with 130 additions and 66 deletions

View File

@@ -57,11 +57,11 @@ func handler6(ctx context.Context, task *mq.Task) mq.Result {
} }
var ( var (
d = v2.NewDAG(mq.WithSyncMode(false)) d = v2.NewDAG(mq.WithSyncMode(true))
) )
func main() { func main() {
d.AddNode("A", handler1) d.AddNode("A", handler1, true)
d.AddNode("B", handler2) d.AddNode("B", handler2)
d.AddNode("C", handler3) d.AddNode("C", handler3)
d.AddNode("D", handler4) d.AddNode("D", handler4)
@@ -72,7 +72,6 @@ func main() {
d.AddEdge("B", "C") d.AddEdge("B", "C")
d.AddEdge("D", "F") d.AddEdge("D", "F")
d.AddEdge("E", "F") d.AddEdge("E", "F")
// fmt.Println(rs.TaskID, "Task", string(rs.Payload)) // fmt.Println(rs.TaskID, "Task", string(rs.Payload))
http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request")) http.HandleFunc("POST /request", requestHandler("request"))
@@ -101,7 +100,9 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ
http.Error(w, "Empty request body", http.StatusBadRequest) http.Error(w, "Empty request body", http.StatusBadRequest)
return return
} }
rs := d.ProcessTask(context.Background(), "A", payload) ctx := context.Background()
// ctx = context.WithValue(ctx, "initial_node", "E")
rs := d.ProcessTask(ctx, payload)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
result := map[string]any{ result := map[string]any{
"message_id": rs.TaskID, "message_id": rs.TaskID,

View File

@@ -3,6 +3,7 @@ package v2
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"log" "log"
"net/http" "net/http"
"sync" "sync"
@@ -20,12 +21,6 @@ func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task {
return &mq.Task{ID: id, Payload: payload, Topic: nodeKey} return &mq.Task{ID: id, Payload: payload, Topic: nodeKey}
} }
type Node struct {
Key string
Edges []Edge
consumer *mq.Consumer
}
type EdgeType int type EdgeType int
func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge } func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge }
@@ -35,6 +30,12 @@ const (
LoopEdge LoopEdge
) )
type Node struct {
Key string
Edges []Edge
consumer *mq.Consumer
}
type Edge struct { type Edge struct {
From *Node From *Node
To *Node To *Node
@@ -42,6 +43,7 @@ type Edge struct {
} }
type DAG struct { type DAG struct {
FirstNode string
Nodes map[string]*Node Nodes map[string]*Node
server *mq.Broker server *mq.Broker
taskContext map[string]*TaskManager taskContext map[string]*TaskManager
@@ -68,21 +70,21 @@ func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
} }
func (tm *DAG) Start(ctx context.Context, addr string) error { func (tm *DAG) Start(ctx context.Context, addr string) error {
if tm.server.SyncMode() { if !tm.server.SyncMode() {
return nil go func() {
} err := tm.server.Start(ctx)
go func() { if err != nil {
err := tm.server.Start(ctx) panic(err)
if err != nil { }
panic(err) }()
for _, con := range tm.Nodes {
go func(con *Node) {
time.Sleep(1 * time.Second)
con.consumer.Consume(ctx)
}(con)
} }
}()
for _, con := range tm.Nodes {
go func(con *Node) {
time.Sleep(1 * time.Second)
con.consumer.Consume(ctx)
}(con)
} }
log.Printf("HTTP server started on %s", addr) log.Printf("HTTP server started on %s", addr)
config := tm.server.TLSConfig() config := tm.server.TLSConfig()
if config.UseTLS { if config.UseTLS {
@@ -91,7 +93,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
return http.ListenAndServe(addr, nil) return http.ListenAndServe(addr, nil)
} }
func (tm *DAG) AddNode(key string, handler mq.Handler) { func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) {
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
con := mq.NewConsumer(key, key, handler) con := mq.NewConsumer(key, key, handler)
@@ -99,6 +101,9 @@ func (tm *DAG) AddNode(key string, handler mq.Handler) {
Key: key, Key: key,
consumer: con, consumer: con,
} }
if len(firstNode) > 0 && firstNode[0] {
tm.FirstNode = key
}
} }
func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) {
@@ -125,11 +130,51 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) {
fromNode.Edges = append(fromNode.Edges, edge) fromNode.Edges = append(fromNode.Edges, edge)
} }
func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) mq.Result { func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result {
val := ctx.Value("initial_node")
initialNode, ok := val.(string)
if !ok {
if tm.FirstNode == "" {
firstNode := tm.FindInitialNode()
if firstNode != nil {
tm.FirstNode = firstNode.Key
}
}
if tm.FirstNode == "" {
return mq.Result{Error: fmt.Errorf("initial node not found")}
}
initialNode = tm.FirstNode
}
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
taskID := xid.New().String() taskID := xid.New().String()
manager := NewTaskManager(tm, taskID) manager := NewTaskManager(tm, taskID)
tm.taskContext[taskID] = manager tm.taskContext[taskID] = manager
return manager.processTask(ctx, node, payload) return manager.processTask(ctx, initialNode, payload)
}
func (tm *DAG) FindInitialNode() *Node {
incomingEdges := make(map[string]bool)
connectedNodes := make(map[string]bool)
for _, node := range tm.Nodes {
for _, edge := range node.Edges {
if edge.Type.IsValid() {
connectedNodes[node.Key] = true
connectedNodes[edge.To.Key] = true
incomingEdges[edge.To.Key] = true
}
}
if cond, ok := tm.conditions[node.Key]; ok {
for _, target := range cond {
connectedNodes[target] = true
incomingEdges[target] = true
}
}
}
for nodeID, node := range tm.Nodes {
if !incomingEdges[nodeID] && connectedNodes[nodeID] {
return node
}
}
return nil
} }

View File

@@ -32,42 +32,63 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
} }
} }
func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
tm.done = make(chan struct{})
tm.wg.Add(1)
go tm.processNode(ctx, node, payload)
go func() {
tm.wg.Wait()
close(tm.done)
}()
select {
case <-ctx.Done():
return mq.Result{Error: ctx.Err()}
case <-tm.done:
tm.mutex.Lock()
defer tm.mutex.Unlock()
if len(tm.results) == 1 {
return tm.handleResult(ctx, tm.results[0])
}
return tm.handleResult(ctx, tm.results)
}
}
func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
tm.finalResult = make(chan mq.Result)
tm.wg.Add(1)
go tm.processNode(ctx, node, payload)
go func() {
tm.wg.Wait()
}()
select {
case result := <-tm.finalResult: // Block until a result is available
return result
case <-ctx.Done(): // Handle context cancellation
return mq.Result{Error: ctx.Err()}
}
}
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
node, ok := tm.dag.Nodes[nodeID] node, ok := tm.dag.Nodes[nodeID]
if !ok { if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
} }
if tm.dag.server.SyncMode() { if tm.dag.server.SyncMode() {
tm.done = make(chan struct{}) return tm.handleSyncTask(ctx, node, payload)
tm.wg.Add(1) }
go tm.processNode(ctx, node, payload) return tm.handleAsyncTask(ctx, node, payload)
go func() { }
tm.wg.Wait()
close(tm.done) func (tm *TaskManager) dispatchFinalResult(ctx context.Context) {
}() if !tm.dag.server.SyncMode() {
select { var rs mq.Result
case <-ctx.Done(): if len(tm.results) == 1 {
return mq.Result{Error: ctx.Err()} rs = tm.handleResult(ctx, tm.results[0])
case <-tm.done: } else {
tm.mutex.Lock() rs = tm.handleResult(ctx, tm.results)
defer tm.mutex.Unlock()
if len(tm.results) == 1 {
return tm.handleResult(ctx, tm.results[0])
}
return tm.handleResult(ctx, tm.results)
} }
} else { if tm.waitingCallback == 0 {
tm.finalResult = make(chan mq.Result) tm.finalResult <- rs
tm.wg.Add(1)
go tm.processNode(ctx, node, payload)
go func() {
tm.wg.Wait()
}()
select {
case result := <-tm.finalResult: // Block until a result is available
return result
case <-ctx.Done(): // Handle context cancellation
return mq.Result{Error: ctx.Err()}
} }
} }
} }
@@ -93,17 +114,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
} }
if len(edges) == 0 { if len(edges) == 0 {
tm.appendFinalResult(result) tm.appendFinalResult(result)
if !tm.dag.server.SyncMode() { tm.dispatchFinalResult(ctx)
var rs mq.Result
if len(tm.results) == 1 {
rs = tm.handleResult(ctx, tm.results[0])
} else {
rs = tm.handleResult(ctx, tm.results)
}
if tm.waitingCallback == 0 {
tm.finalResult <- rs
}
}
return result return result
} }
for _, edge := range edges { for _, edge := range edges {
@@ -205,3 +216,10 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
tm.mutex.Unlock() tm.mutex.Unlock()
tm.handleCallback(ctx, result) tm.handleCallback(ctx, result)
} }
func (tm *TaskManager) Clear() error {
tm.waitingCallback = 0
clear(tm.results)
tm.nodeResults = make(map[string]mq.Result)
return nil
}