mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-04 15:42:49 +08:00
feat: separate broker
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
@@ -14,13 +15,13 @@ import (
|
||||
type TaskManager struct {
|
||||
taskID string
|
||||
dag *DAG
|
||||
wg sync.WaitGroup
|
||||
mutex sync.Mutex
|
||||
createdAt time.Time
|
||||
processedAt time.Time
|
||||
results []mq.Result
|
||||
waitingCallback int64
|
||||
nodeResults map[string]mq.Result
|
||||
done chan struct{}
|
||||
finalResult chan mq.Result // Channel to collect final results
|
||||
finalResult chan mq.Result
|
||||
}
|
||||
|
||||
func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
||||
@@ -32,55 +33,37 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
|
||||
tm.done = make(chan struct{})
|
||||
tm.wg.Add(1)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
go func() {
|
||||
tm.wg.Wait()
|
||||
close(tm.done)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return mq.Result{Error: ctx.Err()}
|
||||
case <-tm.done:
|
||||
tm.mutex.Lock()
|
||||
defer tm.mutex.Unlock()
|
||||
if len(tm.results) == 1 {
|
||||
return tm.handleResult(ctx, tm.results[0])
|
||||
}
|
||||
return tm.handleResult(ctx, tm.results)
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
|
||||
tm.finalResult = make(chan mq.Result)
|
||||
tm.wg.Add(1)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
go func() {
|
||||
tm.wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case result := <-tm.finalResult: // Block until a result is available
|
||||
return result
|
||||
case <-ctx.Done(): // Handle context cancellation
|
||||
return mq.Result{Error: ctx.Err()}
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
|
||||
node, ok := tm.dag.Nodes[nodeID]
|
||||
if !ok {
|
||||
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
|
||||
}
|
||||
if tm.dag.server.SyncMode() {
|
||||
return tm.handleSyncTask(ctx, node, payload)
|
||||
tm.createdAt = time.Now()
|
||||
tm.finalResult = make(chan mq.Result, 0)
|
||||
go tm.processNode(ctx, node, payload)
|
||||
awaitResponse, ok := mq.GetAwaitResponse(ctx)
|
||||
if awaitResponse != "true" {
|
||||
go func() {
|
||||
finalResult := <-tm.finalResult
|
||||
finalResult.CreatedAt = tm.createdAt
|
||||
finalResult.ProcessedAt = time.Now()
|
||||
if tm.dag.server.NotifyHandler() != nil {
|
||||
tm.dag.server.NotifyHandler()(ctx, finalResult)
|
||||
}
|
||||
}()
|
||||
return mq.Result{CreatedAt: tm.createdAt, TaskID: tm.taskID, Topic: nodeID, Status: "PENDING"}
|
||||
} else {
|
||||
finalResult := <-tm.finalResult
|
||||
finalResult.CreatedAt = tm.createdAt
|
||||
finalResult.ProcessedAt = time.Now()
|
||||
if tm.dag.server.NotifyHandler() != nil {
|
||||
tm.dag.server.NotifyHandler()(ctx, finalResult)
|
||||
}
|
||||
return finalResult
|
||||
}
|
||||
return tm.handleAsyncTask(ctx, node, payload)
|
||||
}
|
||||
|
||||
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) {
|
||||
if !tm.dag.server.SyncMode() {
|
||||
var rs mq.Result
|
||||
if len(tm.results) == 1 {
|
||||
rs = tm.handleResult(ctx, tm.results[0])
|
||||
@@ -90,7 +73,6 @@ func (tm *TaskManager) dispatchFinalResult(ctx context.Context) {
|
||||
if tm.waitingCallback == 0 {
|
||||
tm.finalResult <- rs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.Result {
|
||||
@@ -127,13 +109,11 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
|
||||
return result
|
||||
}
|
||||
for _, item := range items {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, item)
|
||||
}
|
||||
case SimpleEdge:
|
||||
if edge.To != nil {
|
||||
tm.wg.Add(1)
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
|
||||
go tm.processNode(ctx, edge.To, result.Payload)
|
||||
}
|
||||
@@ -147,10 +127,11 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
|
||||
switch res := results.(type) {
|
||||
case []mq.Result:
|
||||
aggregatedOutput := make([]json.RawMessage, 0)
|
||||
status := ""
|
||||
var status, topic string
|
||||
for i, result := range res {
|
||||
if i == 0 {
|
||||
status = result.Status
|
||||
topic = result.Topic
|
||||
}
|
||||
if result.Error != nil {
|
||||
return mq.HandleError(ctx, result.Error)
|
||||
@@ -170,6 +151,7 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
|
||||
TaskID: tm.taskID,
|
||||
Payload: finalOutput,
|
||||
Status: status,
|
||||
Topic: topic,
|
||||
}
|
||||
case mq.Result:
|
||||
return res
|
||||
@@ -186,7 +168,6 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) {
|
||||
|
||||
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
|
||||
atomic.AddInt64(&tm.waitingCallback, 1)
|
||||
defer tm.wg.Done()
|
||||
var result mq.Result
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
@@ -3,6 +3,8 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
@@ -10,70 +12,23 @@ import (
|
||||
"github.com/oarkflow/mq/dag"
|
||||
)
|
||||
|
||||
func handler1(ctx context.Context, task *mq.Task) mq.Result {
|
||||
return mq.Result{Payload: task.Payload}
|
||||
}
|
||||
|
||||
func handler2(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
return mq.Result{Payload: task.Payload}
|
||||
}
|
||||
|
||||
func handler3(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
age := int(user["age"].(float64))
|
||||
status := "FAIL"
|
||||
if age > 20 {
|
||||
status = "PASS"
|
||||
}
|
||||
user["status"] = status
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload, Status: status}
|
||||
}
|
||||
|
||||
func handler4(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
user["final"] = "D"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func handler5(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
user["salary"] = "E"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func handler6(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
resultPayload, _ := json.Marshal(map[string]any{"storage": user})
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
var (
|
||||
d = dag.NewDAG(mq.WithSyncMode(true))
|
||||
d = dag.NewDAG(mq.WithSyncMode(false), mq.WithNotifyResponse(tasks.NotifyResponse))
|
||||
// d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
||||
)
|
||||
|
||||
func main() {
|
||||
d.AddNode("A", handler1, true)
|
||||
d.AddNode("B", handler2)
|
||||
d.AddNode("C", handler3)
|
||||
d.AddNode("D", handler4)
|
||||
d.AddNode("E", handler5)
|
||||
d.AddNode("F", handler6)
|
||||
d.AddNode("A", tasks.Node1, true)
|
||||
d.AddNode("B", tasks.Node2)
|
||||
d.AddNode("C", tasks.Node3)
|
||||
d.AddNode("D", tasks.Node4)
|
||||
d.AddNode("E", tasks.Node5)
|
||||
d.AddNode("F", tasks.Node6)
|
||||
d.AddEdge("A", "B", dag.LoopEdge)
|
||||
d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"})
|
||||
d.AddEdge("B", "C")
|
||||
d.AddEdge("D", "F")
|
||||
d.AddEdge("E", "F")
|
||||
// fmt.Println(rs.TaskID, "Task", string(rs.Payload))
|
||||
http.HandleFunc("POST /publish", requestHandler("publish"))
|
||||
http.HandleFunc("POST /request", requestHandler("request"))
|
||||
err := d.Start(context.TODO(), ":8083")
|
||||
@@ -102,14 +57,12 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
if requestType == "request" {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
|
||||
}
|
||||
// ctx = context.WithValue(ctx, "initial_node", "E")
|
||||
rs := d.ProcessTask(ctx, payload)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
result := map[string]any{
|
||||
"message_id": rs.TaskID,
|
||||
"payload": rs.Payload,
|
||||
"error": rs.Error,
|
||||
}
|
||||
json.NewEncoder(w).Encode(result)
|
||||
json.NewEncoder(w).Encode(rs)
|
||||
}
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
@@ -17,53 +18,46 @@ func Node2(ctx context.Context, task *mq.Task) mq.Result {
|
||||
}
|
||||
|
||||
func Node3(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var data map[string]any
|
||||
err := json.Unmarshal(task.Payload, &data)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
age := int(user["age"].(float64))
|
||||
status := "FAIL"
|
||||
if age > 20 {
|
||||
status = "PASS"
|
||||
}
|
||||
data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
|
||||
bt, _ := json.Marshal(data)
|
||||
return mq.Result{Payload: bt, TaskID: task.ID}
|
||||
user["status"] = status
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload, Status: status}
|
||||
}
|
||||
|
||||
func Node4(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var data []map[string]any
|
||||
err := json.Unmarshal(task.Payload, &data)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
payload := map[string]any{"storage": data}
|
||||
bt, _ := json.Marshal(payload)
|
||||
return mq.Result{Payload: bt, TaskID: task.ID}
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
user["final"] = "D"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func CheckCondition(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var data map[string]any
|
||||
err := json.Unmarshal(task.Payload, &data)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
var status string
|
||||
if data["user_id"].(float64) == 2 {
|
||||
status = "pass"
|
||||
} else {
|
||||
status = "fail"
|
||||
}
|
||||
return mq.Result{Status: status, Payload: task.Payload, TaskID: task.ID}
|
||||
func Node5(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
user["salary"] = "E"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func Pass(ctx context.Context, task *mq.Task) mq.Result {
|
||||
fmt.Println("Pass")
|
||||
return mq.Result{Payload: task.Payload}
|
||||
}
|
||||
|
||||
func Fail(ctx context.Context, task *mq.Task) mq.Result {
|
||||
fmt.Println("Fail")
|
||||
return mq.Result{Payload: []byte(`{"test2": "asdsa"}`)}
|
||||
func Node6(ctx context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
resultPayload, _ := json.Marshal(map[string]any{"storage": user})
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func Callback(ctx context.Context, task mq.Result) mq.Result {
|
||||
fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic)
|
||||
return mq.Result{}
|
||||
}
|
||||
|
||||
func NotifyResponse(ctx context.Context, result mq.Result) {
|
||||
log.Printf("DAG Final response: TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic)
|
||||
}
|
||||
|
13
options.go
13
options.go
@@ -10,6 +10,8 @@ import (
|
||||
type Result struct {
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
Topic string `json:"topic"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ProcessedAt time.Time `json:"processed_at,omitempty"`
|
||||
TaskID string `json:"task_id"`
|
||||
Error error `json:"error,omitempty"`
|
||||
Status string `json:"status"`
|
||||
@@ -22,10 +24,6 @@ func (r Result) Unmarshal(data any) error {
|
||||
return json.Unmarshal(r.Payload, data)
|
||||
}
|
||||
|
||||
func (r Result) String() string {
|
||||
return string(r.Payload)
|
||||
}
|
||||
|
||||
func HandleError(ctx context.Context, err error, status ...string) Result {
|
||||
st := "Failed"
|
||||
if len(status) > 0 {
|
||||
@@ -63,6 +61,7 @@ type Options struct {
|
||||
brokerAddr string
|
||||
callback []func(context.Context, Result) Result
|
||||
maxRetries int
|
||||
notifyResponse func(context.Context, Result)
|
||||
initialDelay time.Duration
|
||||
maxBackoff time.Duration
|
||||
jitterPercent float64
|
||||
@@ -96,6 +95,12 @@ func setupOptions(opts ...Option) Options {
|
||||
return options
|
||||
}
|
||||
|
||||
func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option {
|
||||
return func(opts *Options) {
|
||||
opts.notifyResponse = handler
|
||||
}
|
||||
}
|
||||
|
||||
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
|
||||
return func(opts *Options) {
|
||||
opts.aesKey = aesKey
|
||||
|
4
util.go
4
util.go
@@ -15,6 +15,10 @@ func (b *Broker) SyncMode() bool {
|
||||
return b.opts.syncMode
|
||||
}
|
||||
|
||||
func (b *Broker) NotifyHandler() func(context.Context, Result) {
|
||||
return b.opts.notifyResponse
|
||||
}
|
||||
|
||||
func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) {
|
||||
if b.opts.callback != nil {
|
||||
var result Result
|
||||
|
Reference in New Issue
Block a user