feat: update

This commit is contained in:
sujit
2024-10-13 23:19:32 +05:45
parent 20936bebca
commit d910a2656b
6 changed files with 57 additions and 33 deletions

View File

@@ -223,7 +223,7 @@ func (c *Consumer) Consume(ctx context.Context) error {
if err != nil {
return err
}
c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.conn)
c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse)
if err := c.subscribe(ctx, c.queue); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
}

View File

@@ -7,8 +7,7 @@ import (
"log"
"net/http"
"sync"
"github.com/oarkflow/xid"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
@@ -16,9 +15,9 @@ import (
func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task {
if id == "" {
id = xid.New().String()
id = mq.NewID()
}
return &mq.Task{ID: id, Payload: payload, Topic: nodeKey}
return &mq.Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()}
}
type EdgeType int
@@ -72,6 +71,7 @@ type DAG struct {
mu sync.RWMutex
paused bool
opts []mq.Option
pool *mq.Pool
}
func (tm *DAG) Consume(ctx context.Context) error {
@@ -112,6 +112,10 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG {
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
d.server = mq.NewBroker(opts...)
d.opts = opts
d.pool = mq.NewPool(d.server.Options().NumOfWorkers(), d.server.Options().QueueSize(), d.server.Options().MaxMemoryLoad(), d.ProcessTask, func(ctx context.Context, result mq.Result) error {
return nil
})
d.pool.Start(d.server.Options().NumOfWorkers())
return d
}
@@ -258,8 +262,9 @@ func (tm *DAG) addEdge(edgeType EdgeType, label, from string, targets ...string)
func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
tm.mu.Lock()
defer tm.mu.Unlock()
taskID := xid.New().String()
taskID := mq.NewID()
manager := NewTaskManager(tm, taskID)
manager.createdAt = task.CreatedAt
tm.taskContext[taskID] = manager
if tm.consumer != nil {
initialNode, err := tm.parseInitialNode(ctx)
@@ -271,6 +276,34 @@ func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
return manager.processTask(ctx, task.Topic, task.Payload)
}
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
tm.mu.RLock()
if tm.paused {
tm.mu.RUnlock()
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
tm.mu.RUnlock()
if !tm.IsReady() {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
}
initialNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err}
}
task := NewTask(mq.NewID(), payload, initialNode)
awaitResponse, _ := mq.GetAwaitResponse(ctx)
if awaitResponse != "true" {
headers, ok := mq.GetHeaders(ctx)
ctxx := context.Background()
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
tm.pool.AddTask(ctxx, task)
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"}
}
return tm.ProcessTask(ctx, task)
}
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
val := ctx.Value("initial_node")
initialNode, ok := val.(string)
@@ -290,24 +323,6 @@ func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
return tm.startNode, nil
}
func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
tm.mu.RLock()
if tm.paused {
tm.mu.RUnlock()
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
tm.mu.RUnlock()
if !tm.IsReady() {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
}
initialNode, err := tm.parseInitialNode(ctx)
if err != nil {
return mq.Result{Error: err}
}
task := NewTask(xid.New().String(), payload, initialNode)
return tm.ProcessTask(ctx, task)
}
func (tm *DAG) findStartNode() *Node {
incomingEdges := make(map[string]bool)
connectedNodes := make(map[string]bool)

View File

@@ -19,7 +19,6 @@ type TaskManager struct {
processedAt time.Time
results []mq.Result
nodeResults map[string]mq.Result
finalResult chan mq.Result
wg *WaitGroup
}
@@ -29,7 +28,6 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
nodeResults: make(map[string]mq.Result),
results: make([]mq.Result, 0),
taskID: taskID,
finalResult: make(chan mq.Result, 1),
wg: NewWaitGroup(),
}
}
@@ -44,7 +42,9 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j
if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
}
tm.createdAt = time.Now()
if tm.createdAt.IsZero() {
tm.createdAt = time.Now()
}
tm.wg.Add(1)
go func() {
go tm.processNode(ctx, node, payload)

View File

@@ -11,7 +11,7 @@ import (
func main() {
d := dag.NewDAG("Sample DAG", "sample-dag",
// mq.WithSyncMode(true),
mq.WithSyncMode(true),
mq.WithNotifyResponse(tasks.NotifyResponse),
mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")),
)

View File

@@ -85,6 +85,18 @@ func (o *Options) SetSyncMode(sync bool) {
o.syncMode = sync
}
func (o *Options) NumOfWorkers() int {
return o.numOfWorkers
}
func (o *Options) QueueSize() int {
return o.queueSize
}
func (o *Options) MaxMemoryLoad() int64 {
return o.maxMemoryLoad
}
func defaultOptions() *Options {
return &Options{
brokerAddr: ":8080",

View File

@@ -3,7 +3,6 @@ package mq
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
@@ -18,7 +17,6 @@ type QueueTask struct {
type Callback func(ctx context.Context, result Result) error
type Pool struct {
conn net.Conn
taskQueue chan QueueTask
stop chan struct{}
handler Handler
@@ -37,7 +35,7 @@ func NewPool(
numOfWorkers, taskQueueSize int,
maxMemoryLoad int64,
handler Handler,
callback Callback, conn net.Conn) *Pool {
callback Callback) *Pool {
pool := &Pool{
numOfWorkers: int32(numOfWorkers),
taskQueue: make(chan QueueTask, taskQueueSize),
@@ -45,10 +43,9 @@ func NewPool(
maxMemoryLoad: maxMemoryLoad,
handler: handler,
callback: callback,
conn: conn,
workerAdjust: make(chan int),
}
pool.Start(int(numOfWorkers))
pool.Start(numOfWorkers)
return pool
}