diff --git a/broker.go b/broker.go index 95ef01a..f35ef57 100644 --- a/broker.go +++ b/broker.go @@ -97,16 +97,12 @@ func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) { func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) { msg.Command = consts.RESPONSE - headers, ok := GetHeaders(ctx) - if !ok { - return - } b.HandleCallback(ctx, msg) - awaitResponse, ok := headers[consts.AwaitResponseKey] + awaitResponse, ok := GetAwaitResponse(ctx) if !(ok && awaitResponse == "true") { return } - publisherID, exists := headers[consts.PublisherKey] + publisherID, exists := GetPublisherID(ctx) if !exists { return } @@ -120,13 +116,13 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) } } -func (b *Broker) Publish(ctx context.Context, task Task, queue string) error { +func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error { headers, _ := GetHeaders(ctx) payload, err := json.Marshal(task) if err != nil { return err } - msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers) + msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.headers) b.broadcastToConsumers(ctx, msg) return nil } diff --git a/codec/codec.go b/codec/codec.go index 2d397bf..ac71082 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -31,7 +31,7 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) { var buf bytes.Buffer - // Serialize Headers, Queue, Command, Payload, and Metadata + // Serialize Headers, Topic, Command, Payload, and Metadata if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil { return nil, "", fmt.Errorf("error serializing headers: %v", err) } @@ -62,7 +62,7 @@ func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool buf := bytes.NewReader(data) - // Deserialize Headers, Queue, Command, Payload, and Metadata + // Deserialize Headers, Topic, Command, Payload, and Metadata headers := make(map[string]string) if err := readLengthPrefixedJSON(buf, &headers); err != nil { return nil, fmt.Errorf("error deserializing headers: %v", err) diff --git a/consts/constants.go b/consts/constants.go index 2776c94..9abbc6b 100644 --- a/consts/constants.go +++ b/consts/constants.go @@ -54,7 +54,7 @@ var ( PublisherKey = "Publisher-Key" ContentType = "Content-Type" AwaitResponseKey = "Await-Response" - QueueKey = "Queue" + QueueKey = "Topic" TypeJson = "application/json" HeaderKey = "headers" TriggerNode = "triggerNode" diff --git a/consumer.go b/consumer.go index 2aa2926..26b62ab 100644 --- a/consumer.go +++ b/consumer.go @@ -3,7 +3,6 @@ package mq import ( "context" "encoding/json" - "errors" "fmt" "log" "net" @@ -19,20 +18,21 @@ import ( // Consumer structure to hold consumer-specific configurations and state. type Consumer struct { - id string - handlers map[string]Handler - conn net.Conn - queues []string - opts Options + id string + handler Handler + conn net.Conn + queue string + opts Options } // NewConsumer initializes a new consumer with the provided options. -func NewConsumer(id string, opts ...Option) *Consumer { +func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer { options := setupOptions(opts...) return &Consumer{ - handlers: make(map[string]Handler), - id: id, - opts: options, + id: id, + opts: options, + queue: queue, + handler: handler, } } @@ -89,9 +89,9 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C return } ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) - result := c.ProcessTask(ctx, task) - result.MessageID = task.ID - result.Queue = msg.Queue + result := c.ProcessTask(ctx, &task) + result.TaskID = task.ID + result.Topic = msg.Queue if result.Status == "" { if result.Error != nil { result.Status = "FAILED" @@ -107,13 +107,8 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C } // ProcessTask handles a received task message and invokes the appropriate handler. -func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { - queue, _ := GetQueue(ctx) - handler, exists := c.handlers[queue] - if !exists { - return Result{Error: errors.New("No handler for queue " + queue)} - } - return handler(ctx, msg) +func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result { + return c.handler(ctx, msg) } // AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration. @@ -159,10 +154,9 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } - for _, q := range c.queues { - if err := c.subscribe(ctx, q); err != nil { - return fmt.Errorf("failed to connect to server for queue %s: %v", q, err) - } + + if err := c.subscribe(ctx, c.queue); err != nil { + return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) } var wg sync.WaitGroup wg.Add(1) @@ -191,9 +185,3 @@ func (c *Consumer) waitForAck(conn net.Conn) error { } return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command) } - -// RegisterHandler registers a handler for a queue. -func (c *Consumer) RegisterHandler(queue string, handler Handler) { - c.queues = append(c.queues, queue) - c.handlers[queue] = handler -} diff --git a/ctx.go b/ctx.go index 2817b3e..eaf139a 100644 --- a/ctx.go +++ b/ctx.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "os" + "sync" "time" "github.com/oarkflow/xid" @@ -17,6 +18,7 @@ import ( type Task struct { ID string `json:"id"` + Topic string `json:"topic"` Payload json.RawMessage `json:"payload"` CreatedAt time.Time `json:"created_at"` ProcessedAt time.Time `json:"processed_at"` @@ -24,7 +26,7 @@ type Task struct { Error error `json:"error"` } -type Handler func(context.Context, Task) Result +type Handler func(context.Context, *Task) Result func IsClosed(conn net.Conn) bool { _, err := conn.Read(make([]byte, 1)) @@ -34,87 +36,92 @@ func IsClosed(conn net.Conn) bool { } } return false +} // HeaderMap wraps a map and a mutex for thread-safe access +type HeaderMap struct { + mu sync.RWMutex + headers map[string]string +} + +// NewHeaderMap initializes a new HeaderMap +func NewHeaderMap() *HeaderMap { + return &HeaderMap{ + headers: make(map[string]string), + } } func SetHeaders(ctx context.Context, headers map[string]string) context.Context { - hd, ok := GetHeaders(ctx) - if !ok { - hd = make(map[string]string) + hd, _ := GetHeaders(ctx) + if hd == nil { + hd = NewHeaderMap() } + hd.mu.Lock() + defer hd.mu.Unlock() for key, val := range headers { - hd[key] = val + hd.headers[key] = val } return context.WithValue(ctx, consts.HeaderKey, hd) } func WithHeaders(ctx context.Context, headers map[string]string) map[string]string { - hd, ok := GetHeaders(ctx) - if !ok { - hd = make(map[string]string) + hd, _ := GetHeaders(ctx) + if hd == nil { + hd = NewHeaderMap() } + hd.mu.Lock() + defer hd.mu.Unlock() for key, val := range headers { - hd[key] = val + hd.headers[key] = val } - return hd + return getMapAsRegularMap(hd) } -func GetHeaders(ctx context.Context) (map[string]string, bool) { - headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) +func GetHeaders(ctx context.Context) (*HeaderMap, bool) { + headers, ok := ctx.Value(consts.HeaderKey).(*HeaderMap) return headers, ok } func GetHeader(ctx context.Context, key string) (string, bool) { - headers, ok := ctx.Value(consts.HeaderKey).(map[string]string) + headers, ok := GetHeaders(ctx) if !ok { return "", false } - val, ok := headers[key] + headers.mu.RLock() + defer headers.mu.RUnlock() + val, ok := headers.headers[key] return val, ok } func GetContentType(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.ContentType] - return contentType, ok + return GetHeader(ctx, consts.ContentType) } func GetQueue(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.QueueKey] - return contentType, ok + return GetHeader(ctx, consts.QueueKey) } func GetConsumerID(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.ConsumerKey] - return contentType, ok + return GetHeader(ctx, consts.ConsumerKey) } func GetTriggerNode(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false - } - contentType, ok := headers[consts.TriggerNode] - return contentType, ok + return GetHeader(ctx, consts.TriggerNode) +} + +func GetAwaitResponse(ctx context.Context) (string, bool) { + return GetHeader(ctx, consts.AwaitResponseKey) } func GetPublisherID(ctx context.Context) (string, bool) { - headers, ok := GetHeaders(ctx) - if !ok { - return "", false + return GetHeader(ctx, consts.PublisherKey) +} + +// Helper function to convert HeaderMap to a regular map +func getMapAsRegularMap(hd *HeaderMap) map[string]string { + result := make(map[string]string) + for key, value := range hd.headers { + result[key] = value } - contentType, ok := headers[consts.PublisherKey] - return contentType, ok + return result } func NewID() string { diff --git a/dag/dag.go b/dag/dag.go index c0745ef..8e83e6a 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -9,374 +9,172 @@ import ( "sync" "time" - "github.com/oarkflow/mq/consts" + "github.com/oarkflow/xid" "github.com/oarkflow/mq" ) -type taskContext struct { - totalItems int - completed int - results []json.RawMessage - result json.RawMessage - multipleResults bool +func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task { + if id == "" { + id = xid.New().String() + } + return &mq.Task{ID: id, Payload: payload, Topic: nodeKey} +} + +type EdgeType int + +func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge } + +const ( + SimpleEdge EdgeType = iota + LoopEdge +) + +type Node struct { + Key string + Edges []Edge + consumer *mq.Consumer +} + +type Edge struct { + From *Node + To *Node + Type EdgeType } type DAG struct { FirstNode string + Nodes map[string]*Node server *mq.Broker - nodes map[string]*mq.Consumer - edges map[string]string + taskContext map[string]*TaskManager conditions map[string]map[string]string - loopEdges map[string][]string - taskChMap map[string]chan mq.Result - taskResults map[string]map[string]*taskContext - mu sync.Mutex + mu sync.RWMutex } -func New(opts ...mq.Option) *DAG { +func NewDAG(opts ...mq.Option) *DAG { d := &DAG{ - nodes: make(map[string]*mq.Consumer), - edges: make(map[string]string), + Nodes: make(map[string]*Node), + taskContext: make(map[string]*TaskManager), conditions: make(map[string]map[string]string), - loopEdges: make(map[string][]string), - taskChMap: make(map[string]chan mq.Result), - taskResults: make(map[string]map[string]*taskContext), } - opts = append(opts, mq.WithCallback(d.TaskCallback)) + opts = append(opts, mq.WithCallback(d.onTaskCallback)) d.server = mq.NewBroker(opts...) return d } -func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) { - tlsConfig := d.server.TLSConfig() - con := mq.NewConsumer(name, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath)) - if len(firstNode) > 0 { - d.FirstNode = name +func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result { + if taskContext, ok := tm.taskContext[result.TaskID]; ok && result.Topic != "" { + return taskContext.handleCallback(ctx, result) } - con.RegisterHandler(name, handler) - d.nodes[name] = con + return mq.Result{} } -func (d *DAG) AddCondition(fromNode string, conditions map[string]string) { - d.conditions[fromNode] = conditions -} - -func (d *DAG) AddEdge(fromNode string, toNodes string) { - d.edges[fromNode] = toNodes -} - -func (d *DAG) AddLoop(fromNode string, toNode ...string) { - d.loopEdges[fromNode] = toNode -} - -func (d *DAG) Prepare() { - if d.FirstNode == "" { - firstNode, ok := d.FindFirstNode() - if ok && firstNode != "" { - d.FirstNode = firstNode +func (tm *DAG) Start(ctx context.Context, addr string) error { + if !tm.server.SyncMode() { + go func() { + err := tm.server.Start(ctx) + if err != nil { + panic(err) + } + }() + for _, con := range tm.Nodes { + go func(con *Node) { + time.Sleep(1 * time.Second) + con.consumer.Consume(ctx) + }(con) } } -} -func (d *DAG) Start(ctx context.Context, addr string) error { - d.Prepare() - if d.server.SyncMode() { - return nil - } - go func() { - err := d.server.Start(ctx) - if err != nil { - panic(err) - } - }() - for _, con := range d.nodes { - go func(con *mq.Consumer) { - con.Consume(ctx) - }(con) - } log.Printf("HTTP server started on %s", addr) - config := d.server.TLSConfig() + config := tm.server.TLSConfig() if config.UseTLS { return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) } return http.ListenAndServe(addr, nil) } -func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) mq.Result { - queue, ok := mq.GetQueue(ctx) +func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) { + tm.mu.Lock() + defer tm.mu.Unlock() + con := mq.NewConsumer(key, key, handler) + tm.Nodes[key] = &Node{ + Key: key, + consumer: con, + } + if len(firstNode) > 0 && firstNode[0] { + tm.FirstNode = key + } +} + +func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.conditions[fromNode] = conditions +} + +func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) { + tm.mu.Lock() + defer tm.mu.Unlock() + fromNode, ok := tm.Nodes[from] if !ok { - queue = d.FirstNode + return } - var id string - if len(taskID) > 0 { - id = taskID[0] - } else { - id = mq.NewID() + toNode, ok := tm.Nodes[to] + if !ok { + return } - task := mq.Task{ - ID: id, - Payload: payload, - CreatedAt: time.Now(), - } - err := d.server.Publish(ctx, task, queue) - if err != nil { - return mq.Result{Error: err} - } - return mq.Result{ - Payload: payload, - Queue: queue, - MessageID: id, + edge := Edge{From: fromNode, To: toNode} + if len(edgeTypes) > 0 && edgeTypes[0].IsValid() { + edge.Type = edgeTypes[0] } + fromNode.Edges = append(fromNode.Edges, edge) } -func (d *DAG) FindFirstNode() (string, bool) { - inDegree := make(map[string]int) - for n, _ := range d.nodes { - inDegree[n] = 0 - } - for _, outNode := range d.edges { - inDegree[outNode]++ - } - for _, targets := range d.loopEdges { - for _, outNode := range targets { - inDegree[outNode]++ +func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result { + val := ctx.Value("initial_node") + initialNode, ok := val.(string) + if !ok { + if tm.FirstNode == "" { + firstNode := tm.FindInitialNode() + if firstNode != nil { + tm.FirstNode = firstNode.Key + } } - } - for n, count := range inDegree { - if count == 0 { - return n, true + if tm.FirstNode == "" { + return mq.Result{Error: fmt.Errorf("initial node not found")} } + initialNode = tm.FirstNode } - return "", false + tm.mu.Lock() + defer tm.mu.Unlock() + taskID := xid.New().String() + manager := NewTaskManager(tm, taskID) + tm.taskContext[taskID] = manager + return manager.processTask(ctx, initialNode, payload) } -func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result { - return d.sendSync(ctx, mq.Result{Payload: payload}) -} - -func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result { - if d.FirstNode == "" { - return mq.Result{Error: fmt.Errorf("initial node not defined")} - } - if d.server.SyncMode() { - return d.sendSync(ctx, mq.Result{Payload: payload}) - } - resultCh := make(chan mq.Result) - result := d.PublishTask(ctx, payload) - if result.Error != nil { - return result - } - d.mu.Lock() - d.taskChMap[result.MessageID] = resultCh - d.mu.Unlock() - finalResult := <-resultCh - return finalResult -} - -func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result { - if con, ok := d.nodes[task.Queue]; ok { - return con.ProcessTask(ctx, mq.Task{ - ID: task.MessageID, - Payload: task.Payload, - }) - } - return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)} -} - -func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result { - if task.MessageID == "" { - task.MessageID = mq.NewID() - } - if task.Queue == "" { - task.Queue = d.FirstNode - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: task.Queue, - }) - result := d.processNode(ctx, task) - if result.Error != nil { - return result - } - for _, target := range d.loopEdges[task.Queue] { - var items, results []json.RawMessage - if err := json.Unmarshal(result.Payload, &items); err != nil { - return mq.Result{Error: err} - } - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: item, - Queue: target, - MessageID: result.MessageID, - }) - if result.Error != nil { - return result - } - results = append(results, result.Payload) - } - bt, err := json.Marshal(results) - if err != nil { - return mq.Result{Error: err} - } - result.Payload = bt - } - if conditions, ok := d.conditions[task.Queue]; ok { - if target, exists := conditions[result.Status]; exists { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Queue: target, - MessageID: result.MessageID, - }) - if result.Error != nil { - return result - } - } - } - if target, ok := d.edges[task.Queue]; ok { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - }) - result = d.sendSync(ctx, mq.Result{ - Payload: result.Payload, - Queue: target, - MessageID: result.MessageID, - }) - if result.Error != nil { - return result - } - } - return result -} - -func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) { - var result any - var payload []byte - completed := false - multipleResults := false - if ok && triggeredNode != "" { - taskResults, ok := d.taskResults[task.MessageID] - if ok { - nodeResult, exists := taskResults[triggeredNode] - if exists { - multipleResults = nodeResult.multipleResults - nodeResult.completed++ - if nodeResult.completed == nodeResult.totalItems { - completed = true - } - if multipleResults { - nodeResult.results = append(nodeResult.results, task.Payload) - if completed { - result = nodeResult.results - } - } else { - nodeResult.result = task.Payload - if completed { - result = nodeResult.result - } - } - } - if completed { - delete(taskResults, triggeredNode) - } - } - } - if completed { - payload, _ = json.Marshal(result) - } else { - payload = task.Payload - } - return payload, completed, multipleResults -} - -func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result { - if task.Error != nil { - return mq.Result{Error: task.Error} - } - triggeredNode, ok := mq.GetTriggerNode(ctx) - payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode) - if loopNodes, exists := d.loopEdges[task.Queue]; exists { - var items []json.RawMessage - if err := json.Unmarshal(payload, &items); err != nil { - return mq.Result{Error: task.Error} - } - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { - totalItems: len(items), - multipleResults: true, - }, - } - - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) - for _, loopNode := range loopNodes { - for _, item := range items { - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: loopNode, - }) - result := d.PublishTask(ctx, item, task.MessageID) - if result.Error != nil { - return result - } - } - } - - return task - } - if multipleResults && completed { - task.Queue = triggeredNode - } - if conditions, ok := d.conditions[task.Queue]; ok { - if target, exists := conditions[task.Status]; exists { - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { - totalItems: len(conditions), - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: target, - consts.TriggerNode: task.Queue, - }) - result := d.PublishTask(ctx, payload, task.MessageID) - if result.Error != nil { - return result - } - } - } else { - ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue}) - edge, exists := d.edges[task.Queue] - if exists { - d.taskResults[task.MessageID] = map[string]*taskContext{ - task.Queue: { - totalItems: 1, - }, - } - ctx = mq.SetHeaders(ctx, map[string]string{ - consts.QueueKey: edge, - }) - result := d.PublishTask(ctx, payload, task.MessageID) - if result.Error != nil { - return result - } - } else if completed { - d.mu.Lock() - if resultCh, ok := d.taskChMap[task.MessageID]; ok { - resultCh <- mq.Result{ - Payload: payload, - Queue: task.Queue, - MessageID: task.MessageID, - Status: "done", - } - delete(d.taskChMap, task.MessageID) - delete(d.taskResults, task.MessageID) - } - d.mu.Unlock() - } - } - - return task +func (tm *DAG) FindInitialNode() *Node { + incomingEdges := make(map[string]bool) + connectedNodes := make(map[string]bool) + for _, node := range tm.Nodes { + for _, edge := range node.Edges { + if edge.Type.IsValid() { + connectedNodes[node.Key] = true + connectedNodes[edge.To.Key] = true + incomingEdges[edge.To.Key] = true + } + } + if cond, ok := tm.conditions[node.Key]; ok { + for _, target := range cond { + connectedNodes[target] = true + incomingEdges[target] = true + } + } + } + for nodeID, node := range tm.Nodes { + if !incomingEdges[nodeID] && connectedNodes[nodeID] { + return node + } + } + return nil } diff --git a/dag/task_manager.go b/dag/task_manager.go new file mode 100644 index 0000000..3a2347b --- /dev/null +++ b/dag/task_manager.go @@ -0,0 +1,225 @@ +package dag + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/consts" +) + +type TaskManager struct { + taskID string + dag *DAG + wg sync.WaitGroup + mutex sync.Mutex + results []mq.Result + waitingCallback int64 + nodeResults map[string]mq.Result + done chan struct{} + finalResult chan mq.Result // Channel to collect final results +} + +func NewTaskManager(d *DAG, taskID string) *TaskManager { + return &TaskManager{ + dag: d, + nodeResults: make(map[string]mq.Result), + results: make([]mq.Result, 0), + taskID: taskID, + } +} + +func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result { + tm.done = make(chan struct{}) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + close(tm.done) + }() + select { + case <-ctx.Done(): + return mq.Result{Error: ctx.Err()} + case <-tm.done: + tm.mutex.Lock() + defer tm.mutex.Unlock() + if len(tm.results) == 1 { + return tm.handleResult(ctx, tm.results[0]) + } + return tm.handleResult(ctx, tm.results) + } +} + +func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result { + tm.finalResult = make(chan mq.Result) + tm.wg.Add(1) + go tm.processNode(ctx, node, payload) + go func() { + tm.wg.Wait() + }() + select { + case result := <-tm.finalResult: // Block until a result is available + return result + case <-ctx.Done(): // Handle context cancellation + return mq.Result{Error: ctx.Err()} + } +} + +func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { + node, ok := tm.dag.Nodes[nodeID] + if !ok { + return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} + } + if tm.dag.server.SyncMode() { + return tm.handleSyncTask(ctx, node, payload) + } + return tm.handleAsyncTask(ctx, node, payload) +} + +func (tm *TaskManager) dispatchFinalResult(ctx context.Context) { + if !tm.dag.server.SyncMode() { + var rs mq.Result + if len(tm.results) == 1 { + rs = tm.handleResult(ctx, tm.results[0]) + } else { + rs = tm.handleResult(ctx, tm.results) + } + if tm.waitingCallback == 0 { + tm.finalResult <- rs + } + } +} + +func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result { + if result.Topic != "" { + atomic.AddInt64(&tm.waitingCallback, -1) + } + node, ok := tm.dag.Nodes[result.Topic] + if !ok { + return result + } + edges := make([]Edge, len(node.Edges)) + copy(edges, node.Edges) + if result.Status != "" { + if conditions, ok := tm.dag.conditions[result.Topic]; ok { + if targetNodeKey, ok := conditions[result.Status]; ok { + if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok { + edges = append(edges, Edge{From: node, To: targetNode}) + } + } + } + } + if len(edges) == 0 { + tm.appendFinalResult(result) + tm.dispatchFinalResult(ctx) + return result + } + for _, edge := range edges { + switch edge.Type { + case LoopEdge: + var items []json.RawMessage + err := json.Unmarshal(result.Payload, &items) + if err != nil { + tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err}) + return result + } + for _, item := range items { + tm.wg.Add(1) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) + go tm.processNode(ctx, edge.To, item) + } + case SimpleEdge: + if edge.To != nil { + tm.wg.Add(1) + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) + go tm.processNode(ctx, edge.To, result.Payload) + } + } + } + return mq.Result{} +} + +func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result { + var rs mq.Result + switch res := results.(type) { + case []mq.Result: + aggregatedOutput := make([]json.RawMessage, 0) + status := "" + for i, result := range res { + if i == 0 { + status = result.Status + } + if result.Error != nil { + return mq.HandleError(ctx, result.Error) + } + var item json.RawMessage + err := json.Unmarshal(result.Payload, &item) + if err != nil { + return mq.HandleError(ctx, err) + } + aggregatedOutput = append(aggregatedOutput, item) + } + finalOutput, err := json.Marshal(aggregatedOutput) + if err != nil { + return mq.HandleError(ctx, err) + } + return mq.Result{ + TaskID: tm.taskID, + Payload: finalOutput, + Status: status, + } + case mq.Result: + return res + } + return rs +} + +func (tm *TaskManager) appendFinalResult(result mq.Result) { + tm.mutex.Lock() + tm.results = append(tm.results, result) + tm.nodeResults[result.Topic] = result + tm.mutex.Unlock() +} + +func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { + atomic.AddInt64(&tm.waitingCallback, 1) + defer tm.wg.Done() + var result mq.Result + select { + case <-ctx.Done(): + result = mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: ctx.Err()} + tm.appendFinalResult(result) + return + default: + ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: node.Key}) + if tm.dag.server.SyncMode() { + result = node.consumer.ProcessTask(ctx, NewTask(tm.taskID, payload, node.Key)) + result.Topic = node.Key + result.TaskID = tm.taskID + if result.Error != nil { + tm.appendFinalResult(result) + return + } + } else { + err := tm.dag.server.Publish(ctx, NewTask(tm.taskID, payload, node.Key), node.Key) + if err != nil { + tm.appendFinalResult(mq.Result{Error: err}) + return + } + } + } + tm.mutex.Lock() + tm.nodeResults[node.Key] = result + tm.mutex.Unlock() + tm.handleCallback(ctx, result) +} + +func (tm *TaskManager) Clear() error { + tm.waitingCallback = 0 + clear(tm.results) + tm.nodeResults = make(map[string]mq.Result) + return nil +} diff --git a/examples/consumer.go b/examples/consumer.go index a312348..7b9575e 100644 --- a/examples/consumer.go +++ b/examples/consumer.go @@ -2,15 +2,16 @@ package main import ( "context" + "github.com/oarkflow/mq" "github.com/oarkflow/mq/examples/tasks" ) func main() { - consumer := mq.NewConsumer("consumer-1") + consumer1 := mq.NewConsumer("consumer-1", "queue1", tasks.Node1) + consumer2 := mq.NewConsumer("consumer-2", "queue2", tasks.Node2) // consumer := mq.NewConsumer("consumer-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key")) - consumer.RegisterHandler("queue1", tasks.Node1) - consumer.RegisterHandler("queue2", tasks.Node2) - consumer.Consume(context.Background()) + go consumer1.Consume(context.Background()) + consumer2.Consume(context.Background()) } diff --git a/examples/dag.go b/examples/dag.go index d255407..c3fa15a 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -3,56 +3,83 @@ package main import ( "context" "encoding/json" - "fmt" "io" "net/http" - "time" "github.com/oarkflow/mq" "github.com/oarkflow/mq/dag" - "github.com/oarkflow/mq/examples/tasks" ) -var d *dag.DAG +func handler1(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Payload: task.Payload} +} + +func handler2(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + return mq.Result{Payload: task.Payload} +} + +func handler3(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + age := int(user["age"].(float64)) + status := "FAIL" + if age > 20 { + status = "PASS" + } + user["status"] = status + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload, Status: status} +} + +func handler4(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["final"] = "D" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} +} + +func handler5(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + user["salary"] = "E" + resultPayload, _ := json.Marshal(user) + return mq.Result{Payload: resultPayload} +} + +func handler6(ctx context.Context, task *mq.Task) mq.Result { + var user map[string]any + json.Unmarshal(task.Payload, &user) + resultPayload, _ := json.Marshal(map[string]any{"storage": user}) + return mq.Result{Payload: resultPayload} +} + +var ( + d = dag.NewDAG(mq.WithSyncMode(true)) + // d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) +) func main() { - d = dag.New(mq.WithSyncMode(false), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.crt")) - d.AddNode("queue1", tasks.Node1, true) - d.AddNode("queue2", tasks.Node2) - d.AddNode("queue3", tasks.Node3) - d.AddNode("queue4", tasks.Node4) - - d.AddNode("queue5", tasks.CheckCondition) - d.AddNode("queue6", tasks.Pass) - d.AddNode("queue7", tasks.Fail) - - d.AddCondition("queue5", map[string]string{"pass": "queue6", "fail": "queue7"}) - d.AddEdge("queue1", "queue2") - d.AddEdge("queue2", "queue4") - d.AddEdge("queue3", "queue5") - - d.AddLoop("queue2", "queue3") - d.Prepare() - go func() { - d.Start(context.Background(), ":8081") - }() - go func() { - time.Sleep(3 * time.Second) - result := d.Send(context.Background(), []byte(`[{"user_id": 1}, {"user_id": 2}]`)) - if result.Error != nil { - panic(result.Error) - } - fmt.Println("Response", string(result.Payload)) - }() - - time.Sleep(10 * time.Second) - /*d.Prepare() + d.AddNode("A", handler1, true) + d.AddNode("B", handler2) + d.AddNode("C", handler3) + d.AddNode("D", handler4) + d.AddNode("E", handler5) + d.AddNode("F", handler6) + d.AddEdge("A", "B", dag.LoopEdge) + d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) + d.AddEdge("B", "C") + d.AddEdge("D", "F") + d.AddEdge("E", "F") + // fmt.Println(rs.TaskID, "Task", string(rs.Payload)) http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /request", requestHandler("request")) err := d.Start(context.TODO(), ":8083") if err != nil { panic(err) - }*/ + } } func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) { @@ -74,16 +101,13 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ http.Error(w, "Empty request body", http.StatusBadRequest) return } - var rs mq.Result - if requestType == "request" { - rs = d.Request(context.Background(), payload) - } else { - rs = d.Send(context.Background(), payload) - } + ctx := context.Background() + // ctx = context.WithValue(ctx, "initial_node", "E") + rs := d.ProcessTask(ctx, payload) w.Header().Set("Content-Type", "application/json") result := map[string]any{ - "message_id": rs.MessageID, - "payload": string(rs.Payload), + "message_id": rs.TaskID, + "payload": rs.Payload, "error": rs.Error, } json.NewEncoder(w).Encode(result) diff --git a/examples/dag_v2.go b/examples/dag_v2.go deleted file mode 100644 index 8757147..0000000 --- a/examples/dag_v2.go +++ /dev/null @@ -1,275 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "sync" - "time" -) - -type Task struct { - ID string `json:"id"` - Payload json.RawMessage `json:"payload"` - CreatedAt time.Time `json:"created_at"` - ProcessedAt time.Time `json:"processed_at"` - Status string `json:"status"` - Error error `json:"error"` -} - -type Result struct { - Payload json.RawMessage `json:"payload"` - Queue string `json:"queue"` - MessageID string `json:"message_id"` - Error error `json:"error,omitempty"` - Status string `json:"status"` -} - -const ( - SimpleEdge = iota - LoopEdge - ConditionEdge -) - -type Edge struct { - edgeType int - to string - conditions map[string]string -} - -type Node struct { - key string - handler func(context.Context, Task) Result - edges []Edge -} - -type RadixTrie struct { - children map[rune]*RadixTrie - node *Node - mu sync.RWMutex -} - -func NewRadixTrie() *RadixTrie { - return &RadixTrie{ - children: make(map[rune]*RadixTrie), - } -} - -func (trie *RadixTrie) Insert(key string, node *Node) { - trie.mu.Lock() - defer trie.mu.Unlock() - - current := trie - for _, char := range key { - if _, exists := current.children[char]; !exists { - current.children[char] = NewRadixTrie() - } - current = current.children[char] - } - current.node = node -} - -func (trie *RadixTrie) Search(key string) (*Node, bool) { - trie.mu.RLock() - defer trie.mu.RUnlock() - current := trie - for _, char := range key { - if _, exists := current.children[char]; !exists { - return nil, false - } - current = current.children[char] - } - if current.node != nil { - return current.node, true - } - return nil, false -} - -type DAG struct { - trie *RadixTrie - mu sync.RWMutex -} - -func NewDAG() *DAG { - return &DAG{ - trie: NewRadixTrie(), - } -} - -func (d *DAG) AddNode(key string, handler func(context.Context, Task) Result, isRoot ...bool) { - node := &Node{key: key, handler: handler} - d.trie.Insert(key, node) -} - -func (d *DAG) AddEdge(fromKey string, toKey string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add edge.\n", fromKey) - return - } - edge := Edge{edgeType: SimpleEdge, to: toKey} - node.edges = append(node.edges, edge) -} - -func (d *DAG) AddLoop(fromKey string, toKey string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add loop edge.\n", fromKey) - return - } - edge := Edge{edgeType: LoopEdge, to: toKey} - node.edges = append(node.edges, edge) -} - -func (d *DAG) AddCondition(fromKey string, conditions map[string]string) { - d.mu.Lock() - defer d.mu.Unlock() - node, exists := d.trie.Search(fromKey) - if !exists { - fmt.Printf("Node %s not found to add condition edge.\n", fromKey) - return - } - edge := Edge{edgeType: ConditionEdge, conditions: conditions} - node.edges = append(node.edges, edge) -} - -type ProcessCallback func(ctx context.Context, key string, result Result) string - -func (d *DAG) ProcessTask(ctx context.Context, key string, task Task) { - node, exists := d.trie.Search(key) - if !exists { - fmt.Printf("Node %s not found.\n", key) - return - } - result := node.handler(ctx, task) - nextKey := d.callback(ctx, key, result) - if nextKey != "" { - d.ProcessTask(ctx, nextKey, task) - } -} - -func (d *DAG) ProcessLoop(ctx context.Context, key string, task Task) { - _, exists := d.trie.Search(key) - if !exists { - fmt.Printf("Node %s not found.\n", key) - return - } - var items []json.RawMessage - err := json.Unmarshal(task.Payload, &items) - if err != nil { - fmt.Printf("Error unmarshaling payload as slice: %v\n", err) - return - } - for _, item := range items { - newTask := Task{ - ID: task.ID, - Payload: item, - } - - d.ProcessTask(ctx, key, newTask) - } -} - -func (d *DAG) callback(ctx context.Context, currentKey string, result Result) string { - fmt.Printf("Callback received result from %s: %s\n", currentKey, string(result.Payload)) - node, exists := d.trie.Search(currentKey) - if !exists { - return "" - } - for _, edge := range node.edges { - switch edge.edgeType { - case SimpleEdge: - return edge.to - case LoopEdge: - - d.ProcessLoop(ctx, edge.to, Task{Payload: result.Payload}) - return "" - case ConditionEdge: - if nextKey, conditionMet := edge.conditions[result.Status]; conditionMet { - return nextKey - } - } - } - return "" -} - -func Node1(ctx context.Context, task Task) Result { - return Result{Payload: task.Payload, MessageID: task.ID} -} - -func Node2(ctx context.Context, task Task) Result { - return Result{Payload: task.Payload, MessageID: task.ID} -} - -func Node3(ctx context.Context, task Task) Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) - bt, _ := json.Marshal(data) - return Result{Payload: bt, MessageID: task.ID} -} - -func Node4(ctx context.Context, task Task) Result { - var data []map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - payload := map[string]any{"storage": data} - bt, _ := json.Marshal(payload) - return Result{Payload: bt, MessageID: task.ID} -} - -func CheckCondition(ctx context.Context, task Task) Result { - var data map[string]any - err := json.Unmarshal(task.Payload, &data) - if err != nil { - return Result{Error: err} - } - var status string - if data["user_id"].(float64) == 2 { - status = "pass" - } else { - status = "fail" - } - return Result{Status: status, Payload: task.Payload, MessageID: task.ID} -} - -func Pass(ctx context.Context, task Task) Result { - fmt.Println("Pass") - return Result{Payload: task.Payload} -} - -func Fail(ctx context.Context, task Task) Result { - fmt.Println("Fail") - return Result{Payload: []byte(`{"test2": "asdsa"}`)} -} - -func main() { - dag := NewDAG() - dag.AddNode("queue1", Node1, true) - dag.AddNode("queue2", Node2) - dag.AddNode("queue3", Node3) - dag.AddNode("queue4", Node4) - dag.AddNode("queue5", CheckCondition) - dag.AddNode("queue6", Pass) - dag.AddNode("queue7", Fail) - dag.AddEdge("queue1", "queue2") - dag.AddEdge("queue2", "queue4") - dag.AddEdge("queue3", "queue5") - dag.AddLoop("queue2", "queue3") - dag.AddCondition("queue5", map[string]string{"pass": "queue6", "fail": "queue7"}) - ctx := context.Background() - task := Task{ - ID: "task1", - Payload: []byte(`[{"user_id": 1}, {"user_id": 2}]`), - } - dag.ProcessTask(ctx, "queue1", task) -} diff --git a/examples/tasks/tasks.go b/examples/tasks/tasks.go index fd7d534..3a6f64c 100644 --- a/examples/tasks/tasks.go +++ b/examples/tasks/tasks.go @@ -4,18 +4,19 @@ import ( "context" "encoding/json" "fmt" + "github.com/oarkflow/mq" ) -func Node1(ctx context.Context, task mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, MessageID: task.ID} +func Node1(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Payload: task.Payload, TaskID: task.ID} } -func Node2(ctx context.Context, task mq.Task) mq.Result { - return mq.Result{Payload: task.Payload, MessageID: task.ID} +func Node2(ctx context.Context, task *mq.Task) mq.Result { + return mq.Result{Payload: task.Payload, TaskID: task.ID} } -func Node3(ctx context.Context, task mq.Task) mq.Result { +func Node3(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { @@ -23,10 +24,10 @@ func Node3(ctx context.Context, task mq.Task) mq.Result { } data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) bt, _ := json.Marshal(data) - return mq.Result{Payload: bt, MessageID: task.ID} + return mq.Result{Payload: bt, TaskID: task.ID} } -func Node4(ctx context.Context, task mq.Task) mq.Result { +func Node4(ctx context.Context, task *mq.Task) mq.Result { var data []map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { @@ -34,10 +35,10 @@ func Node4(ctx context.Context, task mq.Task) mq.Result { } payload := map[string]any{"storage": data} bt, _ := json.Marshal(payload) - return mq.Result{Payload: bt, MessageID: task.ID} + return mq.Result{Payload: bt, TaskID: task.ID} } -func CheckCondition(ctx context.Context, task mq.Task) mq.Result { +func CheckCondition(ctx context.Context, task *mq.Task) mq.Result { var data map[string]any err := json.Unmarshal(task.Payload, &data) if err != nil { @@ -49,20 +50,20 @@ func CheckCondition(ctx context.Context, task mq.Task) mq.Result { } else { status = "fail" } - return mq.Result{Status: status, Payload: task.Payload, MessageID: task.ID} + return mq.Result{Status: status, Payload: task.Payload, TaskID: task.ID} } -func Pass(ctx context.Context, task mq.Task) mq.Result { +func Pass(ctx context.Context, task *mq.Task) mq.Result { fmt.Println("Pass") return mq.Result{Payload: task.Payload} } -func Fail(ctx context.Context, task mq.Task) mq.Result { +func Fail(ctx context.Context, task *mq.Task) mq.Result { fmt.Println("Fail") return mq.Result{Payload: []byte(`{"test2": "asdsa"}`)} } func Callback(ctx context.Context, task mq.Result) mq.Result { - fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue) + fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic) return mq.Result{} } diff --git a/options.go b/options.go index 76ea6ea..096cc18 100644 --- a/options.go +++ b/options.go @@ -3,15 +3,52 @@ package mq import ( "context" "encoding/json" + "fmt" "time" ) type Result struct { - Payload json.RawMessage `json:"payload"` - Queue string `json:"queue"` - MessageID string `json:"message_id"` - Error error `json:"error,omitempty"` - Status string `json:"status"` + Payload json.RawMessage `json:"payload"` + Topic string `json:"topic"` + TaskID string `json:"task_id"` + Error error `json:"error,omitempty"` + Status string `json:"status"` +} + +func (r Result) Unmarshal(data any) error { + if r.Payload == nil { + return fmt.Errorf("payload is nil") + } + return json.Unmarshal(r.Payload, data) +} + +func (r Result) String() string { + return string(r.Payload) +} + +func HandleError(ctx context.Context, err error, status ...string) Result { + st := "Failed" + if len(status) > 0 { + st = status[0] + } + if err == nil { + return Result{} + } + return Result{ + Status: st, + Error: err, + } +} + +func (r Result) WithData(status string, data []byte) Result { + if r.Error != nil { + return r + } + return Result{ + Status: status, + Payload: data, + Error: nil, + } } type TLSConfig struct {