mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-26 20:11:16 +08:00
feat: add example
This commit is contained in:
@@ -154,6 +154,7 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.subscribe(ctx, c.queue); err != nil {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||||
}
|
||||
|
2
ctx.go
2
ctx.go
@@ -118,8 +118,6 @@ func GetPublisherID(ctx context.Context) (string, bool) {
|
||||
// Helper function to convert HeaderMap to a regular map
|
||||
func getMapAsRegularMap(hd *HeaderMap) map[string]string {
|
||||
result := make(map[string]string)
|
||||
hd.mu.RLock()
|
||||
defer hd.mu.RUnlock()
|
||||
for key, value := range hd.headers {
|
||||
result[key] = value
|
||||
}
|
||||
|
381
dag/dag.go
381
dag/dag.go
@@ -1,381 +0,0 @@
|
||||
package dag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
type taskContext struct {
|
||||
totalItems int
|
||||
completed int
|
||||
results []json.RawMessage
|
||||
result json.RawMessage
|
||||
multipleResults bool
|
||||
}
|
||||
|
||||
type DAG struct {
|
||||
FirstNode string
|
||||
server *mq.Broker
|
||||
nodes map[string]*mq.Consumer
|
||||
edges map[string]string
|
||||
conditions map[string]map[string]string
|
||||
loopEdges map[string][]string
|
||||
taskChMap map[string]chan mq.Result
|
||||
taskResults map[string]map[string]*taskContext
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func New(opts ...mq.Option) *DAG {
|
||||
d := &DAG{
|
||||
nodes: make(map[string]*mq.Consumer),
|
||||
edges: make(map[string]string),
|
||||
conditions: make(map[string]map[string]string),
|
||||
loopEdges: make(map[string][]string),
|
||||
taskChMap: make(map[string]chan mq.Result),
|
||||
taskResults: make(map[string]map[string]*taskContext),
|
||||
}
|
||||
opts = append(opts, mq.WithCallback(d.TaskCallback))
|
||||
d.server = mq.NewBroker(opts...)
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *DAG) AddNode(name string, handler mq.Handler, firstNode ...bool) {
|
||||
tlsConfig := d.server.TLSConfig()
|
||||
con := mq.NewConsumer(name, name, handler, mq.WithTLS(tlsConfig.UseTLS, tlsConfig.CertPath, tlsConfig.KeyPath), mq.WithCAPath(tlsConfig.CAPath))
|
||||
if len(firstNode) > 0 {
|
||||
d.FirstNode = name
|
||||
}
|
||||
d.nodes[name] = con
|
||||
}
|
||||
|
||||
func (d *DAG) AddCondition(fromNode string, conditions map[string]string) {
|
||||
d.conditions[fromNode] = conditions
|
||||
}
|
||||
|
||||
func (d *DAG) AddEdge(fromNode string, toNodes string) {
|
||||
d.edges[fromNode] = toNodes
|
||||
}
|
||||
|
||||
func (d *DAG) AddLoop(fromNode string, toNode ...string) {
|
||||
d.loopEdges[fromNode] = toNode
|
||||
}
|
||||
|
||||
func (d *DAG) Prepare() {
|
||||
if d.FirstNode == "" {
|
||||
firstNode, ok := d.FindFirstNode()
|
||||
if ok && firstNode != "" {
|
||||
d.FirstNode = firstNode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DAG) Start(ctx context.Context, addr string) error {
|
||||
d.Prepare()
|
||||
if d.server.SyncMode() {
|
||||
return nil
|
||||
}
|
||||
go func() {
|
||||
err := d.server.Start(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
for _, con := range d.nodes {
|
||||
go func(con *mq.Consumer) {
|
||||
con.Consume(ctx)
|
||||
}(con)
|
||||
}
|
||||
log.Printf("HTTP server started on %s", addr)
|
||||
config := d.server.TLSConfig()
|
||||
if config.UseTLS {
|
||||
return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil)
|
||||
}
|
||||
return http.ListenAndServe(addr, nil)
|
||||
}
|
||||
|
||||
func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) mq.Result {
|
||||
queue, ok := mq.GetQueue(ctx)
|
||||
if !ok {
|
||||
queue = d.FirstNode
|
||||
}
|
||||
var id string
|
||||
if len(taskID) > 0 {
|
||||
id = taskID[0]
|
||||
} else {
|
||||
id = mq.NewID()
|
||||
}
|
||||
task := &mq.Task{
|
||||
ID: id,
|
||||
Payload: payload,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
err := d.server.Publish(ctx, task, queue)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
return mq.Result{
|
||||
Payload: payload,
|
||||
Topic: queue,
|
||||
TaskID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DAG) FindFirstNode() (string, bool) {
|
||||
inDegree := make(map[string]int)
|
||||
for n, _ := range d.nodes {
|
||||
inDegree[n] = 0
|
||||
}
|
||||
for _, outNode := range d.edges {
|
||||
inDegree[outNode]++
|
||||
}
|
||||
for _, targets := range d.loopEdges {
|
||||
for _, outNode := range targets {
|
||||
inDegree[outNode]++
|
||||
}
|
||||
}
|
||||
for n, count := range inDegree {
|
||||
if count == 0 {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result {
|
||||
return d.sendSync(ctx, mq.Result{Payload: payload})
|
||||
}
|
||||
|
||||
func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result {
|
||||
if d.FirstNode == "" {
|
||||
return mq.Result{Error: fmt.Errorf("initial node not defined")}
|
||||
}
|
||||
if d.server.SyncMode() {
|
||||
return d.sendSync(ctx, mq.Result{Payload: payload})
|
||||
}
|
||||
resultCh := make(chan mq.Result)
|
||||
result := d.PublishTask(ctx, payload)
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.taskChMap[result.TaskID] = resultCh
|
||||
d.mu.Unlock()
|
||||
finalResult := <-resultCh
|
||||
return finalResult
|
||||
}
|
||||
|
||||
func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result {
|
||||
if con, ok := d.nodes[task.Topic]; ok {
|
||||
return con.ProcessTask(ctx, &mq.Task{
|
||||
ID: task.TaskID,
|
||||
Payload: task.Payload,
|
||||
})
|
||||
}
|
||||
return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Topic)}
|
||||
}
|
||||
|
||||
func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result {
|
||||
if task.TaskID == "" {
|
||||
task.TaskID = mq.NewID()
|
||||
}
|
||||
if task.Topic == "" {
|
||||
task.Topic = d.FirstNode
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||
consts.QueueKey: task.Topic,
|
||||
})
|
||||
result := d.processNode(ctx, task)
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
for _, target := range d.loopEdges[task.Topic] {
|
||||
var items, results []json.RawMessage
|
||||
if err := json.Unmarshal(result.Payload, &items); err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
for _, item := range items {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||
consts.QueueKey: target,
|
||||
})
|
||||
result = d.sendSync(ctx, mq.Result{
|
||||
Payload: item,
|
||||
Topic: target,
|
||||
TaskID: result.TaskID,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
results = append(results, result.Payload)
|
||||
}
|
||||
bt, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
result.Payload = bt
|
||||
}
|
||||
if conditions, ok := d.conditions[task.Topic]; ok {
|
||||
if target, exists := conditions[result.Status]; exists {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||
consts.QueueKey: target,
|
||||
})
|
||||
result = d.sendSync(ctx, mq.Result{
|
||||
Payload: result.Payload,
|
||||
Topic: target,
|
||||
TaskID: result.TaskID,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
if target, ok := d.edges[task.Topic]; ok {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||
consts.QueueKey: target,
|
||||
})
|
||||
result = d.sendSync(ctx, mq.Result{
|
||||
Payload: result.Payload,
|
||||
Topic: target,
|
||||
TaskID: result.TaskID,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) {
|
||||
var result any
|
||||
var payload []byte
|
||||
completed := false
|
||||
multipleResults := false
|
||||
if ok && triggeredNode != "" {
|
||||
taskResults, ok := d.taskResults[task.TaskID]
|
||||
if ok {
|
||||
nodeResult, exists := taskResults[triggeredNode]
|
||||
if exists {
|
||||
multipleResults = nodeResult.multipleResults
|
||||
nodeResult.completed++
|
||||
if nodeResult.completed == nodeResult.totalItems {
|
||||
completed = true
|
||||
}
|
||||
if multipleResults {
|
||||
nodeResult.results = append(nodeResult.results, task.Payload)
|
||||
if completed {
|
||||
result = nodeResult.results
|
||||
}
|
||||
} else {
|
||||
nodeResult.result = task.Payload
|
||||
if completed {
|
||||
result = nodeResult.result
|
||||
}
|
||||
}
|
||||
}
|
||||
if completed {
|
||||
delete(taskResults, triggeredNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
if completed {
|
||||
payload, _ = json.Marshal(result)
|
||||
} else {
|
||||
payload = task.Payload
|
||||
}
|
||||
return payload, completed, multipleResults
|
||||
}
|
||||
|
||||
func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
if task.Error != nil {
|
||||
return mq.Result{Error: task.Error}
|
||||
}
|
||||
triggeredNode, ok := mq.GetTriggerNode(ctx)
|
||||
payload, completed, multipleResults := d.getCompletedResults(task, ok, triggeredNode)
|
||||
if loopNodes, exists := d.loopEdges[task.Topic]; exists {
|
||||
var items []json.RawMessage
|
||||
if err := json.Unmarshal(payload, &items); err != nil {
|
||||
return mq.Result{Error: task.Error}
|
||||
}
|
||||
d.taskResults[task.TaskID] = map[string]*taskContext{
|
||||
task.Topic: {
|
||||
totalItems: len(items),
|
||||
multipleResults: true,
|
||||
},
|
||||
}
|
||||
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Topic})
|
||||
for _, loopNode := range loopNodes {
|
||||
for _, item := range items {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||
consts.QueueKey: loopNode,
|
||||
})
|
||||
result := d.PublishTask(ctx, item, task.TaskID)
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return task
|
||||
}
|
||||
if multipleResults && completed {
|
||||
task.Topic = triggeredNode
|
||||
}
|
||||
if conditions, ok := d.conditions[task.Topic]; ok {
|
||||
if target, exists := conditions[task.Status]; exists {
|
||||
d.taskResults[task.TaskID] = map[string]*taskContext{
|
||||
task.Topic: {
|
||||
totalItems: len(conditions),
|
||||
},
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||
consts.QueueKey: target,
|
||||
consts.TriggerNode: task.Topic,
|
||||
})
|
||||
result := d.PublishTask(ctx, payload, task.TaskID)
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Topic})
|
||||
edge, exists := d.edges[task.Topic]
|
||||
if exists {
|
||||
d.taskResults[task.TaskID] = map[string]*taskContext{
|
||||
task.Topic: {
|
||||
totalItems: 1,
|
||||
},
|
||||
}
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||
consts.QueueKey: edge,
|
||||
})
|
||||
result := d.PublishTask(ctx, payload, task.TaskID)
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
} else if completed {
|
||||
d.mu.Lock()
|
||||
if resultCh, ok := d.taskChMap[task.TaskID]; ok {
|
||||
resultCh <- mq.Result{
|
||||
Payload: payload,
|
||||
Topic: task.Topic,
|
||||
TaskID: task.TaskID,
|
||||
Status: "done",
|
||||
}
|
||||
delete(d.taskChMap, task.TaskID)
|
||||
delete(d.taskResults, task.TaskID)
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return task
|
||||
}
|
108
examples/dag.go
108
examples/dag.go
@@ -1,53 +1,79 @@
|
||||
package main
|
||||
|
||||
/*
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
"github.com/oarkflow/mq/v2"
|
||||
)
|
||||
|
||||
var d *dag.DAG
|
||||
func handler1(ctx context.Context, task *mq.Task) mq.Result {
|
||||
return mq.Result{Payload: task.Payload}
|
||||
}
|
||||
|
||||
func handler2(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
return mq.Result{Payload: task.Payload}
|
||||
}
|
||||
|
||||
func handler3(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
age := int(user["age"].(float64))
|
||||
status := "FAIL"
|
||||
if age > 20 {
|
||||
status = "PASS"
|
||||
}
|
||||
user["status"] = status
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload, Status: status}
|
||||
}
|
||||
|
||||
func handler4(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
user["final"] = "D"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func handler5(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
user["salary"] = "E"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func handler6(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
resultPayload, _ := json.Marshal(map[string]any{"storage": user})
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
var (
|
||||
d = v2.NewDAG(mq.WithSyncMode(false))
|
||||
)
|
||||
|
||||
func main() {
|
||||
d = dag.New(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.crt"))
|
||||
d.AddNode("queue1", tasks.Node1, true)
|
||||
d.AddNode("queue2", tasks.Node2)
|
||||
d.AddNode("queue3", tasks.Node3)
|
||||
d.AddNode("queue4", tasks.Node4)
|
||||
d.AddNode("A", handler1)
|
||||
d.AddNode("B", handler2)
|
||||
d.AddNode("C", handler3)
|
||||
d.AddNode("D", handler4)
|
||||
d.AddNode("E", handler5)
|
||||
d.AddNode("F", handler6)
|
||||
d.AddEdge("A", "B", v2.LoopEdge)
|
||||
d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"})
|
||||
d.AddEdge("B", "C")
|
||||
d.AddEdge("D", "F")
|
||||
d.AddEdge("E", "F")
|
||||
|
||||
d.AddNode("queue5", tasks.CheckCondition)
|
||||
d.AddNode("queue6", tasks.Pass)
|
||||
d.AddNode("queue7", tasks.Fail)
|
||||
|
||||
d.AddCondition("queue5", map[string]string{"pass": "queue6", "fail": "queue7"})
|
||||
d.AddEdge("queue1", "queue2")
|
||||
d.AddEdge("queue2", "queue4")
|
||||
d.AddEdge("queue3", "queue5")
|
||||
|
||||
d.AddLoop("queue2", "queue3")
|
||||
d.Prepare()
|
||||
go func() {
|
||||
d.Start(context.Background(), ":8081")
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
result := d.Send(context.Background(), []byte(`[{"user_id": 1}, {"user_id": 2}]`))
|
||||
if result.Error != nil {
|
||||
panic(result.Error)
|
||||
}
|
||||
fmt.Println("Response", string(result.Payload))
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
d.Prepare()
|
||||
// fmt.Println(rs.TaskID, "Task", string(rs.Payload))
|
||||
http.HandleFunc("POST /publish", requestHandler("publish"))
|
||||
http.HandleFunc("POST /request", requestHandler("request"))
|
||||
err := d.Start(context.TODO(), ":8083")
|
||||
@@ -75,19 +101,13 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ
|
||||
http.Error(w, "Empty request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var rs mq.Result
|
||||
if requestType == "request" {
|
||||
rs = d.Request(context.Background(), payload)
|
||||
} else {
|
||||
rs = d.Send(context.Background(), payload)
|
||||
}
|
||||
rs := d.ProcessTask(context.Background(), "A", payload)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
result := map[string]any{
|
||||
"message_id": rs.TaskID,
|
||||
"payload": string(rs.Payload),
|
||||
"payload": rs.Payload,
|
||||
"error": rs.Error,
|
||||
}
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
@@ -1,125 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/v2"
|
||||
)
|
||||
|
||||
func handler1(ctx context.Context, task *mq.Task) mq.Result {
|
||||
return mq.Result{Payload: task.Payload}
|
||||
}
|
||||
|
||||
func handler2(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
return mq.Result{Payload: task.Payload}
|
||||
}
|
||||
|
||||
func handler3(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
age := int(user["age"].(float64))
|
||||
status := "FAIL"
|
||||
if age > 20 {
|
||||
status = "PASS"
|
||||
}
|
||||
user["status"] = status
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload, Status: status}
|
||||
}
|
||||
|
||||
func handler4(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
user["final"] = "D"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func handler5(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
user["salary"] = "E"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func handler6(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
resultPayload, _ := json.Marshal(map[string]any{"storage": user})
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
var (
|
||||
d = v2.NewDAG(mq.WithSyncMode(true))
|
||||
)
|
||||
|
||||
func main() {
|
||||
d.AddNode("A", handler1)
|
||||
d.AddNode("B", handler2)
|
||||
d.AddNode("C", handler3)
|
||||
d.AddNode("D", handler4)
|
||||
d.AddNode("E", handler5)
|
||||
d.AddNode("F", handler6)
|
||||
d.AddEdge("A", "B", v2.LoopEdge)
|
||||
d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"})
|
||||
d.AddEdge("B", "C")
|
||||
d.AddEdge("D", "F")
|
||||
d.AddEdge("E", "F")
|
||||
|
||||
initialPayload, _ := json.Marshal([]map[string]any{
|
||||
{"user_id": 1, "age": 12},
|
||||
{"user_id": 2, "age": 34},
|
||||
})
|
||||
/*for i := 0; i < 100; i++ {
|
||||
|
||||
}*/
|
||||
rs := d.ProcessTask(context.Background(), "A", initialPayload)
|
||||
if rs.Error != nil {
|
||||
panic(rs.Error)
|
||||
}
|
||||
fmt.Println(rs.TaskID, "Task", string(rs.Payload))
|
||||
/*http.HandleFunc("POST /publish", requestHandler("publish"))
|
||||
http.HandleFunc("POST /request", requestHandler("request"))
|
||||
err := d.Start(context.TODO(), ":8083")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}*/
|
||||
}
|
||||
|
||||
func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var payload []byte
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
var err error
|
||||
payload, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "Empty request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
rs := d.ProcessTask(context.Background(), "A", payload)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
result := map[string]any{
|
||||
"message_id": rs.TaskID,
|
||||
"payload": string(rs.Payload),
|
||||
"error": rs.Error,
|
||||
}
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
}
|
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xid"
|
||||
|
||||
@@ -78,6 +79,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
|
||||
}()
|
||||
for _, con := range tm.Nodes {
|
||||
go func(con *Node) {
|
||||
time.Sleep(1 * time.Second)
|
||||
con.consumer.Consume(ctx)
|
||||
}(con)
|
||||
}
|
||||
|
@@ -5,20 +5,22 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type TaskManager struct {
|
||||
taskID string
|
||||
dag *DAG
|
||||
wg sync.WaitGroup
|
||||
mutex sync.Mutex
|
||||
results []mq.Result
|
||||
nodeResults map[string]mq.Result
|
||||
done chan struct{}
|
||||
finalResult chan mq.Result // Channel to collect final results
|
||||
taskID string
|
||||
dag *DAG
|
||||
wg sync.WaitGroup
|
||||
mutex sync.Mutex
|
||||
results []mq.Result
|
||||
waitingCallback int64
|
||||
nodeResults map[string]mq.Result
|
||||
done chan struct{}
|
||||
finalResult chan mq.Result // Channel to collect final results
|
||||
}
|
||||
|
||||
func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
||||
@@ -26,9 +28,7 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
||||
dag: d,
|
||||
nodeResults: make(map[string]mq.Result),
|
||||
results: make([]mq.Result, 0),
|
||||
done: make(chan struct{}),
|
||||
taskID: taskID,
|
||||
finalResult: make(chan mq.Result), // Initialize finalResult channel
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,26 +37,97 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j
|
||||
if !ok {
|
||||
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
|
||||
}
|
||||
tm.wg.Add(1)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
go func() {
|
||||
tm.wg.Wait()
|
||||
close(tm.done)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return mq.Result{Error: ctx.Err()}
|
||||
case <-tm.done:
|
||||
tm.mutex.Lock()
|
||||
defer tm.mutex.Unlock()
|
||||
if len(tm.results) == 1 {
|
||||
return tm.handleResult(ctx, tm.results[0])
|
||||
if tm.dag.server.SyncMode() {
|
||||
tm.done = make(chan struct{})
|
||||
tm.wg.Add(1)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
go func() {
|
||||
tm.wg.Wait()
|
||||
close(tm.done)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return mq.Result{Error: ctx.Err()}
|
||||
case <-tm.done:
|
||||
tm.mutex.Lock()
|
||||
defer tm.mutex.Unlock()
|
||||
if len(tm.results) == 1 {
|
||||
return tm.handleResult(ctx, tm.results[0])
|
||||
}
|
||||
return tm.handleResult(ctx, tm.results)
|
||||
}
|
||||
} else {
|
||||
tm.finalResult = make(chan mq.Result)
|
||||
tm.wg.Add(1)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
go func() {
|
||||
tm.wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case result := <-tm.finalResult: // Block until a result is available
|
||||
return result
|
||||
case <-ctx.Done(): // Handle context cancellation
|
||||
return mq.Result{Error: ctx.Err()}
|
||||
}
|
||||
return tm.handleResult(ctx, tm.results)
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result {
|
||||
if result.Topic != "" {
|
||||
atomic.AddInt64(&tm.waitingCallback, -1)
|
||||
}
|
||||
node, ok := tm.dag.Nodes[result.Topic]
|
||||
if !ok {
|
||||
return result
|
||||
}
|
||||
edges := make([]Edge, len(node.Edges))
|
||||
copy(edges, node.Edges)
|
||||
if result.Status != "" {
|
||||
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
|
||||
if targetNodeKey, ok := conditions[result.Status]; ok {
|
||||
if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok {
|
||||
edges = append(edges, Edge{From: node, To: targetNode})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(edges) == 0 {
|
||||
tm.appendFinalResult(result)
|
||||
if !tm.dag.server.SyncMode() {
|
||||
var rs mq.Result
|
||||
if len(tm.results) == 1 {
|
||||
rs = tm.handleResult(ctx, tm.results[0])
|
||||
} else {
|
||||
rs = tm.handleResult(ctx, tm.results)
|
||||
}
|
||||
if tm.waitingCallback == 0 {
|
||||
tm.finalResult <- rs
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
for _, edge := range edges {
|
||||
switch edge.Type {
|
||||
case LoopEdge:
|
||||
var items []json.RawMessage
|
||||
err := json.Unmarshal(result.Payload, &items)
|
||||
if err != nil {
|
||||
tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err})
|
||||
return result
|
||||
}
|
||||
for _, item := range items {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, item)
|
||||
}
|
||||
case SimpleEdge:
|
||||
if edge.To != nil {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, result.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
return mq.Result{}
|
||||
}
|
||||
|
||||
@@ -103,6 +174,7 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) {
|
||||
}
|
||||
|
||||
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
|
||||
atomic.AddInt64(&tm.waitingCallback, 1)
|
||||
defer tm.wg.Done()
|
||||
var result mq.Result
|
||||
select {
|
||||
@@ -115,6 +187,7 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
|
||||
if tm.dag.server.SyncMode() {
|
||||
result = node.consumer.ProcessTask(ctx, NewTask(tm.taskID, payload, node.Key))
|
||||
result.Topic = node.Key
|
||||
result.TaskID = tm.taskID
|
||||
if result.Error != nil {
|
||||
tm.appendFinalResult(result)
|
||||
return
|
||||
@@ -130,41 +203,5 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
|
||||
tm.mutex.Lock()
|
||||
tm.nodeResults[node.Key] = result
|
||||
tm.mutex.Unlock()
|
||||
edges := make([]Edge, len(node.Edges))
|
||||
copy(edges, node.Edges)
|
||||
if result.Status != "" {
|
||||
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
|
||||
if targetNodeKey, ok := conditions[result.Status]; ok {
|
||||
if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok {
|
||||
edges = append(edges, Edge{From: node, To: targetNode})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(edges) == 0 {
|
||||
tm.appendFinalResult(result)
|
||||
return
|
||||
}
|
||||
for _, edge := range edges {
|
||||
switch edge.Type {
|
||||
case LoopEdge:
|
||||
var items []json.RawMessage
|
||||
err := json.Unmarshal(result.Payload, &items)
|
||||
if err != nil {
|
||||
tm.appendFinalResult(mq.Result{TaskID: tm.taskID, Topic: node.Key, Error: err})
|
||||
return
|
||||
}
|
||||
for _, item := range items {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, item)
|
||||
}
|
||||
case SimpleEdge:
|
||||
if edge.To != nil {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, result.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
tm.handleCallback(ctx, result)
|
||||
}
|
||||
|
Reference in New Issue
Block a user