diff --git a/broker.go b/broker.go index f35ef57..36ffa1c 100644 --- a/broker.go +++ b/broker.go @@ -151,6 +151,9 @@ func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec if err := b.send(conn, ack); err != nil { log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) } + if b.opts.consumerSubscribeHandler != nil { + b.opts.consumerSubscribeHandler(ctx, msg.Queue, consumerID) + } go func() { select { case <-ctx.Done(): diff --git a/dag/dag.go b/dag/dag.go index 8e83e6a..97c4f60 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -4,13 +4,12 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/xid" "log" "net/http" "sync" "time" - "github.com/oarkflow/xid" - "github.com/oarkflow/mq" ) @@ -33,6 +32,7 @@ const ( type Node struct { Key string Edges []Edge + isReady bool consumer *mq.Consumer } @@ -57,7 +57,7 @@ func NewDAG(opts ...mq.Option) *DAG { taskContext: make(map[string]*TaskManager), conditions: make(map[string]map[string]string), } - opts = append(opts, mq.WithCallback(d.onTaskCallback)) + opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerSubscribe(d.onConsumerJoin)) d.server = mq.NewBroker(opts...) return d } @@ -69,6 +69,13 @@ func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { return mq.Result{} } +func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { + if node, ok := tm.Nodes[topic]; ok { + log.Printf("Consumer is ready on %s", topic) + node.isReady = true + } +} + func (tm *DAG) Start(ctx context.Context, addr string) error { if !tm.server.SyncMode() { go func() { @@ -78,10 +85,14 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { } }() for _, con := range tm.Nodes { - go func(con *Node) { - time.Sleep(1 * time.Second) - con.consumer.Consume(ctx) - }(con) + if con.isReady { + go func(con *Node) { + time.Sleep(1 * time.Second) + con.consumer.Consume(ctx) + }(con) + } else { + log.Printf("[WARNING] - %s is not ready yet", con.Key) + } } } @@ -100,12 +111,39 @@ func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) { tm.Nodes[key] = &Node{ Key: key, consumer: con, + isReady: true, } if len(firstNode) > 0 && firstNode[0] { tm.FirstNode = key } } +func (tm *DAG) AddDeferredNode(key string, firstNode ...bool) error { + if tm.server.SyncMode() { + return fmt.Errorf("DAG cannot have deferred node in Sync Mode") + } + tm.mu.Lock() + defer tm.mu.Unlock() + tm.Nodes[key] = &Node{ + Key: key, + } + if len(firstNode) > 0 && firstNode[0] { + tm.FirstNode = key + } + return nil +} + +func (tm *DAG) IsReady() bool { + tm.mu.Lock() + defer tm.mu.Unlock() + for _, node := range tm.Nodes { + if !node.isReady { + return false + } + } + return true +} + func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { tm.mu.Lock() defer tm.mu.Unlock() @@ -131,6 +169,9 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { } func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { + if !tm.IsReady() { + return mq.Result{Error: fmt.Errorf("DAG is not ready yet")} + } val := ctx.Value("initial_node") initialNode, ok := val.(string) if !ok { diff --git a/examples/consumer.go b/examples/consumer.go index 7b9575e..1ba0934 100644 --- a/examples/consumer.go +++ b/examples/consumer.go @@ -9,9 +9,6 @@ import ( ) func main() { - consumer1 := mq.NewConsumer("consumer-1", "queue1", tasks.Node1) - consumer2 := mq.NewConsumer("consumer-2", "queue2", tasks.Node2) - // consumer := mq.NewConsumer("consumer-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key")) - go consumer1.Consume(context.Background()) - consumer2.Consume(context.Background()) + consumer1 := mq.NewConsumer("F", "F", tasks.Node6) + consumer1.Consume(context.Background()) } diff --git a/examples/dag.go b/examples/dag.go index 527185c..b8c8a83 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "fmt" "github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/examples/tasks" "io" @@ -23,7 +24,10 @@ func main() { d.AddNode("C", tasks.Node3) d.AddNode("D", tasks.Node4) d.AddNode("E", tasks.Node5) - d.AddNode("F", tasks.Node6) + err := d.AddDeferredNode("F") + if err != nil { + panic(err) + } d.AddEdge("A", "B", dag.LoopEdge) d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) d.AddEdge("B", "C") @@ -31,7 +35,7 @@ func main() { d.AddEdge("E", "F") http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) - err := d.Start(context.TODO(), ":8083") + err = d.Start(context.TODO(), ":8083") if err != nil { panic(err) } @@ -62,6 +66,10 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ } // ctx = context.WithValue(ctx, "initial_node", "E") rs := d.ProcessTask(ctx, payload) + if rs.Error != nil { + http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(rs) } diff --git a/options.go b/options.go index 5d278d2..c48855c 100644 --- a/options.go +++ b/options.go @@ -57,19 +57,20 @@ type TLSConfig struct { } type Options struct { - syncMode bool - brokerAddr string - callback []func(context.Context, Result) Result - maxRetries int - notifyResponse func(context.Context, Result) - initialDelay time.Duration - maxBackoff time.Duration - jitterPercent float64 - tlsConfig TLSConfig - aesKey json.RawMessage - hmacKey json.RawMessage - enableEncryption bool - queueSize int + syncMode bool + brokerAddr string + callback []func(context.Context, Result) Result + maxRetries int + consumerSubscribeHandler func(ctx context.Context, topic, consumerName string) + notifyResponse func(context.Context, Result) + initialDelay time.Duration + maxBackoff time.Duration + jitterPercent float64 + tlsConfig TLSConfig + aesKey json.RawMessage + hmacKey json.RawMessage + enableEncryption bool + queueSize int } func defaultOptions() Options { @@ -101,6 +102,12 @@ func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option } } +func WithConsumerSubscribe(handler func(ctx context.Context, topic, consumerName string)) Option { + return func(opts *Options) { + opts.consumerSubscribeHandler = handler + } +} + func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { return func(opts *Options) { opts.aesKey = aesKey