feat: add deferred nodes

This commit is contained in:
Oarkflow
2024-10-09 12:25:27 +05:45
parent f9c4a5e731
commit 715aa22471
5 changed files with 83 additions and 27 deletions

View File

@@ -151,6 +151,9 @@ func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec
if err := b.send(conn, ack); err != nil { if err := b.send(conn, ack); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
} }
if b.opts.consumerSubscribeHandler != nil {
b.opts.consumerSubscribeHandler(ctx, msg.Queue, consumerID)
}
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():

View File

@@ -4,13 +4,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/oarkflow/xid"
"log" "log"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/oarkflow/xid"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
) )
@@ -33,6 +32,7 @@ const (
type Node struct { type Node struct {
Key string Key string
Edges []Edge Edges []Edge
isReady bool
consumer *mq.Consumer consumer *mq.Consumer
} }
@@ -57,7 +57,7 @@ func NewDAG(opts ...mq.Option) *DAG {
taskContext: make(map[string]*TaskManager), taskContext: make(map[string]*TaskManager),
conditions: make(map[string]map[string]string), conditions: make(map[string]map[string]string),
} }
opts = append(opts, mq.WithCallback(d.onTaskCallback)) opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerSubscribe(d.onConsumerJoin))
d.server = mq.NewBroker(opts...) d.server = mq.NewBroker(opts...)
return d return d
} }
@@ -69,6 +69,13 @@ func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
return mq.Result{} return mq.Result{}
} }
func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
if node, ok := tm.Nodes[topic]; ok {
log.Printf("Consumer is ready on %s", topic)
node.isReady = true
}
}
func (tm *DAG) Start(ctx context.Context, addr string) error { func (tm *DAG) Start(ctx context.Context, addr string) error {
if !tm.server.SyncMode() { if !tm.server.SyncMode() {
go func() { go func() {
@@ -78,10 +85,14 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
} }
}() }()
for _, con := range tm.Nodes { for _, con := range tm.Nodes {
go func(con *Node) { if con.isReady {
time.Sleep(1 * time.Second) go func(con *Node) {
con.consumer.Consume(ctx) time.Sleep(1 * time.Second)
}(con) con.consumer.Consume(ctx)
}(con)
} else {
log.Printf("[WARNING] - %s is not ready yet", con.Key)
}
} }
} }
@@ -100,12 +111,39 @@ func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) {
tm.Nodes[key] = &Node{ tm.Nodes[key] = &Node{
Key: key, Key: key,
consumer: con, consumer: con,
isReady: true,
} }
if len(firstNode) > 0 && firstNode[0] { if len(firstNode) > 0 && firstNode[0] {
tm.FirstNode = key tm.FirstNode = key
} }
} }
func (tm *DAG) AddDeferredNode(key string, firstNode ...bool) error {
if tm.server.SyncMode() {
return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
}
tm.mu.Lock()
defer tm.mu.Unlock()
tm.Nodes[key] = &Node{
Key: key,
}
if len(firstNode) > 0 && firstNode[0] {
tm.FirstNode = key
}
return nil
}
func (tm *DAG) IsReady() bool {
tm.mu.Lock()
defer tm.mu.Unlock()
for _, node := range tm.Nodes {
if !node.isReady {
return false
}
}
return true
}
func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) {
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
@@ -131,6 +169,9 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) {
} }
func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result {
if !tm.IsReady() {
return mq.Result{Error: fmt.Errorf("DAG is not ready yet")}
}
val := ctx.Value("initial_node") val := ctx.Value("initial_node")
initialNode, ok := val.(string) initialNode, ok := val.(string)
if !ok { if !ok {

View File

@@ -9,9 +9,6 @@ import (
) )
func main() { func main() {
consumer1 := mq.NewConsumer("consumer-1", "queue1", tasks.Node1) consumer1 := mq.NewConsumer("F", "F", tasks.Node6)
consumer2 := mq.NewConsumer("consumer-2", "queue2", tasks.Node2) consumer1.Consume(context.Background())
// consumer := mq.NewConsumer("consumer-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
go consumer1.Consume(context.Background())
consumer2.Consume(context.Background())
} }

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/examples/tasks"
"io" "io"
@@ -23,7 +24,10 @@ func main() {
d.AddNode("C", tasks.Node3) d.AddNode("C", tasks.Node3)
d.AddNode("D", tasks.Node4) d.AddNode("D", tasks.Node4)
d.AddNode("E", tasks.Node5) d.AddNode("E", tasks.Node5)
d.AddNode("F", tasks.Node6) err := d.AddDeferredNode("F")
if err != nil {
panic(err)
}
d.AddEdge("A", "B", dag.LoopEdge) d.AddEdge("A", "B", dag.LoopEdge)
d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"})
d.AddEdge("B", "C") d.AddEdge("B", "C")
@@ -31,7 +35,7 @@ func main() {
d.AddEdge("E", "F") d.AddEdge("E", "F")
http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request")) http.HandleFunc("POST /request", requestHandler("request"))
err := d.Start(context.TODO(), ":8083") err = d.Start(context.TODO(), ":8083")
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -62,6 +66,10 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ
} }
// ctx = context.WithValue(ctx, "initial_node", "E") // ctx = context.WithValue(ctx, "initial_node", "E")
rs := d.ProcessTask(ctx, payload) rs := d.ProcessTask(ctx, payload)
if rs.Error != nil {
http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rs) json.NewEncoder(w).Encode(rs)
} }

View File

@@ -57,19 +57,20 @@ type TLSConfig struct {
} }
type Options struct { type Options struct {
syncMode bool syncMode bool
brokerAddr string brokerAddr string
callback []func(context.Context, Result) Result callback []func(context.Context, Result) Result
maxRetries int maxRetries int
notifyResponse func(context.Context, Result) consumerSubscribeHandler func(ctx context.Context, topic, consumerName string)
initialDelay time.Duration notifyResponse func(context.Context, Result)
maxBackoff time.Duration initialDelay time.Duration
jitterPercent float64 maxBackoff time.Duration
tlsConfig TLSConfig jitterPercent float64
aesKey json.RawMessage tlsConfig TLSConfig
hmacKey json.RawMessage aesKey json.RawMessage
enableEncryption bool hmacKey json.RawMessage
queueSize int enableEncryption bool
queueSize int
} }
func defaultOptions() Options { func defaultOptions() Options {
@@ -101,6 +102,12 @@ func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option
} }
} }
func WithConsumerSubscribe(handler func(ctx context.Context, topic, consumerName string)) Option {
return func(opts *Options) {
opts.consumerSubscribeHandler = handler
}
}
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
return func(opts *Options) { return func(opts *Options) {
opts.aesKey = aesKey opts.aesKey = aesKey