feat: separate broker

This commit is contained in:
Oarkflow
2024-10-09 11:24:54 +05:45
parent 5decaa247b
commit 9c8712994d
5 changed files with 98 additions and 161 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
@@ -14,13 +15,13 @@ import (
type TaskManager struct { type TaskManager struct {
taskID string taskID string
dag *DAG dag *DAG
wg sync.WaitGroup
mutex sync.Mutex mutex sync.Mutex
createdAt time.Time
processedAt time.Time
results []mq.Result results []mq.Result
waitingCallback int64 waitingCallback int64
nodeResults map[string]mq.Result nodeResults map[string]mq.Result
done chan struct{} finalResult chan mq.Result
finalResult chan mq.Result // Channel to collect final results
} }
func NewTaskManager(d *DAG, taskID string) *TaskManager { func NewTaskManager(d *DAG, taskID string) *TaskManager {
@@ -32,64 +33,45 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
} }
} }
func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
tm.done = make(chan struct{})
tm.wg.Add(1)
go tm.processNode(ctx, node, payload)
go func() {
tm.wg.Wait()
close(tm.done)
}()
select {
case <-ctx.Done():
return mq.Result{Error: ctx.Err()}
case <-tm.done:
tm.mutex.Lock()
defer tm.mutex.Unlock()
if len(tm.results) == 1 {
return tm.handleResult(ctx, tm.results[0])
}
return tm.handleResult(ctx, tm.results)
}
}
func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
tm.finalResult = make(chan mq.Result)
tm.wg.Add(1)
go tm.processNode(ctx, node, payload)
go func() {
tm.wg.Wait()
}()
select {
case result := <-tm.finalResult: // Block until a result is available
return result
case <-ctx.Done(): // Handle context cancellation
return mq.Result{Error: ctx.Err()}
}
}
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result { func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
node, ok := tm.dag.Nodes[nodeID] node, ok := tm.dag.Nodes[nodeID]
if !ok { if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)} return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
} }
if tm.dag.server.SyncMode() { tm.createdAt = time.Now()
return tm.handleSyncTask(ctx, node, payload) tm.finalResult = make(chan mq.Result, 0)
go tm.processNode(ctx, node, payload)
awaitResponse, ok := mq.GetAwaitResponse(ctx)
if awaitResponse != "true" {
go func() {
finalResult := <-tm.finalResult
finalResult.CreatedAt = tm.createdAt
finalResult.ProcessedAt = time.Now()
if tm.dag.server.NotifyHandler() != nil {
tm.dag.server.NotifyHandler()(ctx, finalResult)
}
}()
return mq.Result{CreatedAt: tm.createdAt, TaskID: tm.taskID, Topic: nodeID, Status: "PENDING"}
} else {
finalResult := <-tm.finalResult
finalResult.CreatedAt = tm.createdAt
finalResult.ProcessedAt = time.Now()
if tm.dag.server.NotifyHandler() != nil {
tm.dag.server.NotifyHandler()(ctx, finalResult)
}
return finalResult
} }
return tm.handleAsyncTask(ctx, node, payload)
} }
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) { func (tm *TaskManager) dispatchFinalResult(ctx context.Context) {
if !tm.dag.server.SyncMode() { var rs mq.Result
var rs mq.Result if len(tm.results) == 1 {
if len(tm.results) == 1 { rs = tm.handleResult(ctx, tm.results[0])
rs = tm.handleResult(ctx, tm.results[0]) } else {
} else { rs = tm.handleResult(ctx, tm.results)
rs = tm.handleResult(ctx, tm.results) }
} if tm.waitingCallback == 0 {
if tm.waitingCallback == 0 { tm.finalResult <- rs
tm.finalResult <- rs
}
} }
} }
@@ -127,13 +109,11 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
return result return result
} }
for _, item := range items { for _, item := range items {
tm.wg.Add(1)
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
go tm.processNode(ctx, edge.To, item) go tm.processNode(ctx, edge.To, item)
} }
case SimpleEdge: case SimpleEdge:
if edge.To != nil { if edge.To != nil {
tm.wg.Add(1)
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key}) ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
go tm.processNode(ctx, edge.To, result.Payload) go tm.processNode(ctx, edge.To, result.Payload)
} }
@@ -147,10 +127,11 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
switch res := results.(type) { switch res := results.(type) {
case []mq.Result: case []mq.Result:
aggregatedOutput := make([]json.RawMessage, 0) aggregatedOutput := make([]json.RawMessage, 0)
status := "" var status, topic string
for i, result := range res { for i, result := range res {
if i == 0 { if i == 0 {
status = result.Status status = result.Status
topic = result.Topic
} }
if result.Error != nil { if result.Error != nil {
return mq.HandleError(ctx, result.Error) return mq.HandleError(ctx, result.Error)
@@ -170,6 +151,7 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
TaskID: tm.taskID, TaskID: tm.taskID,
Payload: finalOutput, Payload: finalOutput,
Status: status, Status: status,
Topic: topic,
} }
case mq.Result: case mq.Result:
return res return res
@@ -186,7 +168,6 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) {
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) { func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
atomic.AddInt64(&tm.waitingCallback, 1) atomic.AddInt64(&tm.waitingCallback, 1)
defer tm.wg.Done()
var result mq.Result var result mq.Result
select { select {
case <-ctx.Done(): case <-ctx.Done():

View File

@@ -3,6 +3,8 @@ package main
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/examples/tasks"
"io" "io"
"net/http" "net/http"
@@ -10,70 +12,23 @@ import (
"github.com/oarkflow/mq/dag" "github.com/oarkflow/mq/dag"
) )
func handler1(ctx context.Context, task *mq.Task) mq.Result {
return mq.Result{Payload: task.Payload}
}
func handler2(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
return mq.Result{Payload: task.Payload}
}
func handler3(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
age := int(user["age"].(float64))
status := "FAIL"
if age > 20 {
status = "PASS"
}
user["status"] = status
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload, Status: status}
}
func handler4(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
user["final"] = "D"
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload}
}
func handler5(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
user["salary"] = "E"
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload}
}
func handler6(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
resultPayload, _ := json.Marshal(map[string]any{"storage": user})
return mq.Result{Payload: resultPayload}
}
var ( var (
d = dag.NewDAG(mq.WithSyncMode(true)) d = dag.NewDAG(mq.WithSyncMode(false), mq.WithNotifyResponse(tasks.NotifyResponse))
// d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) // d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
) )
func main() { func main() {
d.AddNode("A", handler1, true) d.AddNode("A", tasks.Node1, true)
d.AddNode("B", handler2) d.AddNode("B", tasks.Node2)
d.AddNode("C", handler3) d.AddNode("C", tasks.Node3)
d.AddNode("D", handler4) d.AddNode("D", tasks.Node4)
d.AddNode("E", handler5) d.AddNode("E", tasks.Node5)
d.AddNode("F", handler6) d.AddNode("F", tasks.Node6)
d.AddEdge("A", "B", dag.LoopEdge) d.AddEdge("A", "B", dag.LoopEdge)
d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"}) d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"})
d.AddEdge("B", "C") d.AddEdge("B", "C")
d.AddEdge("D", "F") d.AddEdge("D", "F")
d.AddEdge("E", "F") d.AddEdge("E", "F")
// fmt.Println(rs.TaskID, "Task", string(rs.Payload))
http.HandleFunc("POST /publish", requestHandler("publish")) http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request")) http.HandleFunc("POST /request", requestHandler("request"))
err := d.Start(context.TODO(), ":8083") err := d.Start(context.TODO(), ":8083")
@@ -102,14 +57,12 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ
return return
} }
ctx := context.Background() ctx := context.Background()
if requestType == "request" {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
// ctx = context.WithValue(ctx, "initial_node", "E") // ctx = context.WithValue(ctx, "initial_node", "E")
rs := d.ProcessTask(ctx, payload) rs := d.ProcessTask(ctx, payload)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
result := map[string]any{ json.NewEncoder(w).Encode(rs)
"message_id": rs.TaskID,
"payload": rs.Payload,
"error": rs.Error,
}
json.NewEncoder(w).Encode(result)
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
) )
@@ -17,53 +18,46 @@ func Node2(ctx context.Context, task *mq.Task) mq.Result {
} }
func Node3(ctx context.Context, task *mq.Task) mq.Result { func Node3(ctx context.Context, task *mq.Task) mq.Result {
var data map[string]any var user map[string]any
err := json.Unmarshal(task.Payload, &data) json.Unmarshal(task.Payload, &user)
if err != nil { age := int(user["age"].(float64))
return mq.Result{Error: err} status := "FAIL"
if age > 20 {
status = "PASS"
} }
data["salary"] = fmt.Sprintf("12000%v", data["user_id"]) user["status"] = status
bt, _ := json.Marshal(data) resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: bt, TaskID: task.ID} return mq.Result{Payload: resultPayload, Status: status}
} }
func Node4(ctx context.Context, task *mq.Task) mq.Result { func Node4(ctx context.Context, task *mq.Task) mq.Result {
var data []map[string]any var user map[string]any
err := json.Unmarshal(task.Payload, &data) json.Unmarshal(task.Payload, &user)
if err != nil { user["final"] = "D"
return mq.Result{Error: err} resultPayload, _ := json.Marshal(user)
} return mq.Result{Payload: resultPayload}
payload := map[string]any{"storage": data}
bt, _ := json.Marshal(payload)
return mq.Result{Payload: bt, TaskID: task.ID}
} }
func CheckCondition(ctx context.Context, task *mq.Task) mq.Result { func Node5(ctx context.Context, task *mq.Task) mq.Result {
var data map[string]any var user map[string]any
err := json.Unmarshal(task.Payload, &data) json.Unmarshal(task.Payload, &user)
if err != nil { user["salary"] = "E"
return mq.Result{Error: err} resultPayload, _ := json.Marshal(user)
} return mq.Result{Payload: resultPayload}
var status string
if data["user_id"].(float64) == 2 {
status = "pass"
} else {
status = "fail"
}
return mq.Result{Status: status, Payload: task.Payload, TaskID: task.ID}
} }
func Pass(ctx context.Context, task *mq.Task) mq.Result { func Node6(ctx context.Context, task *mq.Task) mq.Result {
fmt.Println("Pass") var user map[string]any
return mq.Result{Payload: task.Payload} json.Unmarshal(task.Payload, &user)
} resultPayload, _ := json.Marshal(map[string]any{"storage": user})
return mq.Result{Payload: resultPayload}
func Fail(ctx context.Context, task *mq.Task) mq.Result {
fmt.Println("Fail")
return mq.Result{Payload: []byte(`{"test2": "asdsa"}`)}
} }
func Callback(ctx context.Context, task mq.Result) mq.Result { func Callback(ctx context.Context, task mq.Result) mq.Result {
fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic) fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic)
return mq.Result{} return mq.Result{}
} }
func NotifyResponse(ctx context.Context, result mq.Result) {
log.Printf("DAG Final response: TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic)
}

View File

@@ -8,11 +8,13 @@ import (
) )
type Result struct { type Result struct {
Payload json.RawMessage `json:"payload"` Payload json.RawMessage `json:"payload"`
Topic string `json:"topic"` Topic string `json:"topic"`
TaskID string `json:"task_id"` CreatedAt time.Time `json:"created_at"`
Error error `json:"error,omitempty"` ProcessedAt time.Time `json:"processed_at,omitempty"`
Status string `json:"status"` TaskID string `json:"task_id"`
Error error `json:"error,omitempty"`
Status string `json:"status"`
} }
func (r Result) Unmarshal(data any) error { func (r Result) Unmarshal(data any) error {
@@ -22,10 +24,6 @@ func (r Result) Unmarshal(data any) error {
return json.Unmarshal(r.Payload, data) return json.Unmarshal(r.Payload, data)
} }
func (r Result) String() string {
return string(r.Payload)
}
func HandleError(ctx context.Context, err error, status ...string) Result { func HandleError(ctx context.Context, err error, status ...string) Result {
st := "Failed" st := "Failed"
if len(status) > 0 { if len(status) > 0 {
@@ -63,6 +61,7 @@ type Options struct {
brokerAddr string brokerAddr string
callback []func(context.Context, Result) Result callback []func(context.Context, Result) Result
maxRetries int maxRetries int
notifyResponse func(context.Context, Result)
initialDelay time.Duration initialDelay time.Duration
maxBackoff time.Duration maxBackoff time.Duration
jitterPercent float64 jitterPercent float64
@@ -96,6 +95,12 @@ func setupOptions(opts ...Option) Options {
return options return options
} }
func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option {
return func(opts *Options) {
opts.notifyResponse = handler
}
}
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option { func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
return func(opts *Options) { return func(opts *Options) {
opts.aesKey = aesKey opts.aesKey = aesKey

View File

@@ -15,6 +15,10 @@ func (b *Broker) SyncMode() bool {
return b.opts.syncMode return b.opts.syncMode
} }
func (b *Broker) NotifyHandler() func(context.Context, Result) {
return b.opts.notifyResponse
}
func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) { func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) {
if b.opts.callback != nil { if b.opts.callback != nil {
var result Result var result Result