diff --git a/codec/codec.go b/codec/codec.go index 92ee8d4..a61f0b5 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -5,9 +5,9 @@ import ( "context" "encoding/binary" "net" - "sync" "github.com/oarkflow/mq/consts" + "github.com/oarkflow/mq/internal/bpool" ) type Message struct { @@ -45,33 +45,28 @@ func Deserialize(data []byte) (*Message, error) { return &msg, nil } -var byteBufferPool = sync.Pool{ - New: func() interface{} { - return make([]byte, 0, 4096) - }, -} - func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error { data, err := msg.Serialize() if err != nil { return err } totalLength := 4 + len(data) - buffer := byteBufferPool.Get().([]byte) - if cap(buffer) < totalLength { - buffer = make([]byte, totalLength) + buffer := bpool.Get() + defer bpool.Put(buffer) + buffer.Reset() + if cap(buffer.B) < totalLength { + buffer.B = make([]byte, totalLength) } else { - buffer = buffer[:totalLength] + buffer.B = buffer.B[:totalLength] } - defer byteBufferPool.Put(buffer) - binary.BigEndian.PutUint32(buffer[:4], uint32(len(data))) - copy(buffer[4:], data) + binary.BigEndian.PutUint32(buffer.B[:4], uint32(len(data))) + copy(buffer.B[4:], data) writer := bufio.NewWriter(conn) select { case <-ctx.Done(): return ctx.Err() default: - if _, err := writer.Write(buffer); err != nil { + if _, err := writer.Write(buffer.B[:totalLength]); err != nil { return err } } @@ -84,20 +79,26 @@ func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) { return nil, err } length := binary.BigEndian.Uint32(lengthBytes) - data := byteBufferPool.Get().([]byte)[:length] - defer byteBufferPool.Put(data) + buffer := bpool.Get() + defer bpool.Put(buffer) + buffer.Reset() + if cap(buffer.B) < int(length) { + buffer.B = make([]byte, length) + } else { + buffer.B = buffer.B[:length] + } totalRead := 0 for totalRead < int(length) { select { case <-ctx.Done(): return nil, ctx.Err() default: - n, err := conn.Read(data[totalRead:]) + n, err := conn.Read(buffer.B[totalRead:]) if err != nil { return nil, err } totalRead += n } } - return Deserialize(data[:length]) + return Deserialize(buffer.B[:length]) } diff --git a/dag/dag.go b/dag/dag.go index 4fac96e..9559bd4 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -234,7 +234,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { return http.ListenAndServe(addr, nil) } -func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) { +func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG { dag.AssignTopic(key) tm.mu.Lock() defer tm.mu.Unlock() @@ -247,6 +247,7 @@ func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) if len(firstNode) > 0 && firstNode[0] { tm.startNode = key } + return tm } func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG { diff --git a/examples/dag.go b/examples/dag.go index d5b4514..03d927b 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -15,6 +15,16 @@ func main() { aSync() } +func subDAG() *dag.DAG { + f := dag.NewDAG("Sub DAG", "sub-dag", mq.WithCleanTaskOnComplete(), mq.WithSyncMode(true)) + f. + AddNode("Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: "process"}}, true). + AddNode("Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: "process"}}). + AddNode("Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: "process"}}). + AddEdge("Store Data to send sms and notification", "store:data", "send:sms", "notification") + return f +} + func setup(f *dag.DAG) { f. AddNode("Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: dag.Operation{Type: "process"}}). @@ -22,14 +32,11 @@ func setup(f *dag.DAG) { AddNode("Get Input", "get:input", &tasks.GetData{Operation: dag.Operation{Type: "input"}}, true). AddNode("Iterator Processor", "loop", &tasks.Loop{Operation: dag.Operation{Type: "loop"}}). AddNode("Condition", "condition", &tasks.Condition{Operation: dag.Operation{Type: "condition"}}). - AddNode("Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: "process"}}). - AddNode("Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: "process"}}). - AddNode("Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: "process"}}). - AddCondition("condition", map[dag.When]dag.Then{"pass": "email:deliver", "fail": "store:data"}). + AddDAGNode("Persistent", "persistent", subDAG()). + AddCondition("condition", map[dag.When]dag.Then{"pass": "email:deliver", "fail": "persistent"}). AddEdge("Get input to loop", "get:input", "loop"). AddIterator("Loop to prepare email", "loop", "prepare:email"). - AddEdge("Prepare Email to condition", "prepare:email", "condition"). - AddEdge("Store Data to send sms and notification", "store:data", "send:sms", "notification") + AddEdge("Prepare Email to condition", "prepare:email", "condition") } func sendData(f *dag.DAG) { diff --git a/examples/subdag.go b/examples/subdag.go index 87ea254..55d9d4d 100644 --- a/examples/subdag.go +++ b/examples/subdag.go @@ -2,11 +2,7 @@ package main import ( "context" - "encoding/json" "fmt" - "github.com/oarkflow/mq/consts" - "io" - "net/http" "github.com/oarkflow/mq/examples/tasks" @@ -42,64 +38,8 @@ func main() { d.AddCondition("C", map[dag.When]dag.Then{"PASS": "D", "FAIL": "E"}) d.AddEdge("Label 1", "B", "C") - // Classify edges - // d.ClassifyEdges() fmt.Println(d.ExportDOT()) - requestHandler := func(requestType string) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var payload []byte - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err = io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return - } - ctx := r.Context() - if requestType == "request" { - ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) - } - // ctx = context.WithValue(ctx, "initial_node", "E") - rs := d.Process(ctx, payload) - if rs.Error != nil { - http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(rs) - } - } - - http.HandleFunc("POST /publish", requestHandler("publish")) - http.HandleFunc("POST /request", requestHandler("request")) - http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { - id := request.PathValue("id") - if id != "" { - d.PauseConsumer(request.Context(), id) - } - }) - http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { - id := request.PathValue("id") - if id != "" { - d.ResumeConsumer(request.Context(), id) - } - }) - http.HandleFunc("/pause", func(writer http.ResponseWriter, request *http.Request) { - d.Pause(request.Context()) - }) - http.HandleFunc("/resume", func(writer http.ResponseWriter, request *http.Request) { - d.Resume(request.Context()) - }) err := d.Start(context.TODO(), ":8083") if err != nil { panic(err) diff --git a/sio/socket.go b/sio/socket.go index f6148b2..9a93204 100644 --- a/sio/socket.go +++ b/sio/socket.go @@ -4,14 +4,15 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/oarkflow/mq/internal/bpool" - "github.com/oarkflow/mq/storage" - "github.com/oarkflow/mq/storage/memory" "log/slog" "net/http" "sync" "time" + "github.com/oarkflow/mq/internal/bpool" + "github.com/oarkflow/mq/storage" + "github.com/oarkflow/mq/storage/memory" + "github.com/gorilla/websocket" )