mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-05 16:06:55 +08:00
feat: add example
This commit is contained in:
@@ -57,11 +57,11 @@ func handler6(ctx context.Context, task *mq.Task) mq.Result {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
d = v2.NewDAG(mq.WithSyncMode(false))
|
d = v2.NewDAG(mq.WithSyncMode(true))
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
d.AddNode("A", handler1)
|
d.AddNode("A", handler1, true)
|
||||||
d.AddNode("B", handler2)
|
d.AddNode("B", handler2)
|
||||||
d.AddNode("C", handler3)
|
d.AddNode("C", handler3)
|
||||||
d.AddNode("D", handler4)
|
d.AddNode("D", handler4)
|
||||||
@@ -72,7 +72,6 @@ func main() {
|
|||||||
d.AddEdge("B", "C")
|
d.AddEdge("B", "C")
|
||||||
d.AddEdge("D", "F")
|
d.AddEdge("D", "F")
|
||||||
d.AddEdge("E", "F")
|
d.AddEdge("E", "F")
|
||||||
|
|
||||||
// fmt.Println(rs.TaskID, "Task", string(rs.Payload))
|
// fmt.Println(rs.TaskID, "Task", string(rs.Payload))
|
||||||
http.HandleFunc("POST /publish", requestHandler("publish"))
|
http.HandleFunc("POST /publish", requestHandler("publish"))
|
||||||
http.HandleFunc("POST /request", requestHandler("request"))
|
http.HandleFunc("POST /request", requestHandler("request"))
|
||||||
@@ -101,7 +100,9 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ
|
|||||||
http.Error(w, "Empty request body", http.StatusBadRequest)
|
http.Error(w, "Empty request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rs := d.ProcessTask(context.Background(), "A", payload)
|
ctx := context.Background()
|
||||||
|
// ctx = context.WithValue(ctx, "initial_node", "E")
|
||||||
|
rs := d.ProcessTask(ctx, payload)
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
result := map[string]any{
|
result := map[string]any{
|
||||||
"message_id": rs.TaskID,
|
"message_id": rs.TaskID,
|
||||||
|
89
v2/dag.go
89
v2/dag.go
@@ -3,6 +3,7 @@ package v2
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -20,12 +21,6 @@ func NewTask(id string, payload json.RawMessage, nodeKey string) *mq.Task {
|
|||||||
return &mq.Task{ID: id, Payload: payload, Topic: nodeKey}
|
return &mq.Task{ID: id, Payload: payload, Topic: nodeKey}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Node struct {
|
|
||||||
Key string
|
|
||||||
Edges []Edge
|
|
||||||
consumer *mq.Consumer
|
|
||||||
}
|
|
||||||
|
|
||||||
type EdgeType int
|
type EdgeType int
|
||||||
|
|
||||||
func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge }
|
func (c EdgeType) IsValid() bool { return c >= SimpleEdge && c <= LoopEdge }
|
||||||
@@ -35,6 +30,12 @@ const (
|
|||||||
LoopEdge
|
LoopEdge
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Node struct {
|
||||||
|
Key string
|
||||||
|
Edges []Edge
|
||||||
|
consumer *mq.Consumer
|
||||||
|
}
|
||||||
|
|
||||||
type Edge struct {
|
type Edge struct {
|
||||||
From *Node
|
From *Node
|
||||||
To *Node
|
To *Node
|
||||||
@@ -42,6 +43,7 @@ type Edge struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type DAG struct {
|
type DAG struct {
|
||||||
|
FirstNode string
|
||||||
Nodes map[string]*Node
|
Nodes map[string]*Node
|
||||||
server *mq.Broker
|
server *mq.Broker
|
||||||
taskContext map[string]*TaskManager
|
taskContext map[string]*TaskManager
|
||||||
@@ -68,21 +70,21 @@ func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) Start(ctx context.Context, addr string) error {
|
func (tm *DAG) Start(ctx context.Context, addr string) error {
|
||||||
if tm.server.SyncMode() {
|
if !tm.server.SyncMode() {
|
||||||
return nil
|
go func() {
|
||||||
}
|
err := tm.server.Start(ctx)
|
||||||
go func() {
|
if err != nil {
|
||||||
err := tm.server.Start(ctx)
|
panic(err)
|
||||||
if err != nil {
|
}
|
||||||
panic(err)
|
}()
|
||||||
|
for _, con := range tm.Nodes {
|
||||||
|
go func(con *Node) {
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
con.consumer.Consume(ctx)
|
||||||
|
}(con)
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
for _, con := range tm.Nodes {
|
|
||||||
go func(con *Node) {
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
con.consumer.Consume(ctx)
|
|
||||||
}(con)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("HTTP server started on %s", addr)
|
log.Printf("HTTP server started on %s", addr)
|
||||||
config := tm.server.TLSConfig()
|
config := tm.server.TLSConfig()
|
||||||
if config.UseTLS {
|
if config.UseTLS {
|
||||||
@@ -91,7 +93,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
|
|||||||
return http.ListenAndServe(addr, nil)
|
return http.ListenAndServe(addr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) AddNode(key string, handler mq.Handler) {
|
func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) {
|
||||||
tm.mu.Lock()
|
tm.mu.Lock()
|
||||||
defer tm.mu.Unlock()
|
defer tm.mu.Unlock()
|
||||||
con := mq.NewConsumer(key, key, handler)
|
con := mq.NewConsumer(key, key, handler)
|
||||||
@@ -99,6 +101,9 @@ func (tm *DAG) AddNode(key string, handler mq.Handler) {
|
|||||||
Key: key,
|
Key: key,
|
||||||
consumer: con,
|
consumer: con,
|
||||||
}
|
}
|
||||||
|
if len(firstNode) > 0 && firstNode[0] {
|
||||||
|
tm.FirstNode = key
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) {
|
func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) {
|
||||||
@@ -125,11 +130,51 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) {
|
|||||||
fromNode.Edges = append(fromNode.Edges, edge)
|
fromNode.Edges = append(fromNode.Edges, edge)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) ProcessTask(ctx context.Context, node string, payload []byte) mq.Result {
|
func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result {
|
||||||
|
val := ctx.Value("initial_node")
|
||||||
|
initialNode, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
if tm.FirstNode == "" {
|
||||||
|
firstNode := tm.FindInitialNode()
|
||||||
|
if firstNode != nil {
|
||||||
|
tm.FirstNode = firstNode.Key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tm.FirstNode == "" {
|
||||||
|
return mq.Result{Error: fmt.Errorf("initial node not found")}
|
||||||
|
}
|
||||||
|
initialNode = tm.FirstNode
|
||||||
|
}
|
||||||
tm.mu.Lock()
|
tm.mu.Lock()
|
||||||
defer tm.mu.Unlock()
|
defer tm.mu.Unlock()
|
||||||
taskID := xid.New().String()
|
taskID := xid.New().String()
|
||||||
manager := NewTaskManager(tm, taskID)
|
manager := NewTaskManager(tm, taskID)
|
||||||
tm.taskContext[taskID] = manager
|
tm.taskContext[taskID] = manager
|
||||||
return manager.processTask(ctx, node, payload)
|
return manager.processTask(ctx, initialNode, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *DAG) FindInitialNode() *Node {
|
||||||
|
incomingEdges := make(map[string]bool)
|
||||||
|
connectedNodes := make(map[string]bool)
|
||||||
|
for _, node := range tm.Nodes {
|
||||||
|
for _, edge := range node.Edges {
|
||||||
|
if edge.Type.IsValid() {
|
||||||
|
connectedNodes[node.Key] = true
|
||||||
|
connectedNodes[edge.To.Key] = true
|
||||||
|
incomingEdges[edge.To.Key] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cond, ok := tm.conditions[node.Key]; ok {
|
||||||
|
for _, target := range cond {
|
||||||
|
connectedNodes[target] = true
|
||||||
|
incomingEdges[target] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for nodeID, node := range tm.Nodes {
|
||||||
|
if !incomingEdges[nodeID] && connectedNodes[nodeID] {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -32,42 +32,63 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
|
||||||
|
tm.done = make(chan struct{})
|
||||||
|
tm.wg.Add(1)
|
||||||
|
go tm.processNode(ctx, node, payload)
|
||||||
|
go func() {
|
||||||
|
tm.wg.Wait()
|
||||||
|
close(tm.done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return mq.Result{Error: ctx.Err()}
|
||||||
|
case <-tm.done:
|
||||||
|
tm.mutex.Lock()
|
||||||
|
defer tm.mutex.Unlock()
|
||||||
|
if len(tm.results) == 1 {
|
||||||
|
return tm.handleResult(ctx, tm.results[0])
|
||||||
|
}
|
||||||
|
return tm.handleResult(ctx, tm.results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
|
||||||
|
tm.finalResult = make(chan mq.Result)
|
||||||
|
tm.wg.Add(1)
|
||||||
|
go tm.processNode(ctx, node, payload)
|
||||||
|
go func() {
|
||||||
|
tm.wg.Wait()
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case result := <-tm.finalResult: // Block until a result is available
|
||||||
|
return result
|
||||||
|
case <-ctx.Done(): // Handle context cancellation
|
||||||
|
return mq.Result{Error: ctx.Err()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
|
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
|
||||||
node, ok := tm.dag.Nodes[nodeID]
|
node, ok := tm.dag.Nodes[nodeID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
|
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
|
||||||
}
|
}
|
||||||
if tm.dag.server.SyncMode() {
|
if tm.dag.server.SyncMode() {
|
||||||
tm.done = make(chan struct{})
|
return tm.handleSyncTask(ctx, node, payload)
|
||||||
tm.wg.Add(1)
|
}
|
||||||
go tm.processNode(ctx, node, payload)
|
return tm.handleAsyncTask(ctx, node, payload)
|
||||||
go func() {
|
}
|
||||||
tm.wg.Wait()
|
|
||||||
close(tm.done)
|
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) {
|
||||||
}()
|
if !tm.dag.server.SyncMode() {
|
||||||
select {
|
var rs mq.Result
|
||||||
case <-ctx.Done():
|
if len(tm.results) == 1 {
|
||||||
return mq.Result{Error: ctx.Err()}
|
rs = tm.handleResult(ctx, tm.results[0])
|
||||||
case <-tm.done:
|
} else {
|
||||||
tm.mutex.Lock()
|
rs = tm.handleResult(ctx, tm.results)
|
||||||
defer tm.mutex.Unlock()
|
|
||||||
if len(tm.results) == 1 {
|
|
||||||
return tm.handleResult(ctx, tm.results[0])
|
|
||||||
}
|
|
||||||
return tm.handleResult(ctx, tm.results)
|
|
||||||
}
|
}
|
||||||
} else {
|
if tm.waitingCallback == 0 {
|
||||||
tm.finalResult = make(chan mq.Result)
|
tm.finalResult <- rs
|
||||||
tm.wg.Add(1)
|
|
||||||
go tm.processNode(ctx, node, payload)
|
|
||||||
go func() {
|
|
||||||
tm.wg.Wait()
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case result := <-tm.finalResult: // Block until a result is available
|
|
||||||
return result
|
|
||||||
case <-ctx.Done(): // Handle context cancellation
|
|
||||||
return mq.Result{Error: ctx.Err()}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,17 +114,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
|
|||||||
}
|
}
|
||||||
if len(edges) == 0 {
|
if len(edges) == 0 {
|
||||||
tm.appendFinalResult(result)
|
tm.appendFinalResult(result)
|
||||||
if !tm.dag.server.SyncMode() {
|
tm.dispatchFinalResult(ctx)
|
||||||
var rs mq.Result
|
|
||||||
if len(tm.results) == 1 {
|
|
||||||
rs = tm.handleResult(ctx, tm.results[0])
|
|
||||||
} else {
|
|
||||||
rs = tm.handleResult(ctx, tm.results)
|
|
||||||
}
|
|
||||||
if tm.waitingCallback == 0 {
|
|
||||||
tm.finalResult <- rs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
for _, edge := range edges {
|
for _, edge := range edges {
|
||||||
@@ -205,3 +216,10 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
|
|||||||
tm.mutex.Unlock()
|
tm.mutex.Unlock()
|
||||||
tm.handleCallback(ctx, result)
|
tm.handleCallback(ctx, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tm *TaskManager) Clear() error {
|
||||||
|
tm.waitingCallback = 0
|
||||||
|
clear(tm.results)
|
||||||
|
tm.nodeResults = make(map[string]mq.Result)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user