feat: separate broker

This commit is contained in:
Oarkflow
2024-10-09 11:24:54 +05:45
parent 5decaa247b
commit 9c8712994d
5 changed files with 98 additions and 161 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
@@ -14,13 +15,13 @@ import (
type TaskManager struct {
taskID string
dag *DAG
wg sync.WaitGroup
mutex sync.Mutex
createdAt time.Time
processedAt time.Time
results []mq.Result
waitingCallback int64
nodeResults map[string]mq.Result
done chan struct{}
finalResult chan mq.Result // Channel to collect final results
finalResult chan mq.Result
}
func NewTaskManager(d *DAG, taskID string) *TaskManager {
@@ -32,64 +33,45 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
}
}
func (tm *TaskManager) handleSyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
tm.done = make(chan struct{})
tm.wg.Add(1)
go tm.processNode(ctx, node, payload)
go func() {
tm.wg.Wait()
close(tm.done)
}()
select {
case <-ctx.Done():
return mq.Result{Error: ctx.Err()}
case <-tm.done:
tm.mutex.Lock()
defer tm.mutex.Unlock()
if len(tm.results) == 1 {
return tm.handleResult(ctx, tm.results[0])
}
return tm.handleResult(ctx, tm.results)
}
}
func (tm *TaskManager) handleAsyncTask(ctx context.Context, node *Node, payload json.RawMessage) mq.Result {
tm.finalResult = make(chan mq.Result)
tm.wg.Add(1)
go tm.processNode(ctx, node, payload)
go func() {
tm.wg.Wait()
}()
select {
case result := <-tm.finalResult: // Block until a result is available
return result
case <-ctx.Done(): // Handle context cancellation
return mq.Result{Error: ctx.Err()}
}
}
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
node, ok := tm.dag.Nodes[nodeID]
if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
}
if tm.dag.server.SyncMode() {
return tm.handleSyncTask(ctx, node, payload)
tm.createdAt = time.Now()
tm.finalResult = make(chan mq.Result, 0)
go tm.processNode(ctx, node, payload)
awaitResponse, ok := mq.GetAwaitResponse(ctx)
if awaitResponse != "true" {
go func() {
finalResult := <-tm.finalResult
finalResult.CreatedAt = tm.createdAt
finalResult.ProcessedAt = time.Now()
if tm.dag.server.NotifyHandler() != nil {
tm.dag.server.NotifyHandler()(ctx, finalResult)
}
}()
return mq.Result{CreatedAt: tm.createdAt, TaskID: tm.taskID, Topic: nodeID, Status: "PENDING"}
} else {
finalResult := <-tm.finalResult
finalResult.CreatedAt = tm.createdAt
finalResult.ProcessedAt = time.Now()
if tm.dag.server.NotifyHandler() != nil {
tm.dag.server.NotifyHandler()(ctx, finalResult)
}
return finalResult
}
return tm.handleAsyncTask(ctx, node, payload)
}
func (tm *TaskManager) dispatchFinalResult(ctx context.Context) {
if !tm.dag.server.SyncMode() {
var rs mq.Result
if len(tm.results) == 1 {
rs = tm.handleResult(ctx, tm.results[0])
} else {
rs = tm.handleResult(ctx, tm.results)
}
if tm.waitingCallback == 0 {
tm.finalResult <- rs
}
var rs mq.Result
if len(tm.results) == 1 {
rs = tm.handleResult(ctx, tm.results[0])
} else {
rs = tm.handleResult(ctx, tm.results)
}
if tm.waitingCallback == 0 {
tm.finalResult <- rs
}
}
@@ -127,13 +109,11 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
return result
}
for _, item := range items {
tm.wg.Add(1)
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
go tm.processNode(ctx, edge.To, item)
}
case SimpleEdge:
if edge.To != nil {
tm.wg.Add(1)
ctx = mq.SetHeaders(ctx, map[string]string{consts.QueueKey: edge.To.Key})
go tm.processNode(ctx, edge.To, result.Payload)
}
@@ -147,10 +127,11 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
switch res := results.(type) {
case []mq.Result:
aggregatedOutput := make([]json.RawMessage, 0)
status := ""
var status, topic string
for i, result := range res {
if i == 0 {
status = result.Status
topic = result.Topic
}
if result.Error != nil {
return mq.HandleError(ctx, result.Error)
@@ -170,6 +151,7 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
TaskID: tm.taskID,
Payload: finalOutput,
Status: status,
Topic: topic,
}
case mq.Result:
return res
@@ -186,7 +168,6 @@ func (tm *TaskManager) appendFinalResult(result mq.Result) {
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
atomic.AddInt64(&tm.waitingCallback, 1)
defer tm.wg.Done()
var result mq.Result
select {
case <-ctx.Done():

View File

@@ -3,6 +3,8 @@ package main
import (
"context"
"encoding/json"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/examples/tasks"
"io"
"net/http"
@@ -10,70 +12,23 @@ import (
"github.com/oarkflow/mq/dag"
)
func handler1(ctx context.Context, task *mq.Task) mq.Result {
return mq.Result{Payload: task.Payload}
}
func handler2(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
return mq.Result{Payload: task.Payload}
}
func handler3(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
age := int(user["age"].(float64))
status := "FAIL"
if age > 20 {
status = "PASS"
}
user["status"] = status
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload, Status: status}
}
func handler4(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
user["final"] = "D"
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload}
}
func handler5(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
user["salary"] = "E"
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload}
}
func handler6(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
resultPayload, _ := json.Marshal(map[string]any{"storage": user})
return mq.Result{Payload: resultPayload}
}
var (
d = dag.NewDAG(mq.WithSyncMode(true))
d = dag.NewDAG(mq.WithSyncMode(false), mq.WithNotifyResponse(tasks.NotifyResponse))
// d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
)
func main() {
d.AddNode("A", handler1, true)
d.AddNode("B", handler2)
d.AddNode("C", handler3)
d.AddNode("D", handler4)
d.AddNode("E", handler5)
d.AddNode("F", handler6)
d.AddNode("A", tasks.Node1, true)
d.AddNode("B", tasks.Node2)
d.AddNode("C", tasks.Node3)
d.AddNode("D", tasks.Node4)
d.AddNode("E", tasks.Node5)
d.AddNode("F", tasks.Node6)
d.AddEdge("A", "B", dag.LoopEdge)
d.AddCondition("C", map[string]string{"PASS": "D", "FAIL": "E"})
d.AddEdge("B", "C")
d.AddEdge("D", "F")
d.AddEdge("E", "F")
// fmt.Println(rs.TaskID, "Task", string(rs.Payload))
http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request"))
err := d.Start(context.TODO(), ":8083")
@@ -102,14 +57,12 @@ func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Requ
return
}
ctx := context.Background()
if requestType == "request" {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
// ctx = context.WithValue(ctx, "initial_node", "E")
rs := d.ProcessTask(ctx, payload)
w.Header().Set("Content-Type", "application/json")
result := map[string]any{
"message_id": rs.TaskID,
"payload": rs.Payload,
"error": rs.Error,
}
json.NewEncoder(w).Encode(result)
json.NewEncoder(w).Encode(rs)
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"github.com/oarkflow/mq"
)
@@ -17,53 +18,46 @@ func Node2(ctx context.Context, task *mq.Task) mq.Result {
}
func Node3(ctx context.Context, task *mq.Task) mq.Result {
var data map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
var user map[string]any
json.Unmarshal(task.Payload, &user)
age := int(user["age"].(float64))
status := "FAIL"
if age > 20 {
status = "PASS"
}
data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
bt, _ := json.Marshal(data)
return mq.Result{Payload: bt, TaskID: task.ID}
user["status"] = status
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload, Status: status}
}
func Node4(ctx context.Context, task *mq.Task) mq.Result {
var data []map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
}
payload := map[string]any{"storage": data}
bt, _ := json.Marshal(payload)
return mq.Result{Payload: bt, TaskID: task.ID}
var user map[string]any
json.Unmarshal(task.Payload, &user)
user["final"] = "D"
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload}
}
func CheckCondition(ctx context.Context, task *mq.Task) mq.Result {
var data map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
}
var status string
if data["user_id"].(float64) == 2 {
status = "pass"
} else {
status = "fail"
}
return mq.Result{Status: status, Payload: task.Payload, TaskID: task.ID}
func Node5(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
user["salary"] = "E"
resultPayload, _ := json.Marshal(user)
return mq.Result{Payload: resultPayload}
}
func Pass(ctx context.Context, task *mq.Task) mq.Result {
fmt.Println("Pass")
return mq.Result{Payload: task.Payload}
}
func Fail(ctx context.Context, task *mq.Task) mq.Result {
fmt.Println("Fail")
return mq.Result{Payload: []byte(`{"test2": "asdsa"}`)}
func Node6(ctx context.Context, task *mq.Task) mq.Result {
var user map[string]any
json.Unmarshal(task.Payload, &user)
resultPayload, _ := json.Marshal(map[string]any{"storage": user})
return mq.Result{Payload: resultPayload}
}
func Callback(ctx context.Context, task mq.Result) mq.Result {
fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic)
return mq.Result{}
}
func NotifyResponse(ctx context.Context, result mq.Result) {
log.Printf("DAG Final response: TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic)
}

View File

@@ -8,11 +8,13 @@ import (
)
type Result struct {
Payload json.RawMessage `json:"payload"`
Topic string `json:"topic"`
TaskID string `json:"task_id"`
Error error `json:"error,omitempty"`
Status string `json:"status"`
Payload json.RawMessage `json:"payload"`
Topic string `json:"topic"`
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at,omitempty"`
TaskID string `json:"task_id"`
Error error `json:"error,omitempty"`
Status string `json:"status"`
}
func (r Result) Unmarshal(data any) error {
@@ -22,10 +24,6 @@ func (r Result) Unmarshal(data any) error {
return json.Unmarshal(r.Payload, data)
}
func (r Result) String() string {
return string(r.Payload)
}
func HandleError(ctx context.Context, err error, status ...string) Result {
st := "Failed"
if len(status) > 0 {
@@ -63,6 +61,7 @@ type Options struct {
brokerAddr string
callback []func(context.Context, Result) Result
maxRetries int
notifyResponse func(context.Context, Result)
initialDelay time.Duration
maxBackoff time.Duration
jitterPercent float64
@@ -96,6 +95,12 @@ func setupOptions(opts ...Option) Options {
return options
}
func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option {
return func(opts *Options) {
opts.notifyResponse = handler
}
}
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
return func(opts *Options) {
opts.aesKey = aesKey

View File

@@ -15,6 +15,10 @@ func (b *Broker) SyncMode() bool {
return b.opts.syncMode
}
func (b *Broker) NotifyHandler() func(context.Context, Result) {
return b.opts.notifyResponse
}
func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) {
if b.opts.callback != nil {
var result Result