mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-26 20:11:16 +08:00
493 lines
12 KiB
Go
493 lines
12 KiB
Go
package dag
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/oarkflow/mq"
|
|
"github.com/oarkflow/mq/consts"
|
|
)
|
|
|
|
type EdgeType int
|
|
|
|
func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
|
|
|
|
const (
|
|
Simple EdgeType = iota
|
|
Iterator
|
|
)
|
|
|
|
type Node struct {
|
|
processor mq.Processor
|
|
Name string
|
|
Key string
|
|
Edges []Edge
|
|
isReady bool
|
|
}
|
|
|
|
func (n *Node) ProcessTask(ctx context.Context, msg *mq.Task) mq.Result {
|
|
return n.processor.ProcessTask(ctx, msg)
|
|
}
|
|
|
|
func (n *Node) Close() error {
|
|
return n.processor.Close()
|
|
}
|
|
|
|
type Edge struct {
|
|
Label string
|
|
From *Node
|
|
To []*Node
|
|
Type EdgeType
|
|
}
|
|
|
|
type (
|
|
FromNode string
|
|
When string
|
|
Then string
|
|
)
|
|
|
|
type DAG struct {
|
|
nodes map[string]*Node
|
|
server *mq.Broker
|
|
consumer *mq.Consumer
|
|
taskContext map[string]*TaskManager
|
|
conditions map[FromNode]map[When]Then
|
|
pool *mq.Pool
|
|
taskCleanupCh chan string
|
|
name string
|
|
key string
|
|
startNode string
|
|
consumerTopic string
|
|
opts []mq.Option
|
|
mu sync.RWMutex
|
|
paused bool
|
|
}
|
|
|
|
func (tm *DAG) SetKey(key string) {
|
|
tm.key = key
|
|
}
|
|
|
|
func (tm *DAG) GetType() string {
|
|
return tm.key
|
|
}
|
|
|
|
func (tm *DAG) listenForTaskCleanup() {
|
|
for taskID := range tm.taskCleanupCh {
|
|
if tm.server.Options().CleanTaskOnComplete() {
|
|
tm.taskCleanup(taskID)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *DAG) taskCleanup(taskID string) {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
delete(tm.taskContext, taskID)
|
|
log.Printf("DAG - Task %s cleaned up", taskID)
|
|
}
|
|
|
|
func (tm *DAG) Consume(ctx context.Context) error {
|
|
if tm.consumer != nil {
|
|
tm.server.Options().SetSyncMode(true)
|
|
return tm.consumer.Consume(ctx)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) Stop(ctx context.Context) error {
|
|
for _, n := range tm.nodes {
|
|
err := n.processor.Stop(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) GetKey() string {
|
|
return tm.key
|
|
}
|
|
|
|
func (tm *DAG) AssignTopic(topic string) {
|
|
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false))
|
|
tm.consumerTopic = topic
|
|
}
|
|
|
|
func NewDAG(name, key string, opts ...mq.Option) *DAG {
|
|
callback := func(ctx context.Context, result mq.Result) error { return nil }
|
|
d := &DAG{
|
|
name: name,
|
|
key: key,
|
|
nodes: make(map[string]*Node),
|
|
taskContext: make(map[string]*TaskManager),
|
|
conditions: make(map[FromNode]map[When]Then),
|
|
taskCleanupCh: make(chan string),
|
|
}
|
|
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
|
|
d.server = mq.NewBroker(opts...)
|
|
d.opts = opts
|
|
options := d.server.Options()
|
|
d.pool = mq.NewPool(options.NumOfWorkers(), options.QueueSize(), options.MaxMemoryLoad(), d.ProcessTask, callback, options.Storage())
|
|
d.pool.Start(d.server.Options().NumOfWorkers())
|
|
go d.listenForTaskCleanup()
|
|
return d
|
|
}
|
|
|
|
func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) {
|
|
if tm.consumer != nil {
|
|
result.Topic = tm.consumerTopic
|
|
if tm.consumer.Conn() == nil {
|
|
tm.onTaskCallback(ctx, result)
|
|
} else {
|
|
tm.consumer.OnResponse(ctx, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
|
|
if taskContext, ok := tm.taskContext[result.TaskID]; ok && result.Topic != "" {
|
|
return taskContext.handleCallback(ctx, result)
|
|
}
|
|
return mq.Result{}
|
|
}
|
|
|
|
func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
|
|
if node, ok := tm.nodes[topic]; ok {
|
|
log.Printf("DAG - CONSUMER ~> ready on %s", topic)
|
|
node.isReady = true
|
|
}
|
|
}
|
|
|
|
func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) {
|
|
if node, ok := tm.nodes[topic]; ok {
|
|
log.Printf("DAG - CONSUMER ~> down on %s", topic)
|
|
node.isReady = false
|
|
}
|
|
}
|
|
|
|
func (tm *DAG) SetStartNode(node string) {
|
|
tm.startNode = node
|
|
}
|
|
|
|
func (tm *DAG) GetStartNode() string {
|
|
return tm.startNode
|
|
}
|
|
|
|
func (tm *DAG) Start(ctx context.Context, addr string) error {
|
|
if !tm.server.SyncMode() {
|
|
go func() {
|
|
err := tm.server.Start(ctx)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}()
|
|
for _, con := range tm.nodes {
|
|
go func(con *Node) {
|
|
limiter := rate.NewLimiter(rate.Every(1*time.Second), 1) // Retry every second
|
|
for {
|
|
err := con.processor.Consume(ctx)
|
|
if err != nil {
|
|
log.Printf("[ERROR] - Consumer %s failed to start: %v", con.Key, err)
|
|
} else {
|
|
log.Printf("[INFO] - Consumer %s started successfully", con.Key)
|
|
break
|
|
}
|
|
limiter.Wait(ctx) // Wait with rate limiting before retrying
|
|
}
|
|
}(con)
|
|
}
|
|
}
|
|
log.Printf("DAG - HTTP_SERVER ~> started on %s", addr)
|
|
tm.Handlers()
|
|
config := tm.server.TLSConfig()
|
|
if config.UseTLS {
|
|
return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
|
|
}
|
|
return http.ListenAndServe(addr, nil)
|
|
}
|
|
|
|
func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) {
|
|
dag.AssignTopic(key)
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
tm.nodes[key] = &Node{
|
|
Name: name,
|
|
Key: key,
|
|
processor: dag,
|
|
isReady: true,
|
|
}
|
|
if len(firstNode) > 0 && firstNode[0] {
|
|
tm.startNode = key
|
|
}
|
|
}
|
|
|
|
func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
con := mq.NewConsumer(key, key, handler.ProcessTask, tm.opts...)
|
|
tm.nodes[key] = &Node{
|
|
Name: name,
|
|
Key: key,
|
|
processor: con,
|
|
isReady: true,
|
|
}
|
|
if len(firstNode) > 0 && firstNode[0] {
|
|
tm.startNode = key
|
|
}
|
|
return tm
|
|
}
|
|
|
|
func (tm *DAG) AddDeferredNode(name, key string, firstNode ...bool) error {
|
|
if tm.server.SyncMode() {
|
|
return fmt.Errorf("DAG cannot have deferred node in Sync Mode")
|
|
}
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
tm.nodes[key] = &Node{
|
|
Name: name,
|
|
Key: key,
|
|
}
|
|
if len(firstNode) > 0 && firstNode[0] {
|
|
tm.startNode = key
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) IsReady() bool {
|
|
tm.mu.RLock()
|
|
defer tm.mu.RUnlock()
|
|
for _, node := range tm.nodes {
|
|
if !node.isReady {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (tm *DAG) AddCondition(fromNode FromNode, conditions map[When]Then) *DAG {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
tm.conditions[fromNode] = conditions
|
|
return tm
|
|
}
|
|
|
|
func (tm *DAG) AddIterator(label, from string, targets ...string) *DAG {
|
|
tm.addEdge(Iterator, label, from, targets...)
|
|
return tm
|
|
}
|
|
|
|
func (tm *DAG) AddEdge(label, from string, targets ...string) *DAG {
|
|
tm.addEdge(Simple, label, from, targets...)
|
|
return tm
|
|
}
|
|
|
|
func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
fromNode, ok := tm.nodes[from]
|
|
if !ok {
|
|
return
|
|
}
|
|
var nodes []*Node
|
|
for _, target := range targets {
|
|
toNode, ok := tm.nodes[target]
|
|
if !ok {
|
|
return
|
|
}
|
|
nodes = append(nodes, toNode)
|
|
}
|
|
edge := Edge{From: fromNode, To: nodes, Type: edgeType, Label: label}
|
|
fromNode.Edges = append(fromNode.Edges, edge)
|
|
}
|
|
|
|
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
|
tm.mu.Lock()
|
|
taskID := mq.NewID()
|
|
manager := NewTaskManager(tm, taskID)
|
|
manager.createdAt = task.CreatedAt
|
|
tm.taskContext[taskID] = manager
|
|
tm.mu.Unlock()
|
|
|
|
if tm.consumer != nil {
|
|
initialNode, err := tm.parseInitialNode(ctx)
|
|
if err != nil {
|
|
return mq.Result{Error: err}
|
|
}
|
|
task.Topic = initialNode
|
|
}
|
|
return manager.processTask(ctx, task.Topic, task.Payload)
|
|
}
|
|
|
|
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
|
|
tm.mu.RLock()
|
|
if tm.paused {
|
|
tm.mu.RUnlock()
|
|
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
|
|
}
|
|
tm.mu.RUnlock()
|
|
if !tm.IsReady() {
|
|
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
|
|
}
|
|
initialNode, err := tm.parseInitialNode(ctx)
|
|
if err != nil {
|
|
return mq.Result{Error: err}
|
|
}
|
|
if tm.server.SyncMode() {
|
|
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
|
|
}
|
|
task := mq.NewTask(mq.NewID(), payload, initialNode)
|
|
awaitResponse, _ := mq.GetAwaitResponse(ctx)
|
|
if awaitResponse != "true" {
|
|
headers, ok := mq.GetHeaders(ctx)
|
|
ctxx := context.Background()
|
|
if ok {
|
|
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
|
}
|
|
if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil {
|
|
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "FAILED", Error: err}
|
|
}
|
|
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"}
|
|
}
|
|
return tm.ProcessTask(ctx, task)
|
|
}
|
|
|
|
func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.SchedulerOption) mq.Result {
|
|
tm.mu.RLock()
|
|
if tm.paused {
|
|
tm.mu.RUnlock()
|
|
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
|
|
}
|
|
tm.mu.RUnlock()
|
|
if !tm.IsReady() {
|
|
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
|
|
}
|
|
initialNode, err := tm.parseInitialNode(ctx)
|
|
if err != nil {
|
|
return mq.Result{Error: err}
|
|
}
|
|
if tm.server.SyncMode() {
|
|
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
|
|
}
|
|
task := mq.NewTask(mq.NewID(), payload, initialNode)
|
|
headers, ok := mq.GetHeaders(ctx)
|
|
ctxx := context.Background()
|
|
if ok {
|
|
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
|
|
}
|
|
tm.pool.Scheduler().AddTask(ctxx, task, opts...)
|
|
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"}
|
|
}
|
|
|
|
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
|
|
tm.mu.RLock()
|
|
defer tm.mu.RUnlock()
|
|
|
|
val := ctx.Value("initial_node")
|
|
initialNode, ok := val.(string)
|
|
if ok {
|
|
return initialNode, nil
|
|
}
|
|
if tm.startNode == "" {
|
|
firstNode := tm.findStartNode()
|
|
if firstNode != nil {
|
|
tm.startNode = firstNode.Key
|
|
}
|
|
}
|
|
|
|
if tm.startNode == "" {
|
|
return "", fmt.Errorf("initial node not found")
|
|
}
|
|
return tm.startNode, nil
|
|
}
|
|
|
|
func (tm *DAG) findStartNode() *Node {
|
|
incomingEdges := make(map[string]bool)
|
|
connectedNodes := make(map[string]bool)
|
|
for _, node := range tm.nodes {
|
|
for _, edge := range node.Edges {
|
|
if edge.Type.IsValid() {
|
|
connectedNodes[node.Key] = true
|
|
for _, to := range edge.To {
|
|
connectedNodes[to.Key] = true
|
|
incomingEdges[to.Key] = true
|
|
}
|
|
}
|
|
}
|
|
if cond, ok := tm.conditions[FromNode(node.Key)]; ok {
|
|
for _, target := range cond {
|
|
connectedNodes[string(target)] = true
|
|
incomingEdges[string(target)] = true
|
|
}
|
|
}
|
|
}
|
|
for nodeID, node := range tm.nodes {
|
|
if !incomingEdges[nodeID] && connectedNodes[nodeID] {
|
|
return node
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) Pause(_ context.Context) error {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
tm.paused = true
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) Resume(_ context.Context) error {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
tm.paused = false
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) Close() error {
|
|
for _, n := range tm.nodes {
|
|
err := n.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tm *DAG) PauseConsumer(ctx context.Context, id string) {
|
|
tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE)
|
|
}
|
|
|
|
func (tm *DAG) ResumeConsumer(ctx context.Context, id string) {
|
|
tm.doConsumer(ctx, id, consts.CONSUMER_RESUME)
|
|
}
|
|
|
|
func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) {
|
|
if node, ok := tm.nodes[id]; ok {
|
|
switch action {
|
|
case consts.CONSUMER_PAUSE:
|
|
err := node.processor.Pause(ctx)
|
|
if err == nil {
|
|
node.isReady = false
|
|
log.Printf("[INFO] - Consumer %s paused successfully", node.Key)
|
|
} else {
|
|
log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.Key, err)
|
|
}
|
|
case consts.CONSUMER_RESUME:
|
|
err := node.processor.Resume(ctx)
|
|
if err == nil {
|
|
node.isReady = true
|
|
log.Printf("[INFO] - Consumer %s resumed successfully", node.Key)
|
|
} else {
|
|
log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.Key, err)
|
|
}
|
|
}
|
|
} else {
|
|
log.Printf("[WARNING] - Consumer %s not found", id)
|
|
}
|
|
}
|