From d910a2656b8c5da69515411b2c38bb31e66d3eee Mon Sep 17 00:00:00 2001 From: sujit Date: Sun, 13 Oct 2024 23:19:32 +0545 Subject: [PATCH] feat: update --- consumer.go | 2 +- dag/dag.go | 61 +++++++++++++++++++++++++--------------- dag/task_manager.go | 6 ++-- examples/dag_consumer.go | 2 +- options.go | 12 ++++++++ pool.go | 7 ++--- 6 files changed, 57 insertions(+), 33 deletions(-) diff --git a/consumer.go b/consumer.go index 8ff8426..d0ba460 100644 --- a/consumer.go +++ b/consumer.go @@ -223,7 +223,7 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } - c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.conn) + c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse) if err := c.subscribe(ctx, c.queue); err != nil { return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) } diff --git a/dag/dag.go b/dag/dag.go index 4134d58..fb53e25 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -7,8 +7,7 @@ import ( "log" "net/http" "sync" - - "github.com/oarkflow/xid" + "time" "github.com/oarkflow/mq" "github.com/oarkflow/mq/consts" @@ -16,9 +15,9 @@ import ( func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task { if id == "" { - id = xid.New().String() + id = mq.NewID() } - return &mq.Task{ID: id, Payload: payload, Topic: nodeKey} + return &mq.Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()} } type EdgeType int @@ -72,6 +71,7 @@ type DAG struct { mu sync.RWMutex paused bool opts []mq.Option + pool *mq.Pool } func (tm *DAG) Consume(ctx context.Context) error { @@ -112,6 +112,10 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG { opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) d.server = mq.NewBroker(opts...) d.opts = opts + d.pool = mq.NewPool(d.server.Options().NumOfWorkers(), d.server.Options().QueueSize(), d.server.Options().MaxMemoryLoad(), d.ProcessTask, func(ctx context.Context, result mq.Result) error { + return nil + }) + d.pool.Start(d.server.Options().NumOfWorkers()) return d } @@ -258,8 +262,9 @@ func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string) func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { tm.mu.Lock() defer tm.mu.Unlock() - taskID := xid.New().String() + taskID := mq.NewID() manager := NewTaskManager(tm, taskID) + manager.createdAt = task.CreatedAt tm.taskContext[taskID] = manager if tm.consumer != nil { initialNode, err := tm.parseInitialNode(ctx) @@ -271,6 +276,34 @@ func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { return manager.processTask(ctx, task.Topic, task.Payload) } +func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { + tm.mu.RLock() + if tm.paused { + tm.mu.RUnlock() + return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} + } + tm.mu.RUnlock() + if !tm.IsReady() { + return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")} + } + initialNode, err := tm.parseInitialNode(ctx) + if err != nil { + return mq.Result{Error: err} + } + task := NewTask(mq.NewID(), payload, initialNode) + awaitResponse, _ := mq.GetAwaitResponse(ctx) + if awaitResponse != "true" { + headers, ok := mq.GetHeaders(ctx) + ctxx := context.Background() + if ok { + ctxx = mq.SetHeaders(ctxx, headers.AsMap()) + } + tm.pool.AddTask(ctxx, task) + return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"} + } + return tm.ProcessTask(ctx, task) +} + func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { val := ctx.Value("initial_node") initialNode, ok := val.(string) @@ -290,24 +323,6 @@ func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { return tm.startNode, nil } -func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { - tm.mu.RLock() - if tm.paused { - tm.mu.RUnlock() - return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")} - } - tm.mu.RUnlock() - if !tm.IsReady() { - return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")} - } - initialNode, err := tm.parseInitialNode(ctx) - if err != nil { - return mq.Result{Error: err} - } - task := NewTask(xid.New().String(), payload, initialNode) - return tm.ProcessTask(ctx, task) -} - func (tm *DAG) findStartNode() *Node { incomingEdges := make(map[string]bool) connectedNodes := make(map[string]bool) diff --git a/dag/task_manager.go b/dag/task_manager.go index e5a9f6d..6a81e9f 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -19,7 +19,6 @@ type TaskManager struct { processedAt time.Time results []mq.Result nodeResults map[string]mq.Result - finalResult chan mq.Result wg *WaitGroup } @@ -29,7 +28,6 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager { nodeResults: make(map[string]mq.Result), results: make([]mq.Result, 0), taskID: taskID, - finalResult: make(chan mq.Result, 1), wg: NewWaitGroup(), } } @@ -44,7 +42,9 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j if !ok { return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} } - tm.createdAt = time.Now() + if tm.createdAt.IsZero() { + tm.createdAt = time.Now() + } tm.wg.Add(1) go func() { go tm.processNode(ctx, node, payload) diff --git a/examples/dag_consumer.go b/examples/dag_consumer.go index b0f59eb..388e8d2 100644 --- a/examples/dag_consumer.go +++ b/examples/dag_consumer.go @@ -11,7 +11,7 @@ import ( func main() { d := dag.NewDAG("Sample DAG", "sample-dag", - // mq.WithSyncMode(true), + mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse), mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")), ) diff --git a/options.go b/options.go index 11533a8..083b477 100644 --- a/options.go +++ b/options.go @@ -85,6 +85,18 @@ func (o *Options) SetSyncMode(sync bool) { o.syncMode = sync } +func (o *Options) NumOfWorkers() int { + return o.numOfWorkers +} + +func (o *Options) QueueSize() int { + return o.queueSize +} + +func (o *Options) MaxMemoryLoad() int64 { + return o.maxMemoryLoad +} + func defaultOptions() *Options { return &Options{ brokerAddr: ":8080", diff --git a/pool.go b/pool.go index 7f617d8..8f8f755 100644 --- a/pool.go +++ b/pool.go @@ -3,7 +3,6 @@ package mq import ( "context" "fmt" - "net" "sync" "sync/atomic" @@ -18,7 +17,6 @@ type QueueTask struct { type Callback func(ctx context.Context, result Result) error type Pool struct { - conn net.Conn taskQueue chan QueueTask stop chan struct{} handler Handler @@ -37,7 +35,7 @@ func NewPool( numOfWorkers, taskQueueSize int, maxMemoryLoad int64, handler Handler, - callback Callback, conn net.Conn) *Pool { + callback Callback) *Pool { pool := &Pool{ numOfWorkers: int32(numOfWorkers), taskQueue: make(chan QueueTask, taskQueueSize), @@ -45,10 +43,9 @@ func NewPool( maxMemoryLoad: maxMemoryLoad, handler: handler, callback: callback, - conn: conn, workerAdjust: make(chan int), } - pool.Start(int(numOfWorkers)) + pool.Start(numOfWorkers) return pool }