diff --git a/broker.go b/broker.go index 8967f47..c840e63 100644 --- a/broker.go +++ b/broker.go @@ -129,6 +129,10 @@ func (b *Broker) Send(ctx context.Context, cmd Command) error { return nil } +func (b *Broker) SyncMode() bool { + return b.opts.syncMode +} + func (b *Broker) sendToPublisher(ctx context.Context, publisherID string, result Result) error { pub, ok := b.publishers.Get(publisherID) if !ok { diff --git a/dag/dag.go b/dag/dag.go index 5cc7b53..51948fa 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -3,6 +3,7 @@ package dag import ( "context" "encoding/json" + "fmt" "sync" "github.com/oarkflow/mq" @@ -17,6 +18,7 @@ type taskContext struct { } type DAG struct { + FirstNode string server *mq.Broker nodes map[string]*mq.Consumer edges map[string][]string @@ -39,8 +41,11 @@ func New(opts ...mq.Option) *DAG { return d } -func (d *DAG) AddNode(name string, handler mq.Handler) { +func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { con := mq.NewConsumer(name) + if len(firstNode) > 0 { + d.FirstNode = name + } con.RegisterHandler(name, handler) d.nodes[name] = con } @@ -54,6 +59,15 @@ func (d *DAG) AddLoop(fromNode string, toNode ...string) { } func (d *DAG) Start(ctx context.Context) error { + if d.FirstNode == "" { + firstNode, ok := d.FindFirstNode() + if ok && firstNode != "" { + d.FirstNode = firstNode + } + } + if d.server.SyncMode() { + return nil + } for _, con := range d.nodes { go con.Consume(ctx) } @@ -70,11 +84,37 @@ func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, return d.server.Publish(ctx, task, queueName) } +func (d *DAG) FindFirstNode() (string, bool) { + inDegree := make(map[string]int) + for n, _ := range d.nodes { + inDegree[n] = 0 + } + for _, targets := range d.edges { + for _, outNode := range targets { + inDegree[outNode]++ + } + } + for _, targets := range d.loopEdges { + for _, outNode := range targets { + inDegree[outNode]++ + } + } + for n, count := range inDegree { + if count == 0 { + return n, true + } + } + return "", false +} + func (d *DAG) Send(payload []byte) mq.Result { + if d.FirstNode == "" { + return mq.Result{Error: fmt.Errorf("initial node not defined")} + } resultCh := make(chan mq.Result) - task, err := d.PublishTask(context.TODO(), payload, "queue2") + task, err := d.PublishTask(context.TODO(), payload, d.FirstNode) if err != nil { - panic(err) + return mq.Result{Error: err} } d.mu.Lock() d.taskChMap[task.ID] = resultCh diff --git a/examples/dag.go b/examples/dag.go index 0b7ebc4..b3c0c20 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -43,7 +43,6 @@ func main() { d.AddEdge("queue1", "queue2") d.AddLoop("queue2", "queue3") d.AddEdge("queue2", "queue4") - go func() { time.Sleep(2 * time.Second) finalResult := d.Send([]byte(`[{"user_id": 1}, {"user_id": 2}]`)) diff --git a/examples/publisher.go b/examples/publisher.go index 9145ccc..f1a9e9b 100644 --- a/examples/publisher.go +++ b/examples/publisher.go @@ -13,42 +13,32 @@ func main() { publishSync() } -// publishAsync sends a task in Fire-and-Forget (async) mode func publishAsync() error { taskPayload := map[string]string{"message": "Fire-and-Forget \n Task"} payload, _ := json.Marshal(taskPayload) - task := mq.Task{ Payload: payload, } - - // Create publisher and send the task without waiting for a result publisher := mq.NewPublisher("publish-1") err := publisher.Publish(context.Background(), "queue1", task) if err != nil { return fmt.Errorf("failed to publish async task: %w", err) } - fmt.Println("Async task published successfully") return nil } -// publishSync sends a task in Request/Response (sync) mode func publishSync() error { taskPayload := map[string]string{"message": "Request/Response \n Task"} payload, _ := json.Marshal(taskPayload) - task := mq.Task{ Payload: payload, } - - // Create publisher and send the task, waiting for the result publisher := mq.NewPublisher("publish-2") result, err := publisher.Request(context.Background(), "queue1", task) if err != nil { return fmt.Errorf("failed to publish sync task: %w", err) } - fmt.Printf("Sync task published. Result: %v\n", string(result.Payload)) return nil } diff --git a/options.go b/options.go index dc539e6..cf23d7c 100644 --- a/options.go +++ b/options.go @@ -6,6 +6,7 @@ import ( ) type Options struct { + syncMode bool brokerAddr string messageHandler MessageHandler closeHandler CloseHandler @@ -19,6 +20,7 @@ type Options struct { func defaultOptions() Options { return Options{ + syncMode: true, brokerAddr: ":8080", maxRetries: 5, initialDelay: 2 * time.Second, @@ -37,6 +39,13 @@ func WithBrokerURL(url string) Option { } } +// WithSyncMode - +func WithSyncMode(mode bool) Option { + return func(opts *Options) { + opts.syncMode = mode + } +} + // WithMaxRetries - func WithMaxRetries(val int) Option { return func(opts *Options) {