feat: implement TLS support

This commit is contained in:
sujit
2024-10-05 15:23:48 +05:45
parent b600a8f89a
commit 138e2ed8c5
11 changed files with 44 additions and 330 deletions

View File

@@ -3,8 +3,8 @@ package main
import (
"context"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks"
mq "github.com/oarkflow/mq/v2"
)
func main() {

View File

@@ -1,266 +0,0 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"strings"
"sync"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
)
type Broker struct {
aesKey json.RawMessage
hmacKey json.RawMessage
subscribers map[string][]net.Conn
mu sync.RWMutex
}
func NewBroker(aesKey, hmacKey json.RawMessage) *Broker {
return &Broker{
aesKey: aesKey,
hmacKey: hmacKey,
subscribers: make(map[string][]net.Conn),
}
}
func (b *Broker) addSubscriber(topic string, conn net.Conn) {
b.mu.Lock()
defer b.mu.Unlock()
b.subscribers[topic] = append(b.subscribers[topic], conn)
}
func (b *Broker) removeSubscriber(conn net.Conn) {
b.mu.Lock()
defer b.mu.Unlock()
for topic, conns := range b.subscribers {
for i, c := range conns {
if c == conn {
b.subscribers[topic] = append(conns[:i], conns[i+1:]...)
break
}
}
}
}
func (b *Broker) broadcastToSubscribers(topic string, msg *codec.Message) {
b.mu.RLock()
defer b.mu.RUnlock()
subscribers, ok := b.subscribers[topic]
if !ok || len(subscribers) == 0 {
log.Printf("No subscribers for topic: %s", topic)
return
}
for _, conn := range subscribers {
err := codec.SendMessage(conn, msg, b.aesKey, b.hmacKey, true)
if err != nil {
log.Printf("Error sending message to subscriber: %v", err)
}
}
}
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
switch msg.Command {
case consts.PUBLISH:
b.broadcastToSubscribers(msg.Queue, msg)
ack := &codec.Message{
Headers: msg.Headers,
Queue: msg.Queue,
Command: consts.PUBLISH_ACK,
}
if err := codec.SendMessage(conn, ack, b.aesKey, b.hmacKey, true); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
}
case consts.SUBSCRIBE:
b.addSubscriber(msg.Queue, conn)
ack := &codec.Message{
Headers: msg.Headers,
Queue: msg.Queue,
Command: consts.SUBSCRIBE_ACK,
}
if err := codec.SendMessage(conn, ack, b.aesKey, b.hmacKey, true); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
}
}
}
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) {
log.Println("Connection closed")
b.removeSubscriber(conn)
}
func (b *Broker) OnError(ctx context.Context, err error) {
log.Printf("Connection Error: %v\n", err)
}
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
msg, err := codec.ReadMessage(c, b.aesKey, b.hmacKey, true)
if err == nil {
ctx = mq.SetHeaders(ctx, msg.Headers)
b.OnMessage(ctx, msg, c)
return nil
}
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
b.OnClose(ctx, c)
return err
}
b.OnError(ctx, err)
return err
}
func (b *Broker) Serve(ctx context.Context, addr string) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
b.OnError(ctx, err)
continue
}
go func(c net.Conn) {
defer c.Close()
for {
err := b.readMessage(ctx, c)
if err != nil {
break
}
}
}(conn)
}
}
type Publisher struct {
aesKey json.RawMessage
hmacKey json.RawMessage
}
func NewPublisher(aesKey, hmacKey json.RawMessage) *Publisher {
return &Publisher{aesKey: aesKey, hmacKey: hmacKey}
}
func (p *Publisher) Publish(ctx context.Context, addr, topic string, payload json.RawMessage) error {
conn, err := net.Dial("tcp", addr)
if err != nil {
return err
}
defer conn.Close()
headers, _ := mq.GetHeaders(ctx)
msg := &codec.Message{
Headers: headers,
Queue: topic,
Command: consts.PUBLISH,
Payload: payload,
}
if err := codec.SendMessage(conn, msg, p.aesKey, p.hmacKey, true); err != nil {
return err
}
return p.waitForAck(conn)
}
func (p *Publisher) waitForAck(conn net.Conn) error {
msg, err := codec.ReadMessage(conn, p.aesKey, p.hmacKey, true)
if err != nil {
return err
}
if msg.Command == consts.PUBLISH_ACK {
log.Println("Received PUBLISH_ACK: Message published successfully")
return nil
}
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
}
type Consumer struct {
aesKey json.RawMessage
hmacKey json.RawMessage
}
func NewConsumer(aesKey, hmacKey json.RawMessage) *Consumer {
return &Consumer{aesKey: aesKey, hmacKey: hmacKey}
}
func (c *Consumer) Subscribe(ctx context.Context, addr, topic string) error {
conn, err := net.Dial("tcp", addr)
if err != nil {
return err
}
defer conn.Close()
headers, _ := mq.GetHeaders(ctx)
msg := &codec.Message{
Headers: headers,
Queue: topic,
Command: consts.SUBSCRIBE,
}
if err := codec.SendMessage(conn, msg, c.aesKey, c.hmacKey, true); err != nil {
return err
}
err = c.waitForAck(conn)
if err != nil {
return err
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
msg, err := codec.ReadMessage(conn, c.aesKey, c.hmacKey, true)
if err != nil {
log.Printf("Error reading message: %v\n", err)
break
}
log.Printf("Received task on topic %s: %s\n", msg.Queue, msg.Payload)
}
}()
wg.Wait()
return nil
}
func (c *Consumer) waitForAck(conn net.Conn) error {
msg, err := codec.ReadMessage(conn, c.aesKey, c.hmacKey, true)
if err != nil {
return err
}
if msg.Command == consts.SUBSCRIBE_ACK {
log.Println("Received SUBSCRIBE_ACK: Subscribed successfully")
return nil
}
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
}
func main() {
addr := ":8081"
aesKey := []byte("thisis32bytekeyforaesencryption1")
hmacKey := []byte("thisisasecrethmackey1")
broker := NewBroker(aesKey, hmacKey)
publisher := NewPublisher(aesKey, hmacKey)
consumer := NewConsumer(aesKey, hmacKey)
go broker.Serve(context.Background(), addr)
time.Sleep(1 * time.Second)
go consumer.Subscribe(context.Background(), addr, "sensor_data")
time.Sleep(3 * time.Second)
data := map[string]interface{}{"temperature": 23.5, "humidity": 60}
payload, _ := json.Marshal(data)
go publisher.Publish(context.Background(), addr, "sensor_data", payload)
time.Sleep(10 * time.Second)
}

View File

@@ -1,36 +0,0 @@
package main
import (
"context"
"encoding/json"
"fmt"
"time"
v2 "github.com/oarkflow/mq/v2"
)
func main() {
ctx := context.Background()
broker := v2.NewBroker()
go broker.Start(ctx)
time.Sleep(1 * time.Second)
consumer := v2.NewConsumer("consumer-1")
consumer.RegisterHandler("queue-1", func(ctx context.Context, task v2.Task) v2.Result {
fmt.Println("Handling on queue-1", string(task.Payload))
return v2.Result{Payload: task.Payload}
})
go func() {
err := consumer.Consume(ctx)
if err != nil {
panic(err)
}
}()
publisher := v2.NewPublisher("publisher-1")
time.Sleep(3 * time.Second)
data := map[string]any{"temperature": 23.5, "humidity": 60}
payload, _ := json.Marshal(data)
rs := publisher.Request(ctx, "queue-1", v2.Task{Payload: payload})
fmt.Println("Response:", string(rs.Payload), rs.Error)
}

View File

@@ -4,7 +4,7 @@ import (
"context"
"fmt"
"github.com/oarkflow/mq"
mq "github.com/oarkflow/mq/v2"
)
func main() {
@@ -23,9 +23,9 @@ func main() {
task = mq.Task{
Payload: payload,
}
result, err := publisher.Request(context.Background(), "queue1", task)
if err != nil {
panic(err)
result := publisher.Request(context.Background(), "queue1", task)
if result.Error != nil {
panic(result.Error)
}
fmt.Printf("Sync task published. Result: %v\n", string(result.Payload))
}

View File

@@ -3,8 +3,8 @@ package main
import (
"context"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks"
mq "github.com/oarkflow/mq/v2"
)
func main() {

View File

@@ -5,7 +5,7 @@ import (
"encoding/json"
"fmt"
"github.com/oarkflow/mq"
mq "github.com/oarkflow/mq/v2"
)
func Node1(ctx context.Context, task mq.Task) mq.Result {

View File

@@ -32,17 +32,13 @@ type Broker struct {
}
func NewBroker(opts ...Option) *Broker {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
b := &Broker{
options := setupOptions(opts...)
return &Broker{
queues: xsync.NewMap[string, *Queue](),
publishers: xsync.NewMap[string, *publisher](),
consumers: xsync.NewMap[string, *consumer](),
opts: options,
}
return b
}
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
@@ -90,10 +86,12 @@ func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
}
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
msg.Command = consts.RESPONSE
headers, ok := GetHeaders(ctx)
if !ok {
return
}
b.HandleCallback(ctx, msg)
awaitResponse, ok := headers[consts.AwaitResponseKey]
if !(ok && awaitResponse == "true") {
return
@@ -106,7 +104,6 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message)
if !ok {
return
}
msg.Command = consts.RESPONSE
err := b.send(con.conn, msg)
if err != nil {
panic(err)
@@ -115,7 +112,7 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message)
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
pub := b.addPublisher(ctx, msg.Queue, conn)
log.Printf("BROKER - PUBLISH ~> from %s on %s", pub.id, msg.Queue)
log.Printf("BROKER - PUBLISH ~> received from %s on %s", pub.id, msg.Queue)
ack := codec.NewMessage(consts.PUBLISH_ACK, nil, msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err)

View File

@@ -27,16 +27,12 @@ type Consumer struct {
// NewConsumer initializes a new consumer with the provided options.
func NewConsumer(id string, opts ...Option) *Consumer {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
b := &Consumer{
options := setupOptions(opts...)
return &Consumer{
handlers: make(map[string]Handler),
id: id,
opts: options,
}
return b
}
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
@@ -127,7 +123,7 @@ func (c *Consumer) AttemptConnect() error {
return nil
}
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
time.Sleep(sleepDuration)
delay *= 2
if delay > c.opts.maxBackoff {

View File

@@ -49,6 +49,14 @@ func defaultOptions() Options {
// Option defines a function type for setting options.
type Option func(*Options)
func setupOptions(opts ...Option) Options {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
return options
}
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
return func(opts *Options) {
opts.aesKey = aesKey

View File

@@ -18,12 +18,8 @@ type Publisher struct {
}
func NewPublisher(id string, opts ...Option) *Publisher {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
b := &Publisher{id: id, opts: options}
return b
options := setupOptions(opts...)
return &Publisher{id: id, opts: options}
}
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {

View File

@@ -1,5 +1,12 @@
package v2
import (
"context"
"encoding/json"
"github.com/oarkflow/mq/codec"
)
func (b *Broker) TLSConfig() TLSConfig {
return b.opts.tlsConfig
}
@@ -7,3 +14,15 @@ func (b *Broker) TLSConfig() TLSConfig {
func (b *Broker) SyncMode() bool {
return b.opts.syncMode
}
func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) {
if b.opts.callback != nil {
var result Result
err := json.Unmarshal(msg.Payload, &result)
if err == nil {
for _, callback := range b.opts.callback {
callback(ctx, result)
}
}
}
}