feat: add example

This commit is contained in:
sujit
2024-10-11 12:23:36 +05:45
parent 80b1ee81e2
commit 3d17a58345
9 changed files with 150 additions and 128 deletions

View File

@@ -208,7 +208,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
}
@@ -361,7 +361,7 @@ func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, cons
fn := func(queue *Queue) {
con, ok := queue.consumers.Get(consumerID)
if ok {
ack := codec.NewMessage(cmd, []byte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
err := b.send(con.conn, ack)
if err == nil {
con.state = state

View File

@@ -2,13 +2,20 @@ package codec
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"sync"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/utils"
)
type Message struct {
@@ -31,18 +38,18 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m
func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) {
m.m.Lock()
defer m.m.Unlock()
var buf bytes.Buffer
if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil {
buf := bytes.NewBuffer(make([]byte, 0, 512))
if err := writeLengthPrefixedJSON(buf, m.Headers); err != nil {
return nil, "", fmt.Errorf("error serializing headers: %v", err)
}
if err := writeLengthPrefixed(&buf, []byte(m.Queue)); err != nil {
return nil, "", fmt.Errorf("error serializing topic: %v", err)
if err := writeLengthPrefixed(buf, utils.ToByte(m.Queue)); err != nil {
return nil, "", fmt.Errorf("error serializing queue: %v", err)
}
if err := binary.Write(&buf, binary.LittleEndian, m.Command); err != nil {
if err := binary.Write(buf, binary.LittleEndian, m.Command); err != nil {
return nil, "", fmt.Errorf("error serializing command: %v", err)
}
if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil {
return nil, "", err
if err := writePayload(buf, aesKey, m.Payload, encrypt); err != nil {
return nil, "", fmt.Errorf("error serializing payload: %v", err)
}
messageBytes := buf.Bytes()
hmacSignature := CalculateHMAC(hmacKey, messageBytes)
@@ -51,16 +58,17 @@ func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, strin
func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) {
if !VerifyHMAC(hmacKey, data, receivedHMAC) {
return nil, fmt.Errorf("HMAC verification failed %s", string(hmacKey))
return nil, fmt.Errorf("HMAC verification failed")
}
buf := bytes.NewReader(data)
headers := make(map[string]string)
if err := readLengthPrefixedJSON(buf, &headers); err != nil {
return nil, fmt.Errorf("error deserializing headers: %v", err)
}
topic, err := readLengthPrefixedString(buf)
queue, err := readLengthPrefixedString(buf)
if err != nil {
return nil, fmt.Errorf("error deserializing topic: %v", err)
return nil, fmt.Errorf("error deserializing queue: %v", err)
}
var command consts.CMD
if err := binary.Read(buf, binary.LittleEndian, &command); err != nil {
@@ -72,7 +80,7 @@ func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool
}
return &Message{
Headers: headers,
Queue: topic,
Queue: queue,
Command: command,
Payload: payload,
}, nil
@@ -139,7 +147,7 @@ func readLengthPrefixedString(r *bytes.Reader) (string, error) {
if err != nil {
return "", err
}
return string(data), nil
return utils.FromByte(data), nil
}
func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, encrypt bool) error {
@@ -147,7 +155,6 @@ func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, enc
if err != nil {
return fmt.Errorf("error marshalling payload: %v", err)
}
var encryptedPayload, nonce []byte
if encrypt {
encryptedPayload, nonce, err = EncryptPayload(aesKey, payloadBytes)
@@ -157,11 +164,9 @@ func writePayload(buf *bytes.Buffer, aesKey []byte, payload json.RawMessage, enc
} else {
encryptedPayload = payloadBytes
}
if err := writeLengthPrefixed(buf, encryptedPayload); err != nil {
return err
}
if encrypt {
buf.Write(nonce)
}
@@ -192,6 +197,7 @@ func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage,
}
return payload, nil
}
func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error {
if err := binary.Write(conn, binary.LittleEndian, uint32(len(messageBytes))); err != nil {
return err
@@ -199,7 +205,11 @@ func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature str
if _, err := conn.Write(messageBytes); err != nil {
return err
}
if _, err := conn.Write([]byte(hmacSignature)); err != nil {
hmacBytes, err := hex.DecodeString(hmacSignature)
if err != nil {
return err
}
if _, err := conn.Write(hmacBytes); err != nil {
return err
}
return nil
@@ -214,12 +224,54 @@ func readMessageWithHMAC(conn io.Reader) ([]byte, string, error) {
if _, err := io.ReadFull(conn, data); err != nil {
return nil, "", err
}
hmacBytes := make([]byte, 64)
hmacBytes := make([]byte, 32)
if _, err := io.ReadFull(conn, hmacBytes); err != nil {
return nil, "", err
}
receivedHMAC := string(hmacBytes)
receivedHMAC := hex.EncodeToString(hmacBytes)
return data, receivedHMAC, nil
}
func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, nil, err
}
nonce := make([]byte, aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, nil, err
}
ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil)
return ciphertext, nonce, nil
}
func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
func CalculateHMAC(key []byte, data []byte) string {
h := hmac.New(sha256.New, key)
h.Write(data)
return hex.EncodeToString(h.Sum(nil))
}
func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool {
expectedHMAC := CalculateHMAC(key, data)
return hmac.Equal(utils.ToByte(receivedHMAC), utils.ToByte(expectedHMAC))
}

View File

@@ -1,51 +0,0 @@
package codec
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"io"
)
func EncryptPayload(key []byte, plaintext []byte) ([]byte, []byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, nil, err
}
nonce := make([]byte, aesGCM.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, nil, err
}
ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil)
return ciphertext, nonce, nil
}
func DecryptPayload(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aesGCM.Open(nil, nonce, ciphertext, nil)
}
func CalculateHMAC(key []byte, data []byte) string {
h := hmac.New(sha256.New, key)
h.Write(data)
return hex.EncodeToString(h.Sum(nil))
}
func VerifyHMAC(key []byte, data []byte, receivedHMAC string) bool {
expectedHMAC := CalculateHMAC(key, data)
return hmac.Equal([]byte(expectedHMAC), []byte(receivedHMAC))
}

View File

@@ -57,7 +57,7 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error {
consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson,
})
msg := codec.NewMessage(consts.SUBSCRIBE, []byte("{}"), queue, headers)
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
if err := c.send(c.conn, msg); err != nil {
return err
}
@@ -104,7 +104,7 @@ func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn
consts.QueueKey: msg.Queue,
})
taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err := c.send(conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
}
@@ -158,7 +158,7 @@ func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, er
consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson,
})
reply := codec.NewMessage(consts.MESSAGE_DENY, []byte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
if sendErr := c.send(c.conn, reply); sendErr != nil {
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
}

View File

@@ -44,8 +44,8 @@ type Edge struct {
}
type DAG struct {
FirstNode string
Nodes map[string]*Node
startNode string
nodes map[string]*Node
server *mq.Broker
taskContext map[string]*TaskManager
conditions map[string]map[string]string
@@ -56,7 +56,7 @@ type DAG struct {
func NewDAG(opts ...mq.Option) *DAG {
d := &DAG{
Nodes: make(map[string]*Node),
nodes: make(map[string]*Node),
taskContext: make(map[string]*TaskManager),
conditions: make(map[string]map[string]string),
}
@@ -74,19 +74,27 @@ func (tm *DAG) onTaskCallback(ctx context.Context, result mq.Result) mq.Result {
}
func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) {
if node, ok := tm.Nodes[topic]; ok {
if node, ok := tm.nodes[topic]; ok {
log.Printf("DAG - CONSUMER ~> ready on %s", topic)
node.isReady = true
}
}
func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) {
if node, ok := tm.Nodes[topic]; ok {
if node, ok := tm.nodes[topic]; ok {
log.Printf("DAG - CONSUMER ~> down on %s", topic)
node.isReady = false
}
}
func (tm *DAG) SetStartNode(node string) {
tm.startNode = node
}
func (tm *DAG) GetStartNode() string {
return tm.startNode
}
func (tm *DAG) Start(ctx context.Context, addr string) error {
if !tm.server.SyncMode() {
go func() {
@@ -95,7 +103,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
panic(err)
}
}()
for _, con := range tm.Nodes {
for _, con := range tm.nodes {
if con.isReady {
go func(con *Node) {
time.Sleep(1 * time.Second)
@@ -122,13 +130,13 @@ func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) {
tm.mu.Lock()
defer tm.mu.Unlock()
con := mq.NewConsumer(key, key, handler, tm.opts...)
tm.Nodes[key] = &Node{
tm.nodes[key] = &Node{
Key: key,
consumer: con,
isReady: true,
}
if len(firstNode) > 0 && firstNode[0] {
tm.FirstNode = key
tm.startNode = key
}
}
@@ -138,11 +146,11 @@ func (tm *DAG) AddDeferredNode(key string, firstNode ...bool) error {
}
tm.mu.Lock()
defer tm.mu.Unlock()
tm.Nodes[key] = &Node{
tm.nodes[key] = &Node{
Key: key,
}
if len(firstNode) > 0 && firstNode[0] {
tm.FirstNode = key
tm.startNode = key
}
return nil
}
@@ -150,7 +158,7 @@ func (tm *DAG) AddDeferredNode(key string, firstNode ...bool) error {
func (tm *DAG) IsReady() bool {
tm.mu.Lock()
defer tm.mu.Unlock()
for _, node := range tm.Nodes {
for _, node := range tm.nodes {
if !node.isReady {
return false
}
@@ -167,11 +175,11 @@ func (tm *DAG) AddCondition(fromNode string, conditions map[string]string) {
func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) {
tm.mu.Lock()
defer tm.mu.Unlock()
fromNode, ok := tm.Nodes[from]
fromNode, ok := tm.nodes[from]
if !ok {
return
}
toNode, ok := tm.Nodes[to]
toNode, ok := tm.nodes[to]
if !ok {
return
}
@@ -183,25 +191,28 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) {
}
func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result {
tm.mu.RLock() // lock when reading `paused`
if tm.paused {
tm.mu.RUnlock()
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
}
tm.mu.RUnlock()
if !tm.IsReady() {
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
}
val := ctx.Value("initial_node")
initialNode, ok := val.(string)
if !ok {
if tm.FirstNode == "" {
if tm.startNode == "" {
firstNode := tm.FindInitialNode()
if firstNode != nil {
tm.FirstNode = firstNode.Key
tm.startNode = firstNode.Key
}
}
if tm.FirstNode == "" {
if tm.startNode == "" {
return mq.Result{Error: fmt.Errorf("initial node not found")}
}
initialNode = tm.FirstNode
initialNode = tm.startNode
}
tm.mu.Lock()
defer tm.mu.Unlock()
@@ -214,7 +225,7 @@ func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result {
func (tm *DAG) FindInitialNode() *Node {
incomingEdges := make(map[string]bool)
connectedNodes := make(map[string]bool)
for _, node := range tm.Nodes {
for _, node := range tm.nodes {
for _, edge := range node.Edges {
if edge.Type.IsValid() {
connectedNodes[node.Key] = true
@@ -229,7 +240,7 @@ func (tm *DAG) FindInitialNode() *Node {
}
}
}
for nodeID, node := range tm.Nodes {
for nodeID, node := range tm.nodes {
if !incomingEdges[nodeID] && connectedNodes[nodeID] {
return node
}
@@ -238,24 +249,28 @@ func (tm *DAG) FindInitialNode() *Node {
}
func (tm *DAG) Pause() {
tm.mu.Lock() // lock when modifying `paused`
defer tm.mu.Unlock()
tm.paused = true
log.Printf("DAG - PAUSED")
}
func (tm *DAG) Resume() {
tm.mu.Lock() // lock when modifying `paused`
defer tm.mu.Unlock()
tm.paused = false
log.Printf("DAG - RESUMED")
}
func (tm *DAG) PauseConsumer(id string) {
if node, ok := tm.Nodes[id]; ok {
if node, ok := tm.nodes[id]; ok {
node.consumer.Pause()
node.isReady = false
}
}
func (tm *DAG) ResumeConsumer(id string) {
if node, ok := tm.Nodes[id]; ok {
if node, ok := tm.nodes[id]; ok {
node.consumer.Resume()
node.isReady = true
}

View File

@@ -33,8 +33,13 @@ func NewTaskManager(d *DAG, taskID string) *TaskManager {
}
}
func (tm *TaskManager) updateTS(result *mq.Result) {
result.CreatedAt = tm.createdAt
result.ProcessedAt = time.Now()
}
func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload json.RawMessage) mq.Result {
node, ok := tm.dag.Nodes[nodeID]
node, ok := tm.dag.nodes[nodeID]
if !ok {
return mq.Result{Error: fmt.Errorf("nodeID %s not found", nodeID)}
}
@@ -45,8 +50,7 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j
if awaitResponse != "true" {
go func() {
finalResult := <-tm.finalResult
finalResult.CreatedAt = tm.createdAt
finalResult.ProcessedAt = time.Now()
tm.updateTS(&finalResult)
if tm.dag.server.NotifyHandler() != nil {
tm.dag.server.NotifyHandler()(ctx, finalResult)
}
@@ -54,8 +58,7 @@ func (tm *TaskManager) processTask(ctx context.Context, nodeID string, payload j
return mq.Result{CreatedAt: tm.createdAt, TaskID: tm.taskID, Topic: nodeID, Status: "PENDING"}
} else {
finalResult := <-tm.finalResult
finalResult.CreatedAt = tm.createdAt
finalResult.ProcessedAt = time.Now()
tm.updateTS(&finalResult)
if tm.dag.server.NotifyHandler() != nil {
tm.dag.server.NotifyHandler()(ctx, finalResult)
}
@@ -79,7 +82,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
if result.Topic != "" {
atomic.AddInt64(&tm.waitingCallback, -1)
}
node, ok := tm.dag.Nodes[result.Topic]
node, ok := tm.dag.nodes[result.Topic]
if !ok {
return result
}
@@ -88,7 +91,7 @@ func (tm *TaskManager) handleCallback(ctx context.Context, result mq.Result) mq.
if result.Status != "" {
if conditions, ok := tm.dag.conditions[result.Topic]; ok {
if targetNodeKey, ok := conditions[result.Status]; ok {
if targetNode, ok := tm.dag.Nodes[targetNodeKey]; ok {
if targetNode, ok := tm.dag.nodes[targetNodeKey]; ok {
edges = append(edges, Edge{From: node, To: targetNode})
}
}
@@ -147,12 +150,7 @@ func (tm *TaskManager) handleResult(ctx context.Context, results any) mq.Result
if err != nil {
return mq.HandleError(ctx, err)
}
return mq.Result{
TaskID: tm.taskID,
Payload: finalOutput,
Status: status,
Topic: topic,
}
return mq.Result{TaskID: tm.taskID, Payload: finalOutput, Status: status, Topic: topic}
case mq.Result:
return res
}

View File

@@ -2,33 +2,25 @@ package main
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"log"
)
func GenerateSecretKey() (string, error) {
// Create a byte slice to hold 32 random bytes
key := make([]byte, 32)
// Fill the slice with secure random bytes
func generateHMACKey() ([]byte, error) {
key := make([]byte, 32) // 32 bytes = 256 bits
_, err := rand.Read(key)
if err != nil {
return "", err
return nil, err
}
// Encode the byte slice to a Base64 string
secretKey := base64.StdEncoding.EncodeToString(key)
// Return the first 32 characters
return secretKey[:32], nil
return key, nil
}
func main() {
secretKey, err := GenerateSecretKey()
hmacKey, err := generateHMACKey()
if err != nil {
log.Fatalf("Error generating secret key: %v", err)
fmt.Println("Error generating HMAC key:", err)
return
}
fmt.Println("Generated Secret Key:", secretKey)
fmt.Println("HMAC Key (hex):", hex.EncodeToString(hmacKey))
}

View File

@@ -86,7 +86,7 @@ func defaultOptions() Options {
maxBackoff: 20 * time.Second,
jitterPercent: 0.5,
queueSize: 100,
hmacKey: []byte(`a9f4b9415485b70275673b5920182796ea497b5e093ead844a43ea5d77cbc24f`),
hmacKey: []byte(`475f3adc6be9ee6f5357020e2922ff5b8f971598e175878e617d19df584bc648`),
numOfWorkers: runtime.NumCPU(),
maxMemoryLoad: 5000000,
}

16
utils/str.go Normal file
View File

@@ -0,0 +1,16 @@
package utils
import (
"unsafe"
)
func ToByte(s string) []byte {
p := unsafe.StringData(s)
b := unsafe.Slice(p, len(s))
return b
}
func FromByte(b []byte) string {
p := unsafe.SliceData(b)
return unsafe.String(p, len(b))
}