feat: Add connection

This commit is contained in:
sujit
2024-11-11 09:33:01 +05:45
parent 9557053f30
commit bf3ab4bca6
5 changed files with 39 additions and 89 deletions

View File

@@ -5,9 +5,9 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"net" "net"
"sync"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/internal/bpool"
) )
type Message struct { type Message struct {
@@ -45,33 +45,28 @@ func Deserialize(data []byte) (*Message, error) {
return &msg, nil return &msg, nil
} }
var byteBufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 0, 4096)
},
}
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error { func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
data, err := msg.Serialize() data, err := msg.Serialize()
if err != nil { if err != nil {
return err return err
} }
totalLength := 4 + len(data) totalLength := 4 + len(data)
buffer := byteBufferPool.Get().([]byte) buffer := bpool.Get()
if cap(buffer) < totalLength { defer bpool.Put(buffer)
buffer = make([]byte, totalLength) buffer.Reset()
if cap(buffer.B) < totalLength {
buffer.B = make([]byte, totalLength)
} else { } else {
buffer = buffer[:totalLength] buffer.B = buffer.B[:totalLength]
} }
defer byteBufferPool.Put(buffer) binary.BigEndian.PutUint32(buffer.B[:4], uint32(len(data)))
binary.BigEndian.PutUint32(buffer[:4], uint32(len(data))) copy(buffer.B[4:], data)
copy(buffer[4:], data)
writer := bufio.NewWriter(conn) writer := bufio.NewWriter(conn)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
default: default:
if _, err := writer.Write(buffer); err != nil { if _, err := writer.Write(buffer.B[:totalLength]); err != nil {
return err return err
} }
} }
@@ -84,20 +79,26 @@ func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
return nil, err return nil, err
} }
length := binary.BigEndian.Uint32(lengthBytes) length := binary.BigEndian.Uint32(lengthBytes)
data := byteBufferPool.Get().([]byte)[:length] buffer := bpool.Get()
defer byteBufferPool.Put(data) defer bpool.Put(buffer)
buffer.Reset()
if cap(buffer.B) < int(length) {
buffer.B = make([]byte, length)
} else {
buffer.B = buffer.B[:length]
}
totalRead := 0 totalRead := 0
for totalRead < int(length) { for totalRead < int(length) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
default: default:
n, err := conn.Read(data[totalRead:]) n, err := conn.Read(buffer.B[totalRead:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
totalRead += n totalRead += n
} }
} }
return Deserialize(data[:length]) return Deserialize(buffer.B[:length])
} }

View File

@@ -234,7 +234,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
return http.ListenAndServe(addr, nil) return http.ListenAndServe(addr, nil)
} }
func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) { func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool) *DAG {
dag.AssignTopic(key) dag.AssignTopic(key)
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
@@ -247,6 +247,7 @@ func (tm *DAG) AddDAGNode(name string, key string, dag *DAG, firstNode ...bool)
if len(firstNode) > 0 && firstNode[0] { if len(firstNode) > 0 && firstNode[0] {
tm.startNode = key tm.startNode = key
} }
return tm
} }
func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG { func (tm *DAG) AddNode(name, key string, handler mq.Processor, firstNode ...bool) *DAG {

View File

@@ -15,6 +15,16 @@ func main() {
aSync() aSync()
} }
func subDAG() *dag.DAG {
f := dag.NewDAG("Sub DAG", "sub-dag", mq.WithCleanTaskOnComplete(), mq.WithSyncMode(true))
f.
AddNode("Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: "process"}}, true).
AddNode("Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: "process"}}).
AddNode("Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: "process"}}).
AddEdge("Store Data to send sms and notification", "store:data", "send:sms", "notification")
return f
}
func setup(f *dag.DAG) { func setup(f *dag.DAG) {
f. f.
AddNode("Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: dag.Operation{Type: "process"}}). AddNode("Email Delivery", "email:deliver", &tasks.EmailDelivery{Operation: dag.Operation{Type: "process"}}).
@@ -22,14 +32,11 @@ func setup(f *dag.DAG) {
AddNode("Get Input", "get:input", &tasks.GetData{Operation: dag.Operation{Type: "input"}}, true). AddNode("Get Input", "get:input", &tasks.GetData{Operation: dag.Operation{Type: "input"}}, true).
AddNode("Iterator Processor", "loop", &tasks.Loop{Operation: dag.Operation{Type: "loop"}}). AddNode("Iterator Processor", "loop", &tasks.Loop{Operation: dag.Operation{Type: "loop"}}).
AddNode("Condition", "condition", &tasks.Condition{Operation: dag.Operation{Type: "condition"}}). AddNode("Condition", "condition", &tasks.Condition{Operation: dag.Operation{Type: "condition"}}).
AddNode("Store data", "store:data", &tasks.StoreData{Operation: dag.Operation{Type: "process"}}). AddDAGNode("Persistent", "persistent", subDAG()).
AddNode("Send SMS", "send:sms", &tasks.SendSms{Operation: dag.Operation{Type: "process"}}). AddCondition("condition", map[dag.When]dag.Then{"pass": "email:deliver", "fail": "persistent"}).
AddNode("Notification", "notification", &tasks.InAppNotification{Operation: dag.Operation{Type: "process"}}).
AddCondition("condition", map[dag.When]dag.Then{"pass": "email:deliver", "fail": "store:data"}).
AddEdge("Get input to loop", "get:input", "loop"). AddEdge("Get input to loop", "get:input", "loop").
AddIterator("Loop to prepare email", "loop", "prepare:email"). AddIterator("Loop to prepare email", "loop", "prepare:email").
AddEdge("Prepare Email to condition", "prepare:email", "condition"). AddEdge("Prepare Email to condition", "prepare:email", "condition")
AddEdge("Store Data to send sms and notification", "store:data", "send:sms", "notification")
} }
func sendData(f *dag.DAG) { func sendData(f *dag.DAG) {

View File

@@ -2,11 +2,7 @@ package main
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/oarkflow/mq/consts"
"io"
"net/http"
"github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/examples/tasks"
@@ -42,64 +38,8 @@ func main() {
d.AddCondition("C", map[dag.When]dag.Then{"PASS": "D", "FAIL": "E"}) d.AddCondition("C", map[dag.When]dag.Then{"PASS": "D", "FAIL": "E"})
d.AddEdge("Label 1", "B", "C") d.AddEdge("Label 1", "B", "C")
// Classify edges
// d.ClassifyEdges()
fmt.Println(d.ExportDOT()) fmt.Println(d.ExportDOT())
requestHandler := func(requestType string) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var payload []byte
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err = io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
ctx := r.Context()
if requestType == "request" {
ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"})
}
// ctx = context.WithValue(ctx, "initial_node", "E")
rs := d.Process(ctx, payload)
if rs.Error != nil {
http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rs)
}
}
http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request"))
http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {
d.PauseConsumer(request.Context(), id)
}
})
http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
id := request.PathValue("id")
if id != "" {
d.ResumeConsumer(request.Context(), id)
}
})
http.HandleFunc("/pause", func(writer http.ResponseWriter, request *http.Request) {
d.Pause(request.Context())
})
http.HandleFunc("/resume", func(writer http.ResponseWriter, request *http.Request) {
d.Resume(request.Context())
})
err := d.Start(context.TODO(), ":8083") err := d.Start(context.TODO(), ":8083")
if err != nil { if err != nil {
panic(err) panic(err)

View File

@@ -4,14 +4,15 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/oarkflow/mq/internal/bpool"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"log/slog" "log/slog"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/oarkflow/mq/internal/bpool"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )