From 04d706e82b3653acfdae6e51d8c3bfc7e5630aff Mon Sep 17 00:00:00 2001 From: Luca Casonato Date: Sun, 20 Oct 2019 18:43:56 +0200 Subject: [PATCH] Added routing and subscribing --- README.md | 36 ++++++++++ mqtt.go | 19 ++++- publish_test.go | 16 ++--- router.go | 77 ++++++++++++++++++++ subscribe.go | 58 +++++++++++++++ subscribe_test.go | 178 ++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 375 insertions(+), 9 deletions(-) create mode 100644 router.go create mode 100644 subscribe.go create mode 100644 subscribe_test.go diff --git a/README.md b/README.md index 0c6d845..1cbb255 100644 --- a/README.md +++ b/README.md @@ -81,3 +81,39 @@ if err != nil { panic(err) } ``` + +### subscribing + +```go +err := client.Subscribe(context.WithTimeout(1 * time.Second), func(message mqtt.Message) { + fmt.Printf("recieved a message with content %v\n", message.PayloadString()) +}, "api/v0/main/client1", mqtt.AtLeastOnce) +if err != nil { + panic(err) +} +``` + +### subscribing without listening + +```go +err := client.Subscribe(context.WithTimeout(1 * time.Second), nil, "api/v0/main/client1", mqtt.AtLeastOnce) +if err != nil { + panic(err) +} +``` + +### listening without subscribing + +```go +err := client.Listen(func(message mqtt.Message) { + v := interface{}{} + err := message.PayloadJSON(&v) + if err != nil { + panic(err) + } + fmt.Printf("recieved a message with content %v\n", v) +}, "api/v0/main/client1") +if err != nil { + panic(err) +} +``` diff --git a/mqtt.go b/mqtt.go index da878b6..f6061b2 100644 --- a/mqtt.go +++ b/mqtt.go @@ -12,6 +12,7 @@ import ( type Client struct { Options ClientOptions // The options that were used to create this client client paho.Client + router *router } // ClientOptions is the list of options used to create a client @@ -37,6 +38,14 @@ var ( ErrMinimumOneServer = errors.New("mqtt: at least one server needs to be specified") ) +func handle(callback MessageHandler) paho.MessageHandler { + return func(client paho.Client, message paho.Message) { + if callback != nil { + callback(Message{message: message}) + } + } +} + // NewClient creates a new client with the specified options func NewClient(options ClientOptions) (*Client, error) { pahoOptions := paho.NewClientOptions() @@ -66,7 +75,15 @@ func NewClient(options ClientOptions) (*Client, error) { pahoOptions.SetAutoReconnect(options.AutoReconnect) pahoClient := paho.NewClient(pahoOptions) - return &Client{client: pahoClient, Options: options}, nil + router := newRouter() + pahoClient.AddRoute("#", handle(func(message Message) { + routes := router.match(&message) + for _, route := range routes { + route.handler(message) + } + })) + + return &Client{client: pahoClient, Options: options, router: router}, nil } // Connect tries to establish a conenction with the mqtt servers diff --git a/publish_test.go b/publish_test.go index ec0a3ae..ba79a5f 100644 --- a/publish_test.go +++ b/publish_test.go @@ -29,7 +29,7 @@ func TestPublishSuccess(t *testing.T) { t.Fatalf("connect should not have failed: %v", err) } - err = client.Publish(context.Background(), testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce) + err = client.Publish(context.Background(), testUUID+"/TestPublishSuccess", []byte("hello"), mqtt.AtLeastOnce) if err != nil { t.Fatalf("publish should not have failed: %v", err) } @@ -52,7 +52,7 @@ func TestPublishContextTimeout(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() - err = client.Publish(ctx, testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce) + err = client.Publish(ctx, testUUID+"/TestPublishContextTimeout", []byte("hello"), mqtt.AtLeastOnce) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("publish should have returned the error context.DeadlineExceeded") } @@ -79,7 +79,7 @@ func TestPublishContextCancelled(t *testing.T) { cancel() }() defer cancel() - err = client.Publish(ctx, testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce) + err = client.Publish(ctx, testUUID+"/TestPublishContextCancelled", []byte("hello"), mqtt.AtLeastOnce) if !errors.Is(err, context.Canceled) { t.Fatalf("publish should have returned the error context.Canceled") } @@ -100,7 +100,7 @@ func TestPublishFailed(t *testing.T) { if err != nil { t.Fatalf("connect should not have failed: %v", err) } - err = client.Publish(context.Background(), testUUID+"/test_publish", nil, 3) + err = client.Publish(context.Background(), testUUID+"/TestPublishFailed", nil, 3) if err == nil { t.Fatalf("publish should have failed") } @@ -122,7 +122,7 @@ func TestPublishSuccessRetained(t *testing.T) { t.Fatalf("connect should not have failed: %v", err) } - err = client.Publish(context.Background(), testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce, mqtt.Retain) + err = client.Publish(context.Background(), testUUID+"/TestPublishSuccessRetained", []byte("hello"), mqtt.AtLeastOnce, mqtt.Retain) if err != nil { t.Fatalf("publish should not have failed: %v", err) } @@ -144,7 +144,7 @@ func TestPublisStringSuccess(t *testing.T) { t.Fatalf("connect should not have failed: %v", err) } - err = client.PublishString(context.Background(), testUUID+"/test_publish", "world", mqtt.AtLeastOnce) + err = client.PublishString(context.Background(), testUUID+"/TestPublisStringSuccess", "world", mqtt.AtLeastOnce) if err != nil { t.Fatalf("publish should not have failed: %v", err) } @@ -166,7 +166,7 @@ func TestPublisJSONSuccess(t *testing.T) { t.Fatalf("connect should not have failed: %v", err) } - err = client.PublishJSON(context.Background(), testUUID+"/test_publish", []string{"hello", "world"}, mqtt.AtLeastOnce) + err = client.PublishJSON(context.Background(), testUUID+"/TestPublisJSONSuccess", []string{"hello", "world"}, mqtt.AtLeastOnce) if err != nil { t.Fatalf("publish should not have failed: %v", err) } @@ -188,7 +188,7 @@ func TestPublisJSONFailed(t *testing.T) { t.Fatalf("connect should not have failed: %v", err) } - err = client.PublishJSON(context.Background(), testUUID+"/test_publish", make(chan int), mqtt.AtLeastOnce) + err = client.PublishJSON(context.Background(), testUUID+"/TestPublisJSONFailed", make(chan int), mqtt.AtLeastOnce) if _, ok := err.(*json.UnsupportedTypeError); !ok { t.Fatalf("publish error should be of type *json.UnsupportedTypeError: %v", err) } diff --git a/router.go b/router.go new file mode 100644 index 0000000..793ba34 --- /dev/null +++ b/router.go @@ -0,0 +1,77 @@ +package mqtt + +import ( + "strings" + "sync" +) + +type router struct { + routes []route + lock sync.RWMutex +} + +func newRouter() *router { + return &router{routes: []route{}, lock: sync.RWMutex{}} +} + +type route struct { + topic string + handler MessageHandler +} + +func match(route []string, topic []string) bool { + if len(route) == 0 { + return len(topic) == 0 + } + + if len(topic) == 0 { + return route[0] == "#" + } + + if route[0] == "#" { + return true + } + + if (route[0] == "+") || (route[0] == topic[0]) { + return match(route[1:], topic[1:]) + } + return false +} + +func routeIncludesTopic(route, topic string) bool { + return match(routeSplit(route), strings.Split(topic, "/")) +} + +func routeSplit(route string) []string { + var result []string + if strings.HasPrefix(route, "$share") { + result = strings.Split(route, "/")[2:] + } else { + result = strings.Split(route, "/") + } + return result +} + +func (r *route) match(message *Message) bool { + return r.topic == message.Topic() || routeIncludesTopic(r.topic, message.Topic()) +} + +func (r *router) addRoute(topic string, handler MessageHandler) { + if handler != nil { + r.lock.Lock() + r.routes = append(r.routes, route{topic: topic, handler: handler}) + r.lock.Unlock() + } +} + +func (r *router) match(message *Message) []route { + routes := []route{} + r.lock.RLock() + for _, route := range r.routes { + if route.match(message) { + routes = append(routes, route) + } + } + r.lock.RUnlock() + return routes +} diff --git a/subscribe.go b/subscribe.go new file mode 100644 index 0000000..f2b7e71 --- /dev/null +++ b/subscribe.go @@ -0,0 +1,58 @@ +package mqtt + +import ( + "context" + "encoding/json" + + paho "github.com/eclipse/paho.mqtt.golang" +) + +type Message struct { + message paho.Message +} + +type MessageHandler func(Message) + +func (m *Message) Topic() string { + return m.message.Topic() +} + +func (m *Message) QOS() QOS { + return QOS(m.message.Qos()) +} + +func (m *Message) IsDuplicate() bool { + return m.message.Duplicate() +} + +func (m *Message) Acknowledge() { + m.message.Ack() +} + +func (m *Message) Payload() []byte { + return m.message.Payload() +} + +func (m *Message) PayloadString() string { + return string(m.message.Payload()) +} + +func (m *Message) PayloadJSON(v interface{}) error { + return json.Unmarshal(m.message.Payload(), v) +} + +func (c *Client) Listen(handler MessageHandler, topics ...string) { + for _, topic := range topics { + c.router.addRoute(topic, handler) + } +} + +func (c *Client) Subscribe(ctx context.Context, handler MessageHandler, topic string, qos QOS) error { + token := c.client.Subscribe(topic, byte(qos), nil) + err := tokenWithContext(ctx, token) + if err != nil { + return err + } + c.router.addRoute(topic, handler) + return nil +} diff --git a/subscribe_test.go b/subscribe_test.go new file mode 100644 index 0000000..50b936a --- /dev/null +++ b/subscribe_test.go @@ -0,0 +1,178 @@ +package mqtt_test + +import ( + "context" + "testing" + "time" + + "github.com/lucacasonato/mqtt" +) + +// TestSubcribeSuccess checks that a message gets recieved correctly +func TestSubcribeSuccess(t *testing.T) { + client, err := mqtt.NewClient(mqtt.ClientOptions{ + Servers: []string{ + "tcp://test.mosquitto.org:1883", + }, + }) + if err != nil { + t.Fatalf("creating client should not have failed: %v", err) + } + err = client.Connect(context.Background()) + defer client.DisconnectImmediately() + if err != nil { + t.Fatalf("connect should not have failed: %v", err) + } + + reciever := make(chan mqtt.Message) + err = client.Subscribe(context.Background(), func(message mqtt.Message) { + reciever <- message + }, testUUID+"/TestSubcribeSuccess", mqtt.ExactlyOnce) + if err != nil { + t.Fatalf("subscribe should not have failed: %v", err) + } + err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccess", "[1, 2]", mqtt.ExactlyOnce) + if err != nil { + t.Fatalf("publish should not have failed: %v", err) + } + message := <-reciever + if string(message.Payload()) != "[1, 2]" { + t.Fatalf("message payload should have been byte array '%v' but is %v", []byte("[1, 2]"), message.Payload()) + } + if message.PayloadString() != "[1, 2]" { + t.Fatalf("message payload should have been '[1, 2]' but is %v", message.PayloadString()) + } + v := []int{} + err = message.PayloadJSON(&v) + if err != nil { + t.Fatalf("json should have unmarshalled: %v", err) + } + if len(v) != 2 || v[0] != 1 || v[1] != 2 { + t.Fatalf("message payload should have been []int{1, 2} but is %v", v) + } + if message.Topic() != testUUID+"/TestSubcribeSuccess" { + t.Fatalf("message topic should be %v but is %v", testUUID+"/TestSubcribeSuccess", message.Topic()) + } + if message.QOS() != mqtt.ExactlyOnce { + t.Fatalf("message qos should be mqtt.ExactlyOnce but is %v", message.QOS()) + } + if message.IsDuplicate() != false { + t.Fatalf("message IsDuplicate should be false but is %v", message.IsDuplicate()) + } + message.Acknowledge() +} + +// TestListenSuccess checks that a listener recieves a message correctly +func TestListenSuccess(t *testing.T) { + client, err := mqtt.NewClient(mqtt.ClientOptions{ + Servers: []string{ + "tcp://test.mosquitto.org:1883", + }, + }) + if err != nil { + t.Fatalf("creating client should not have failed: %v", err) + } + err = client.Connect(context.Background()) + defer client.DisconnectImmediately() + if err != nil { + t.Fatalf("connect should not have failed: %v", err) + } + reciever := make(chan mqtt.Message) + err = client.Subscribe(context.Background(), func(message mqtt.Message) {}, testUUID+"/TestListenSuccess", mqtt.ExactlyOnce) + if err != nil { + t.Fatalf("subscribe should not have failed: %v", err) + } + client.Listen(func(message mqtt.Message) { + reciever <- message + }, testUUID+"/TestListenSuccess") + err = client.PublishString(context.Background(), testUUID+"/TestListenSuccess", "hello", mqtt.ExactlyOnce) + if err != nil { + t.Fatalf("publish should not have failed: %v", err) + } + message := <-reciever + if message.PayloadString() != "hello" { + t.Fatalf("message payload should have been 'hello' but is %v", message) + } +} + +// TestSubcribeSuccess checks that a message gets recieved correctly +func TestSubcribeFailure(t *testing.T) { + client, err := mqtt.NewClient(mqtt.ClientOptions{ + Servers: []string{ + "tcp://test.mosquitto.org:1883", + }, + }) + if err != nil { + t.Fatalf("creating client should not have failed: %v", err) + } + err = client.Connect(context.Background()) + defer client.DisconnectImmediately() + if err != nil { + t.Fatalf("connect should not have failed: %v", err) + } + err = client.Subscribe(context.Background(), func(message mqtt.Message) {}, testUUID+"/#/test_publish", mqtt.ExactlyOnce) // # in the middle of a subscribe is not allowed + if err == nil { + t.Fatalf("subscribe should have failed: %v", err) + } +} + +// TestSubcribeSuccess checks that a message gets recieved correctly +func TestSubcribeSuccessAdvancedRouting(t *testing.T) { + client, err := mqtt.NewClient(mqtt.ClientOptions{ + Servers: []string{ + "tcp://test.mosquitto.org:1883", + }, + }) + if err != nil { + t.Fatalf("creating client should not have failed: %v", err) + } + err = client.Connect(context.Background()) + defer client.DisconnectImmediately() + if err != nil { + t.Fatalf("connect should not have failed: %v", err) + } + reciever := make(chan mqtt.Message) + err = client.Subscribe(context.Background(), func(message mqtt.Message) { + reciever <- message + }, testUUID+"/TestSubcribeSuccessAdvancedRouting/#", mqtt.ExactlyOnce) + if err != nil { + t.Fatalf("subscribe should not have failed: %v", err) + } + err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccessAdvancedRouting/abc", "hello world", mqtt.ExactlyOnce) + if err != nil { + t.Fatalf("publish should not have failed: %v", err) + } + message := <-reciever + if message.PayloadString() != "hello world" { + t.Fatalf("message payload should have been 'hello world' but is %v", message.PayloadString()) + } +} + +// TestSubcribeSuccess checks that a message gets recieved correctly +func TestSubcribeNoRecieve(t *testing.T) { + client, err := mqtt.NewClient(mqtt.ClientOptions{ + Servers: []string{ + "tcp://test.mosquitto.org:1883", + }, + }) + if err != nil { + t.Fatalf("creating client should not have failed: %v", err) + } + err = client.Connect(context.Background()) + defer client.DisconnectImmediately() + if err != nil { + t.Fatalf("connect should not have failed: %v", err) + } + client.Listen(func(message mqtt.Message) { + t.Fatalf("recieved a message which was not meant to happen: %v", err) + }, testUUID+"/TestSubcribeSuccessAdvancedRouting/abc") + err = client.Subscribe(context.Background(), nil, testUUID+"/TestSubcribeSuccessAdvancedRouting/def", mqtt.ExactlyOnce) + if err != nil { + t.Fatalf("subscribe should not have failed: %v", err) + } + err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccessAdvancedRouting/def", "hello world", mqtt.ExactlyOnce) + if err != nil { + t.Fatalf("publish should not have failed: %v", err) + } + <-time.After(500 * time.Millisecond) +}