Added routing and subscribing

This commit is contained in:
Luca Casonato
2019-10-20 18:43:56 +02:00
parent ac8523abed
commit 04d706e82b
6 changed files with 375 additions and 9 deletions

View File

@@ -81,3 +81,39 @@ if err != nil {
panic(err)
}
```
### subscribing
```go
err := client.Subscribe(context.WithTimeout(1 * time.Second), func(message mqtt.Message) {
fmt.Printf("recieved a message with content %v\n", message.PayloadString())
}, "api/v0/main/client1", mqtt.AtLeastOnce)
if err != nil {
panic(err)
}
```
### subscribing without listening
```go
err := client.Subscribe(context.WithTimeout(1 * time.Second), nil, "api/v0/main/client1", mqtt.AtLeastOnce)
if err != nil {
panic(err)
}
```
### listening without subscribing
```go
err := client.Listen(func(message mqtt.Message) {
v := interface{}{}
err := message.PayloadJSON(&v)
if err != nil {
panic(err)
}
fmt.Printf("recieved a message with content %v\n", v)
}, "api/v0/main/client1")
if err != nil {
panic(err)
}
```

19
mqtt.go
View File

@@ -12,6 +12,7 @@ import (
type Client struct {
Options ClientOptions // The options that were used to create this client
client paho.Client
router *router
}
// ClientOptions is the list of options used to create a client
@@ -37,6 +38,14 @@ var (
ErrMinimumOneServer = errors.New("mqtt: at least one server needs to be specified")
)
func handle(callback MessageHandler) paho.MessageHandler {
return func(client paho.Client, message paho.Message) {
if callback != nil {
callback(Message{message: message})
}
}
}
// NewClient creates a new client with the specified options
func NewClient(options ClientOptions) (*Client, error) {
pahoOptions := paho.NewClientOptions()
@@ -66,7 +75,15 @@ func NewClient(options ClientOptions) (*Client, error) {
pahoOptions.SetAutoReconnect(options.AutoReconnect)
pahoClient := paho.NewClient(pahoOptions)
return &Client{client: pahoClient, Options: options}, nil
router := newRouter()
pahoClient.AddRoute("#", handle(func(message Message) {
routes := router.match(&message)
for _, route := range routes {
route.handler(message)
}
}))
return &Client{client: pahoClient, Options: options, router: router}, nil
}
// Connect tries to establish a conenction with the mqtt servers

View File

@@ -29,7 +29,7 @@ func TestPublishSuccess(t *testing.T) {
t.Fatalf("connect should not have failed: %v", err)
}
err = client.Publish(context.Background(), testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce)
err = client.Publish(context.Background(), testUUID+"/TestPublishSuccess", []byte("hello"), mqtt.AtLeastOnce)
if err != nil {
t.Fatalf("publish should not have failed: %v", err)
}
@@ -52,7 +52,7 @@ func TestPublishContextTimeout(t *testing.T) {
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
err = client.Publish(ctx, testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce)
err = client.Publish(ctx, testUUID+"/TestPublishContextTimeout", []byte("hello"), mqtt.AtLeastOnce)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("publish should have returned the error context.DeadlineExceeded")
}
@@ -79,7 +79,7 @@ func TestPublishContextCancelled(t *testing.T) {
cancel()
}()
defer cancel()
err = client.Publish(ctx, testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce)
err = client.Publish(ctx, testUUID+"/TestPublishContextCancelled", []byte("hello"), mqtt.AtLeastOnce)
if !errors.Is(err, context.Canceled) {
t.Fatalf("publish should have returned the error context.Canceled")
}
@@ -100,7 +100,7 @@ func TestPublishFailed(t *testing.T) {
if err != nil {
t.Fatalf("connect should not have failed: %v", err)
}
err = client.Publish(context.Background(), testUUID+"/test_publish", nil, 3)
err = client.Publish(context.Background(), testUUID+"/TestPublishFailed", nil, 3)
if err == nil {
t.Fatalf("publish should have failed")
}
@@ -122,7 +122,7 @@ func TestPublishSuccessRetained(t *testing.T) {
t.Fatalf("connect should not have failed: %v", err)
}
err = client.Publish(context.Background(), testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce, mqtt.Retain)
err = client.Publish(context.Background(), testUUID+"/TestPublishSuccessRetained", []byte("hello"), mqtt.AtLeastOnce, mqtt.Retain)
if err != nil {
t.Fatalf("publish should not have failed: %v", err)
}
@@ -144,7 +144,7 @@ func TestPublisStringSuccess(t *testing.T) {
t.Fatalf("connect should not have failed: %v", err)
}
err = client.PublishString(context.Background(), testUUID+"/test_publish", "world", mqtt.AtLeastOnce)
err = client.PublishString(context.Background(), testUUID+"/TestPublisStringSuccess", "world", mqtt.AtLeastOnce)
if err != nil {
t.Fatalf("publish should not have failed: %v", err)
}
@@ -166,7 +166,7 @@ func TestPublisJSONSuccess(t *testing.T) {
t.Fatalf("connect should not have failed: %v", err)
}
err = client.PublishJSON(context.Background(), testUUID+"/test_publish", []string{"hello", "world"}, mqtt.AtLeastOnce)
err = client.PublishJSON(context.Background(), testUUID+"/TestPublisJSONSuccess", []string{"hello", "world"}, mqtt.AtLeastOnce)
if err != nil {
t.Fatalf("publish should not have failed: %v", err)
}
@@ -188,7 +188,7 @@ func TestPublisJSONFailed(t *testing.T) {
t.Fatalf("connect should not have failed: %v", err)
}
err = client.PublishJSON(context.Background(), testUUID+"/test_publish", make(chan int), mqtt.AtLeastOnce)
err = client.PublishJSON(context.Background(), testUUID+"/TestPublisJSONFailed", make(chan int), mqtt.AtLeastOnce)
if _, ok := err.(*json.UnsupportedTypeError); !ok {
t.Fatalf("publish error should be of type *json.UnsupportedTypeError: %v", err)
}

77
router.go Normal file
View File

@@ -0,0 +1,77 @@
package mqtt
import (
"strings"
"sync"
)
type router struct {
routes []route
lock sync.RWMutex
}
func newRouter() *router {
return &router{routes: []route{}, lock: sync.RWMutex{}}
}
type route struct {
topic string
handler MessageHandler
}
func match(route []string, topic []string) bool {
if len(route) == 0 {
return len(topic) == 0
}
if len(topic) == 0 {
return route[0] == "#"
}
if route[0] == "#" {
return true
}
if (route[0] == "+") || (route[0] == topic[0]) {
return match(route[1:], topic[1:])
}
return false
}
func routeIncludesTopic(route, topic string) bool {
return match(routeSplit(route), strings.Split(topic, "/"))
}
func routeSplit(route string) []string {
var result []string
if strings.HasPrefix(route, "$share") {
result = strings.Split(route, "/")[2:]
} else {
result = strings.Split(route, "/")
}
return result
}
func (r *route) match(message *Message) bool {
return r.topic == message.Topic() || routeIncludesTopic(r.topic, message.Topic())
}
func (r *router) addRoute(topic string, handler MessageHandler) {
if handler != nil {
r.lock.Lock()
r.routes = append(r.routes, route{topic: topic, handler: handler})
r.lock.Unlock()
}
}
func (r *router) match(message *Message) []route {
routes := []route{}
r.lock.RLock()
for _, route := range r.routes {
if route.match(message) {
routes = append(routes, route)
}
}
r.lock.RUnlock()
return routes
}

58
subscribe.go Normal file
View File

@@ -0,0 +1,58 @@
package mqtt
import (
"context"
"encoding/json"
paho "github.com/eclipse/paho.mqtt.golang"
)
type Message struct {
message paho.Message
}
type MessageHandler func(Message)
func (m *Message) Topic() string {
return m.message.Topic()
}
func (m *Message) QOS() QOS {
return QOS(m.message.Qos())
}
func (m *Message) IsDuplicate() bool {
return m.message.Duplicate()
}
func (m *Message) Acknowledge() {
m.message.Ack()
}
func (m *Message) Payload() []byte {
return m.message.Payload()
}
func (m *Message) PayloadString() string {
return string(m.message.Payload())
}
func (m *Message) PayloadJSON(v interface{}) error {
return json.Unmarshal(m.message.Payload(), v)
}
func (c *Client) Listen(handler MessageHandler, topics ...string) {
for _, topic := range topics {
c.router.addRoute(topic, handler)
}
}
func (c *Client) Subscribe(ctx context.Context, handler MessageHandler, topic string, qos QOS) error {
token := c.client.Subscribe(topic, byte(qos), nil)
err := tokenWithContext(ctx, token)
if err != nil {
return err
}
c.router.addRoute(topic, handler)
return nil
}

178
subscribe_test.go Normal file
View File

@@ -0,0 +1,178 @@
package mqtt_test
import (
"context"
"testing"
"time"
"github.com/lucacasonato/mqtt"
)
// TestSubcribeSuccess checks that a message gets recieved correctly
func TestSubcribeSuccess(t *testing.T) {
client, err := mqtt.NewClient(mqtt.ClientOptions{
Servers: []string{
"tcp://test.mosquitto.org:1883",
},
})
if err != nil {
t.Fatalf("creating client should not have failed: %v", err)
}
err = client.Connect(context.Background())
defer client.DisconnectImmediately()
if err != nil {
t.Fatalf("connect should not have failed: %v", err)
}
reciever := make(chan mqtt.Message)
err = client.Subscribe(context.Background(), func(message mqtt.Message) {
reciever <- message
}, testUUID+"/TestSubcribeSuccess", mqtt.ExactlyOnce)
if err != nil {
t.Fatalf("subscribe should not have failed: %v", err)
}
err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccess", "[1, 2]", mqtt.ExactlyOnce)
if err != nil {
t.Fatalf("publish should not have failed: %v", err)
}
message := <-reciever
if string(message.Payload()) != "[1, 2]" {
t.Fatalf("message payload should have been byte array '%v' but is %v", []byte("[1, 2]"), message.Payload())
}
if message.PayloadString() != "[1, 2]" {
t.Fatalf("message payload should have been '[1, 2]' but is %v", message.PayloadString())
}
v := []int{}
err = message.PayloadJSON(&v)
if err != nil {
t.Fatalf("json should have unmarshalled: %v", err)
}
if len(v) != 2 || v[0] != 1 || v[1] != 2 {
t.Fatalf("message payload should have been []int{1, 2} but is %v", v)
}
if message.Topic() != testUUID+"/TestSubcribeSuccess" {
t.Fatalf("message topic should be %v but is %v", testUUID+"/TestSubcribeSuccess", message.Topic())
}
if message.QOS() != mqtt.ExactlyOnce {
t.Fatalf("message qos should be mqtt.ExactlyOnce but is %v", message.QOS())
}
if message.IsDuplicate() != false {
t.Fatalf("message IsDuplicate should be false but is %v", message.IsDuplicate())
}
message.Acknowledge()
}
// TestListenSuccess checks that a listener recieves a message correctly
func TestListenSuccess(t *testing.T) {
client, err := mqtt.NewClient(mqtt.ClientOptions{
Servers: []string{
"tcp://test.mosquitto.org:1883",
},
})
if err != nil {
t.Fatalf("creating client should not have failed: %v", err)
}
err = client.Connect(context.Background())
defer client.DisconnectImmediately()
if err != nil {
t.Fatalf("connect should not have failed: %v", err)
}
reciever := make(chan mqtt.Message)
err = client.Subscribe(context.Background(), func(message mqtt.Message) {}, testUUID+"/TestListenSuccess", mqtt.ExactlyOnce)
if err != nil {
t.Fatalf("subscribe should not have failed: %v", err)
}
client.Listen(func(message mqtt.Message) {
reciever <- message
}, testUUID+"/TestListenSuccess")
err = client.PublishString(context.Background(), testUUID+"/TestListenSuccess", "hello", mqtt.ExactlyOnce)
if err != nil {
t.Fatalf("publish should not have failed: %v", err)
}
message := <-reciever
if message.PayloadString() != "hello" {
t.Fatalf("message payload should have been 'hello' but is %v", message)
}
}
// TestSubcribeSuccess checks that a message gets recieved correctly
func TestSubcribeFailure(t *testing.T) {
client, err := mqtt.NewClient(mqtt.ClientOptions{
Servers: []string{
"tcp://test.mosquitto.org:1883",
},
})
if err != nil {
t.Fatalf("creating client should not have failed: %v", err)
}
err = client.Connect(context.Background())
defer client.DisconnectImmediately()
if err != nil {
t.Fatalf("connect should not have failed: %v", err)
}
err = client.Subscribe(context.Background(), func(message mqtt.Message) {}, testUUID+"/#/test_publish", mqtt.ExactlyOnce) // # in the middle of a subscribe is not allowed
if err == nil {
t.Fatalf("subscribe should have failed: %v", err)
}
}
// TestSubcribeSuccess checks that a message gets recieved correctly
func TestSubcribeSuccessAdvancedRouting(t *testing.T) {
client, err := mqtt.NewClient(mqtt.ClientOptions{
Servers: []string{
"tcp://test.mosquitto.org:1883",
},
})
if err != nil {
t.Fatalf("creating client should not have failed: %v", err)
}
err = client.Connect(context.Background())
defer client.DisconnectImmediately()
if err != nil {
t.Fatalf("connect should not have failed: %v", err)
}
reciever := make(chan mqtt.Message)
err = client.Subscribe(context.Background(), func(message mqtt.Message) {
reciever <- message
}, testUUID+"/TestSubcribeSuccessAdvancedRouting/#", mqtt.ExactlyOnce)
if err != nil {
t.Fatalf("subscribe should not have failed: %v", err)
}
err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccessAdvancedRouting/abc", "hello world", mqtt.ExactlyOnce)
if err != nil {
t.Fatalf("publish should not have failed: %v", err)
}
message := <-reciever
if message.PayloadString() != "hello world" {
t.Fatalf("message payload should have been 'hello world' but is %v", message.PayloadString())
}
}
// TestSubcribeSuccess checks that a message gets recieved correctly
func TestSubcribeNoRecieve(t *testing.T) {
client, err := mqtt.NewClient(mqtt.ClientOptions{
Servers: []string{
"tcp://test.mosquitto.org:1883",
},
})
if err != nil {
t.Fatalf("creating client should not have failed: %v", err)
}
err = client.Connect(context.Background())
defer client.DisconnectImmediately()
if err != nil {
t.Fatalf("connect should not have failed: %v", err)
}
client.Listen(func(message mqtt.Message) {
t.Fatalf("recieved a message which was not meant to happen: %v", err)
}, testUUID+"/TestSubcribeSuccessAdvancedRouting/abc")
err = client.Subscribe(context.Background(), nil, testUUID+"/TestSubcribeSuccessAdvancedRouting/def", mqtt.ExactlyOnce)
if err != nil {
t.Fatalf("subscribe should not have failed: %v", err)
}
err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccessAdvancedRouting/def", "hello world", mqtt.ExactlyOnce)
if err != nil {
t.Fatalf("publish should not have failed: %v", err)
}
<-time.After(500 * time.Millisecond)
}