feat: update

This commit is contained in:
sujit
2024-10-13 23:19:32 +05:45
parent 20936bebca
commit d910a2656b
6 changed files with 57 additions and 33 deletions

View File

@@ -223,7 +223,7 @@ func (c *Consumer) Consume(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.conn) c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse)
if err := c.subscribe(ctx, c.queue); err != nil { if err := c.subscribe(ctx, c.queue); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
} }

View File

@@ -7,8 +7,7 @@ import (
"log" "log"
"net/http" "net/http"
"sync" "sync"
"time"
"github.com/oarkflow/xid"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
@@ -16,9 +15,9 @@ import (
func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task { func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task {
if id == "" { if id == "" {
id = xid.New().String() id = mq.NewID()
} }
return &mq.Task{ID: id, Payload: payload, Topic: nodeKey} return &mq.Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()}
} }
type EdgeType int type EdgeType int
@@ -72,6 +71,7 @@ type DAG struct {
mu sync.RWMutex mu sync.RWMutex
paused bool paused bool
opts []mq.Option opts []mq.Option
pool *mq.Pool
} }
func (tm *DAG) Consume(ctx context.Context) error { func (tm *DAG) Consume(ctx context.Context) error {
@@ -112,6 +112,10 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG {
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
d.server = mq.NewBroker(opts...) d.server = mq.NewBroker(opts...)
d.opts = opts d.opts = opts
d.pool = mq.NewPool(d.server.Options().NumOfWorkers(), d.server.Options().QueueSize(), d.server.Options().MaxMemoryLoad(), d.ProcessTask, func(ctx context.Context, result mq.Result) error {
return nil
})
d.pool.Start(d.server.Options().NumOfWorkers())
return d return d
} }
@@ -258,8 +262,9 @@ func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string)
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
taskID := xid.New().String() taskID := mq.NewID()
manager := NewTaskManager(tm, taskID) manager := NewTaskManager(tm, taskID)
manager.createdAt = task.CreatedAt
tm.taskContext[taskID] = manager tm.taskContext[taskID] = manager
if tm.consumer != nil { if tm.consumer != nil {
initialNode, err := tm.parseInitialNode(ctx) initialNode, err := tm.parseInitialNode(ctx)
@@ -271,6 +276,34 @@ func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
return manager.processTask(ctx, task.Topic, task.Payload) return manager.processTask(ctx, task.Topic, task.Payload)
} }
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
tm.mu.RLock()
if tm.paused {
tm.mu.RUnlock()
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
tm.mu.RUnlock()
if !tm.IsReady() {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
}
initialNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err}
}
task := NewTask(mq.NewID(), payload, initialNode)
awaitResponse, _ := mq.GetAwaitResponse(ctx)
if awaitResponse != "true" {
headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background()
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.pool.AddTask(ctxx, task)
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"}
}
return tm.ProcessTask(ctx, task)
}
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
val := ctx.Value("initial_node") val := ctx.Value("initial_node")
initialNode, ok := val.(string) initialNode, ok := val.(string)
@@ -290,24 +323,6 @@ func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
return tm.startNode, nil return tm.startNode, nil
} }
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
tm.mu.RLock()
if tm.paused {
tm.mu.RUnlock()
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
tm.mu.RUnlock()
if !tm.IsReady() {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
}
initialNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err}
}
task := NewTask(xid.New().String(), payload, initialNode)
return tm.ProcessTask(ctx, task)
}
func (tm *DAG) findStartNode() *Node { func (tm *DAG) findStartNode() *Node {
incomingEdges := make(map[string]bool) incomingEdges := make(map[string]bool)
connectedNodes := make(map[string]bool) connectedNodes := make(map[string]bool)

View File

@@ -19,7 +19,6 @@ type TaskManager struct {
processedAt time.Time processedAt time.Time
results []mq.Result results []mq.Result
nodeResults map[string]mq.Result nodeResults map[string]mq.Result
finalResult chan mq.Result
wg *WaitGroup wg *WaitGroup
} }
@@ -29,7 +28,6 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
nodeResults: make(map[string]mq.Result), nodeResults: make(map[string]mq.Result),
results: make([]mq.Result, 0), results: make([]mq.Result, 0),
taskID: taskID, taskID: taskID,
finalResult: make(chan mq.Result, 1),
wg: NewWaitGroup(), wg: NewWaitGroup(),
} }
} }
@@ -44,7 +42,9 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j
if !ok { if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
} }
if tm.createdAt.IsZero() {
tm.createdAt = time.Now() tm.createdAt = time.Now()
}
tm.wg.Add(1) tm.wg.Add(1)
go func() { go func() {
go tm.processNode(ctx, node, payload) go tm.processNode(ctx, node, payload)

View File

@@ -11,7 +11,7 @@ import (
func main() { func main() {
d := dag.NewDAG("Sample DAG", "sample-dag", d := dag.NewDAG("Sample DAG", "sample-dag",
// mq.WithSyncMode(true), mq.WithSyncMode(true),
mq.WithNotifyResponse(tasks.NotifyResponse), mq.WithNotifyResponse(tasks.NotifyResponse),
mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")), mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")),
) )

View File

@@ -85,6 +85,18 @@ func (o *Options) SetSyncMode(sync bool) {
o.syncMode = sync o.syncMode = sync
} }
func (o *Options) NumOfWorkers() int {
return o.numOfWorkers
}
func (o *Options) QueueSize() int {
return o.queueSize
}
func (o *Options) MaxMemoryLoad() int64 {
return o.maxMemoryLoad
}
func defaultOptions() *Options { func defaultOptions() *Options {
return &Options{ return &Options{
brokerAddr: ":8080", brokerAddr: ":8080",

View File

@@ -3,7 +3,6 @@ package mq
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -18,7 +17,6 @@ type QueueTask struct {
type Callback func(ctx context.Context, result Result) error type Callback func(ctx context.Context, result Result) error
type Pool struct { type Pool struct {
conn net.Conn
taskQueue chan QueueTask taskQueue chan QueueTask
stop chan struct{} stop chan struct{}
handler Handler handler Handler
@@ -37,7 +35,7 @@ func NewPool(
numOfWorkers, taskQueueSize int, numOfWorkers, taskQueueSize int,
maxMemoryLoad int64, maxMemoryLoad int64,
handler Handler, handler Handler,
callback Callback, conn net.Conn) *Pool { callback Callback) *Pool {
pool := &Pool{ pool := &Pool{
numOfWorkers: int32(numOfWorkers), numOfWorkers: int32(numOfWorkers),
taskQueue: make(chan QueueTask, taskQueueSize), taskQueue: make(chan QueueTask, taskQueueSize),
@@ -45,10 +43,9 @@ func NewPool(
maxMemoryLoad: maxMemoryLoad, maxMemoryLoad: maxMemoryLoad,
handler: handler, handler: handler,
callback: callback, callback: callback,
conn: conn,
workerAdjust: make(chan int), workerAdjust: make(chan int),
} }
pool.Start(int(numOfWorkers)) pool.Start(numOfWorkers)
return pool return pool
} }