mirror of
https://github.com/lucacasonato/mqtt.git
synced 2025-09-27 03:05:59 +08:00
Added routing and subscribing
This commit is contained in:
36
README.md
36
README.md
@@ -81,3 +81,39 @@ if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
||||
### subscribing
|
||||
|
||||
```go
|
||||
err := client.Subscribe(context.WithTimeout(1 * time.Second), func(message mqtt.Message) {
|
||||
fmt.Printf("recieved a message with content %v\n", message.PayloadString())
|
||||
}, "api/v0/main/client1", mqtt.AtLeastOnce)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
||||
### subscribing without listening
|
||||
|
||||
```go
|
||||
err := client.Subscribe(context.WithTimeout(1 * time.Second), nil, "api/v0/main/client1", mqtt.AtLeastOnce)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
||||
### listening without subscribing
|
||||
|
||||
```go
|
||||
err := client.Listen(func(message mqtt.Message) {
|
||||
v := interface{}{}
|
||||
err := message.PayloadJSON(&v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("recieved a message with content %v\n", v)
|
||||
}, "api/v0/main/client1")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
19
mqtt.go
19
mqtt.go
@@ -12,6 +12,7 @@ import (
|
||||
type Client struct {
|
||||
Options ClientOptions // The options that were used to create this client
|
||||
client paho.Client
|
||||
router *router
|
||||
}
|
||||
|
||||
// ClientOptions is the list of options used to create a client
|
||||
@@ -37,6 +38,14 @@ var (
|
||||
ErrMinimumOneServer = errors.New("mqtt: at least one server needs to be specified")
|
||||
)
|
||||
|
||||
func handle(callback MessageHandler) paho.MessageHandler {
|
||||
return func(client paho.Client, message paho.Message) {
|
||||
if callback != nil {
|
||||
callback(Message{message: message})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient creates a new client with the specified options
|
||||
func NewClient(options ClientOptions) (*Client, error) {
|
||||
pahoOptions := paho.NewClientOptions()
|
||||
@@ -66,7 +75,15 @@ func NewClient(options ClientOptions) (*Client, error) {
|
||||
pahoOptions.SetAutoReconnect(options.AutoReconnect)
|
||||
|
||||
pahoClient := paho.NewClient(pahoOptions)
|
||||
return &Client{client: pahoClient, Options: options}, nil
|
||||
router := newRouter()
|
||||
pahoClient.AddRoute("#", handle(func(message Message) {
|
||||
routes := router.match(&message)
|
||||
for _, route := range routes {
|
||||
route.handler(message)
|
||||
}
|
||||
}))
|
||||
|
||||
return &Client{client: pahoClient, Options: options, router: router}, nil
|
||||
}
|
||||
|
||||
// Connect tries to establish a conenction with the mqtt servers
|
||||
|
@@ -29,7 +29,7 @@ func TestPublishSuccess(t *testing.T) {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
|
||||
err = client.Publish(context.Background(), testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce)
|
||||
err = client.Publish(context.Background(), testUUID+"/TestPublishSuccess", []byte("hello"), mqtt.AtLeastOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("publish should not have failed: %v", err)
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func TestPublishContextTimeout(t *testing.T) {
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
|
||||
defer cancel()
|
||||
err = client.Publish(ctx, testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce)
|
||||
err = client.Publish(ctx, testUUID+"/TestPublishContextTimeout", []byte("hello"), mqtt.AtLeastOnce)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("publish should have returned the error context.DeadlineExceeded")
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func TestPublishContextCancelled(t *testing.T) {
|
||||
cancel()
|
||||
}()
|
||||
defer cancel()
|
||||
err = client.Publish(ctx, testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce)
|
||||
err = client.Publish(ctx, testUUID+"/TestPublishContextCancelled", []byte("hello"), mqtt.AtLeastOnce)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("publish should have returned the error context.Canceled")
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func TestPublishFailed(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
err = client.Publish(context.Background(), testUUID+"/test_publish", nil, 3)
|
||||
err = client.Publish(context.Background(), testUUID+"/TestPublishFailed", nil, 3)
|
||||
if err == nil {
|
||||
t.Fatalf("publish should have failed")
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func TestPublishSuccessRetained(t *testing.T) {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
|
||||
err = client.Publish(context.Background(), testUUID+"/test_publish", []byte("hello"), mqtt.AtLeastOnce, mqtt.Retain)
|
||||
err = client.Publish(context.Background(), testUUID+"/TestPublishSuccessRetained", []byte("hello"), mqtt.AtLeastOnce, mqtt.Retain)
|
||||
if err != nil {
|
||||
t.Fatalf("publish should not have failed: %v", err)
|
||||
}
|
||||
@@ -144,7 +144,7 @@ func TestPublisStringSuccess(t *testing.T) {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
|
||||
err = client.PublishString(context.Background(), testUUID+"/test_publish", "world", mqtt.AtLeastOnce)
|
||||
err = client.PublishString(context.Background(), testUUID+"/TestPublisStringSuccess", "world", mqtt.AtLeastOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("publish should not have failed: %v", err)
|
||||
}
|
||||
@@ -166,7 +166,7 @@ func TestPublisJSONSuccess(t *testing.T) {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
|
||||
err = client.PublishJSON(context.Background(), testUUID+"/test_publish", []string{"hello", "world"}, mqtt.AtLeastOnce)
|
||||
err = client.PublishJSON(context.Background(), testUUID+"/TestPublisJSONSuccess", []string{"hello", "world"}, mqtt.AtLeastOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("publish should not have failed: %v", err)
|
||||
}
|
||||
@@ -188,7 +188,7 @@ func TestPublisJSONFailed(t *testing.T) {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
|
||||
err = client.PublishJSON(context.Background(), testUUID+"/test_publish", make(chan int), mqtt.AtLeastOnce)
|
||||
err = client.PublishJSON(context.Background(), testUUID+"/TestPublisJSONFailed", make(chan int), mqtt.AtLeastOnce)
|
||||
if _, ok := err.(*json.UnsupportedTypeError); !ok {
|
||||
t.Fatalf("publish error should be of type *json.UnsupportedTypeError: %v", err)
|
||||
}
|
||||
|
77
router.go
Normal file
77
router.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type router struct {
|
||||
routes []route
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func newRouter() *router {
|
||||
return &router{routes: []route{}, lock: sync.RWMutex{}}
|
||||
}
|
||||
|
||||
type route struct {
|
||||
topic string
|
||||
handler MessageHandler
|
||||
}
|
||||
|
||||
func match(route []string, topic []string) bool {
|
||||
if len(route) == 0 {
|
||||
return len(topic) == 0
|
||||
}
|
||||
|
||||
if len(topic) == 0 {
|
||||
return route[0] == "#"
|
||||
}
|
||||
|
||||
if route[0] == "#" {
|
||||
return true
|
||||
}
|
||||
|
||||
if (route[0] == "+") || (route[0] == topic[0]) {
|
||||
return match(route[1:], topic[1:])
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func routeIncludesTopic(route, topic string) bool {
|
||||
return match(routeSplit(route), strings.Split(topic, "/"))
|
||||
}
|
||||
|
||||
func routeSplit(route string) []string {
|
||||
var result []string
|
||||
if strings.HasPrefix(route, "$share") {
|
||||
result = strings.Split(route, "/")[2:]
|
||||
} else {
|
||||
result = strings.Split(route, "/")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *route) match(message *Message) bool {
|
||||
return r.topic == message.Topic() || routeIncludesTopic(r.topic, message.Topic())
|
||||
}
|
||||
|
||||
func (r *router) addRoute(topic string, handler MessageHandler) {
|
||||
if handler != nil {
|
||||
r.lock.Lock()
|
||||
r.routes = append(r.routes, route{topic: topic, handler: handler})
|
||||
r.lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) match(message *Message) []route {
|
||||
routes := []route{}
|
||||
r.lock.RLock()
|
||||
for _, route := range r.routes {
|
||||
if route.match(message) {
|
||||
routes = append(routes, route)
|
||||
}
|
||||
}
|
||||
r.lock.RUnlock()
|
||||
return routes
|
||||
}
|
58
subscribe.go
Normal file
58
subscribe.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
paho "github.com/eclipse/paho.mqtt.golang"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
message paho.Message
|
||||
}
|
||||
|
||||
type MessageHandler func(Message)
|
||||
|
||||
func (m *Message) Topic() string {
|
||||
return m.message.Topic()
|
||||
}
|
||||
|
||||
func (m *Message) QOS() QOS {
|
||||
return QOS(m.message.Qos())
|
||||
}
|
||||
|
||||
func (m *Message) IsDuplicate() bool {
|
||||
return m.message.Duplicate()
|
||||
}
|
||||
|
||||
func (m *Message) Acknowledge() {
|
||||
m.message.Ack()
|
||||
}
|
||||
|
||||
func (m *Message) Payload() []byte {
|
||||
return m.message.Payload()
|
||||
}
|
||||
|
||||
func (m *Message) PayloadString() string {
|
||||
return string(m.message.Payload())
|
||||
}
|
||||
|
||||
func (m *Message) PayloadJSON(v interface{}) error {
|
||||
return json.Unmarshal(m.message.Payload(), v)
|
||||
}
|
||||
|
||||
func (c *Client) Listen(handler MessageHandler, topics ...string) {
|
||||
for _, topic := range topics {
|
||||
c.router.addRoute(topic, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Subscribe(ctx context.Context, handler MessageHandler, topic string, qos QOS) error {
|
||||
token := c.client.Subscribe(topic, byte(qos), nil)
|
||||
err := tokenWithContext(ctx, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.router.addRoute(topic, handler)
|
||||
return nil
|
||||
}
|
178
subscribe_test.go
Normal file
178
subscribe_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package mqtt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lucacasonato/mqtt"
|
||||
)
|
||||
|
||||
// TestSubcribeSuccess checks that a message gets recieved correctly
|
||||
func TestSubcribeSuccess(t *testing.T) {
|
||||
client, err := mqtt.NewClient(mqtt.ClientOptions{
|
||||
Servers: []string{
|
||||
"tcp://test.mosquitto.org:1883",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("creating client should not have failed: %v", err)
|
||||
}
|
||||
err = client.Connect(context.Background())
|
||||
defer client.DisconnectImmediately()
|
||||
if err != nil {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
|
||||
reciever := make(chan mqtt.Message)
|
||||
err = client.Subscribe(context.Background(), func(message mqtt.Message) {
|
||||
reciever <- message
|
||||
}, testUUID+"/TestSubcribeSuccess", mqtt.ExactlyOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("subscribe should not have failed: %v", err)
|
||||
}
|
||||
err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccess", "[1, 2]", mqtt.ExactlyOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("publish should not have failed: %v", err)
|
||||
}
|
||||
message := <-reciever
|
||||
if string(message.Payload()) != "[1, 2]" {
|
||||
t.Fatalf("message payload should have been byte array '%v' but is %v", []byte("[1, 2]"), message.Payload())
|
||||
}
|
||||
if message.PayloadString() != "[1, 2]" {
|
||||
t.Fatalf("message payload should have been '[1, 2]' but is %v", message.PayloadString())
|
||||
}
|
||||
v := []int{}
|
||||
err = message.PayloadJSON(&v)
|
||||
if err != nil {
|
||||
t.Fatalf("json should have unmarshalled: %v", err)
|
||||
}
|
||||
if len(v) != 2 || v[0] != 1 || v[1] != 2 {
|
||||
t.Fatalf("message payload should have been []int{1, 2} but is %v", v)
|
||||
}
|
||||
if message.Topic() != testUUID+"/TestSubcribeSuccess" {
|
||||
t.Fatalf("message topic should be %v but is %v", testUUID+"/TestSubcribeSuccess", message.Topic())
|
||||
}
|
||||
if message.QOS() != mqtt.ExactlyOnce {
|
||||
t.Fatalf("message qos should be mqtt.ExactlyOnce but is %v", message.QOS())
|
||||
}
|
||||
if message.IsDuplicate() != false {
|
||||
t.Fatalf("message IsDuplicate should be false but is %v", message.IsDuplicate())
|
||||
}
|
||||
message.Acknowledge()
|
||||
}
|
||||
|
||||
// TestListenSuccess checks that a listener recieves a message correctly
|
||||
func TestListenSuccess(t *testing.T) {
|
||||
client, err := mqtt.NewClient(mqtt.ClientOptions{
|
||||
Servers: []string{
|
||||
"tcp://test.mosquitto.org:1883",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("creating client should not have failed: %v", err)
|
||||
}
|
||||
err = client.Connect(context.Background())
|
||||
defer client.DisconnectImmediately()
|
||||
if err != nil {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
reciever := make(chan mqtt.Message)
|
||||
err = client.Subscribe(context.Background(), func(message mqtt.Message) {}, testUUID+"/TestListenSuccess", mqtt.ExactlyOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("subscribe should not have failed: %v", err)
|
||||
}
|
||||
client.Listen(func(message mqtt.Message) {
|
||||
reciever <- message
|
||||
}, testUUID+"/TestListenSuccess")
|
||||
err = client.PublishString(context.Background(), testUUID+"/TestListenSuccess", "hello", mqtt.ExactlyOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("publish should not have failed: %v", err)
|
||||
}
|
||||
message := <-reciever
|
||||
if message.PayloadString() != "hello" {
|
||||
t.Fatalf("message payload should have been 'hello' but is %v", message)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubcribeSuccess checks that a message gets recieved correctly
|
||||
func TestSubcribeFailure(t *testing.T) {
|
||||
client, err := mqtt.NewClient(mqtt.ClientOptions{
|
||||
Servers: []string{
|
||||
"tcp://test.mosquitto.org:1883",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("creating client should not have failed: %v", err)
|
||||
}
|
||||
err = client.Connect(context.Background())
|
||||
defer client.DisconnectImmediately()
|
||||
if err != nil {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
err = client.Subscribe(context.Background(), func(message mqtt.Message) {}, testUUID+"/#/test_publish", mqtt.ExactlyOnce) // # in the middle of a subscribe is not allowed
|
||||
if err == nil {
|
||||
t.Fatalf("subscribe should have failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubcribeSuccess checks that a message gets recieved correctly
|
||||
func TestSubcribeSuccessAdvancedRouting(t *testing.T) {
|
||||
client, err := mqtt.NewClient(mqtt.ClientOptions{
|
||||
Servers: []string{
|
||||
"tcp://test.mosquitto.org:1883",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("creating client should not have failed: %v", err)
|
||||
}
|
||||
err = client.Connect(context.Background())
|
||||
defer client.DisconnectImmediately()
|
||||
if err != nil {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
reciever := make(chan mqtt.Message)
|
||||
err = client.Subscribe(context.Background(), func(message mqtt.Message) {
|
||||
reciever <- message
|
||||
}, testUUID+"/TestSubcribeSuccessAdvancedRouting/#", mqtt.ExactlyOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("subscribe should not have failed: %v", err)
|
||||
}
|
||||
err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccessAdvancedRouting/abc", "hello world", mqtt.ExactlyOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("publish should not have failed: %v", err)
|
||||
}
|
||||
message := <-reciever
|
||||
if message.PayloadString() != "hello world" {
|
||||
t.Fatalf("message payload should have been 'hello world' but is %v", message.PayloadString())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubcribeSuccess checks that a message gets recieved correctly
|
||||
func TestSubcribeNoRecieve(t *testing.T) {
|
||||
client, err := mqtt.NewClient(mqtt.ClientOptions{
|
||||
Servers: []string{
|
||||
"tcp://test.mosquitto.org:1883",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("creating client should not have failed: %v", err)
|
||||
}
|
||||
err = client.Connect(context.Background())
|
||||
defer client.DisconnectImmediately()
|
||||
if err != nil {
|
||||
t.Fatalf("connect should not have failed: %v", err)
|
||||
}
|
||||
client.Listen(func(message mqtt.Message) {
|
||||
t.Fatalf("recieved a message which was not meant to happen: %v", err)
|
||||
}, testUUID+"/TestSubcribeSuccessAdvancedRouting/abc")
|
||||
err = client.Subscribe(context.Background(), nil, testUUID+"/TestSubcribeSuccessAdvancedRouting/def", mqtt.ExactlyOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("subscribe should not have failed: %v", err)
|
||||
}
|
||||
err = client.PublishString(context.Background(), testUUID+"/TestSubcribeSuccessAdvancedRouting/def", "hello world", mqtt.ExactlyOnce)
|
||||
if err != nil {
|
||||
t.Fatalf("publish should not have failed: %v", err)
|
||||
}
|
||||
<-time.After(500 * time.Millisecond)
|
||||
}
|
Reference in New Issue
Block a user