Compare commits

..

13 Commits

Author SHA1 Message Date
JB
1ae050939a Merge pull request #84 from mochi-co/goreport-fixes
Goreport fixes
2022-06-22 15:52:36 +01:00
mochi
f4683d27d0 remove ineffective assignments 2022-06-22 15:45:13 +01:00
mochi
dff2b1db30 apply gofmt -s 2022-06-22 15:40:52 +01:00
JB
9de6b4e427 Merge pull request #83 from mochi-co/tls-client-auth
Expose tls.Config to Listeners
2022-06-22 15:32:23 +01:00
JB
78c1914270 Merge pull request #82 from mochi-co/expose-event-client-username
Add CleanSession and Username to events.Client struct
2022-06-22 15:31:31 +01:00
mochi
f71bf5c3d6 use TLSConfig instead of deprecated TLS field 2022-06-22 15:26:51 +01:00
mochi
53c4a6b09f Add TLSConfig field to allow direct tls.Config setting 2022-06-22 15:26:26 +01:00
mochi
a02c6bd8df update TLS example to use TLSConfig field 2022-06-22 15:25:52 +01:00
mochi
d8f6d63cc8 Add CleanSession and Username to events.Client struct 2022-06-22 12:33:09 +01:00
mochi
bef13eec20 Add OnSubscribe, OnUnsubscribe events examples 2022-05-04 12:58:23 +01:00
mochi
27f3c484ad Extend onsusbcribe, onunsubscribe events 2022-05-04 12:53:04 +01:00
JB
9b5cdb0bcc Merge pull request #74 from muXxer/feat/topic-subscription-events 2022-05-04 12:33:12 +01:00
muXxer
2b60a11d4a Add topic un-/subscribe events 2022-04-28 00:48:20 +02:00
17 changed files with 225 additions and 52 deletions

View File

@@ -30,7 +30,7 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
- Interfaces for Client Authentication and Topic access control.
- Bolt persistence and storage interfaces (see examples folder).
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
- Basic Event Hooks (`OnMessage`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
- Basic Event Hooks (`OnMessage`, `onSubscribe`, `onUnsubscribe`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
- ARM32 Compatible.
#### Roadmap
@@ -141,7 +141,25 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
```go
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
}
```
##### OnSubscribe
`server.Events.OnSubscribe` is called when a client subscribes to a new topic filter.
```go
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
}
```
##### OnUnsubscribe
`server.Events.OnUnsubscribe` is called when a client unsubscribes from a topic filter.
```go
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
}
```
@@ -158,7 +176,6 @@ If an error is returned, the packet will not be modified. and the existing packe
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
```go
import "github.com/mochi-co/mqtt/server/events"
@@ -175,6 +192,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
##### OnError
`server.Events.OnError` is called when an error is encountered on the server, particularly within the use of a client connection status.

View File

@@ -54,6 +54,16 @@ func main() {
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
}
// Add OnSubscribe Event Hook
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
}
// Add OnUnsubscribe Event Hook
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
}
// Add OnMessage Event Hook
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
pkx = pk

View File

@@ -1,6 +1,7 @@
package main
import (
"crypto/tls"
"fmt"
"log"
"os"
@@ -57,14 +58,30 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
// Optionally, if you want clients to authenticate only with certs issued by your CA,
// you might want to use something like this:
// certPool := x509.NewCertPool()
// _ = certPool.AppendCertsFromPEM(caCertPem)
// tlsConfig := &tls.Config{
// ClientCAs: certPool,
// ClientAuth: tls.RequireAndVerifyClientCert,
// }
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
err = server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
@@ -72,11 +89,8 @@ func main() {
ws := listeners.NewWebsocket("ws1", ":1882")
err = server.AddListener(ws, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
@@ -84,11 +98,8 @@ func main() {
stats := listeners.NewHTTPStats("stats", ":8080")
err = server.AddListener(stats, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)

View File

@@ -11,6 +11,8 @@ type Events struct {
OnError // server error.
OnConnect // client connected.
OnDisconnect // client disconnected.
OnSubscribe // topic subscription created.
OnUnsubscribe // topic subscription removed.
}
// Packets is an alias for packets.Packet.
@@ -18,9 +20,11 @@ type Packet packets.Packet
// Client contains limited information about a connected client.
type Client struct {
ID string
Remote string
Listener string
ID string
Remote string
Listener string
Username []byte
CleanSession bool
}
// Clientlike is an interface for Clients and client-like objects that
@@ -40,7 +44,7 @@ type Clientlike interface {
// be dispatched as if the event hook had not been triggered.
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any further
// processing of the packet.
type OnProcessMessage func(Client, Packet) (Packet, error)
@@ -66,3 +70,9 @@ type OnDisconnect func(Client, error)
// OnError is called when errors that will not be passed to
// OnDisconnect are handled by the server.
type OnError func(Client, error)
// OnSubscribe is called when a new subscription filter for a client is created.
type OnSubscribe func(filter string, cl Client, qos byte)
// OnUnsubscribe is called when an existing subscription filter for a client is removed.
type OnUnsubscribe func(filter string, cl Client)

View File

@@ -200,9 +200,11 @@ func (cl *Client) Info() events.Client {
addr = cl.conn.RemoteAddr().String()
}
return events.Client{
ID: cl.ID,
Remote: addr,
Listener: cl.Listener,
ID: cl.ID,
Remote: addr,
Username: cl.Username,
CleanSession: cl.CleanSession,
Listener: cl.Listener,
}
}

View File

@@ -21,7 +21,7 @@ func BenchmarkBytesToString(b *testing.B) {
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
name string
rawBytes []byte
result string
offset int
@@ -88,8 +88,8 @@ func TestDecodeString(t *testing.T) {
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrOffsetStrInvalidUTF8,
},
}
@@ -101,7 +101,7 @@ func TestDecodeString(t *testing.T) {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
@@ -209,7 +209,7 @@ func TestDecodeByte(t *testing.T) {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
@@ -250,12 +250,11 @@ func TestDecodeUint16(t *testing.T) {
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
@@ -295,7 +294,7 @@ func TestDecodeByteBool(t *testing.T) {
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return

View File

@@ -215,7 +215,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
}
if pk.PasswordFlag {
pk.Password, offset, err = decodeBytes(buf, offset)
pk.Password, _, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
}
@@ -287,7 +287,7 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
}
pk.ReturnCode, offset, err = decodeByte(buf, offset)
pk.ReturnCode, _, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
}

View File

@@ -70,6 +70,9 @@ func (l *HTTPStats) Listen(s *system.Info) error {
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
@@ -79,6 +82,8 @@ func (l *HTTPStats) Listen(s *system.Info) error {
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil

View File

@@ -73,6 +73,18 @@ func TestHTTPStatsListen(t *testing.T) {
l.listen.Close()
}
func TestHTTPStatsListenTLSConfig(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestHTTPStatsListenTLS(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{

View File

@@ -1,6 +1,7 @@
package listeners
import (
"crypto/tls"
"net"
"sync"
@@ -10,11 +11,25 @@ import (
// Config contains configuration values for a listener.
type Config struct {
Auth auth.Controller // an authentication controller containing auth and ACL logic.
TLS *TLS // the TLS certficates and settings for the connection.
// Auth controller containing auth and ACL logic for
// allowing or denying access to the server and topics.
Auth auth.Controller
// TLS certficates and settings for the connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
TLS *TLS
// TLSConfig is a tls.Config configuration to be used with the listener.
// See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config
}
// TLS contains the TLS certificates and settings for the listener connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
type TLS struct {
Certificate []byte // the body of a public certificate.
PrivateKey []byte // the body of a private key.

View File

@@ -1,6 +1,8 @@
package listeners
import (
"crypto/tls"
"log"
"testing"
"time"
@@ -37,8 +39,22 @@ WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
tlsConfigBasic *tls.Config
)
func init() {
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfigBasic = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
func TestNew(t *testing.T) {
l := New(nil)
require.NotNil(t, l.internal)

View File

@@ -62,6 +62,9 @@ func (l *TCP) ID() string {
func (l *TCP) Listen(s *system.Info) error {
var err error
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
@@ -72,9 +75,12 @@ func (l *TCP) Listen(s *system.Info) error {
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
} else if l.config.TLSConfig != nil {
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
} else {
l.listen, err = net.Listen(l.protocol, l.address)
}
if err != nil {
return err
}

View File

@@ -73,6 +73,17 @@ func TestTCPListen(t *testing.T) {
l.listen.Close()
}
func TestTCPListenTLSConfig(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
l.listen.Close()
}
func TestTCPListenTLS(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{

View File

@@ -122,6 +122,9 @@ func (l *Websocket) Listen(s *system.Info) error {
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
@@ -131,6 +134,8 @@ func (l *Websocket) Listen(s *system.Info) error {
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil

View File

@@ -77,6 +77,18 @@ func TestWebsocketListen(t *testing.T) {
require.NotNil(t, l.listen)
}
func TestWebsocketListenTLSConfig(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestWebsocketListenTLS(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{

View File

@@ -400,6 +400,9 @@ func (s *Server) unsubscribeClient(cl *clients.Client) {
for k := range cl.Subscriptions {
delete(cl.Subscriptions, k)
if s.Topics.Unsubscribe(k, cl.ID) {
if s.Events.OnUnsubscribe != nil {
s.Events.OnUnsubscribe(k, cl.Info())
}
atomic.AddInt64(&s.System.Subscriptions, -1)
}
}
@@ -735,8 +738,11 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
if !cl.AC.ACL(cl.Username, pk.Topics[i], false) {
retCodes[i] = packets.ErrSubAckNetworkError
} else {
q := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
if q {
r := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
if r {
if s.Events.OnSubscribe != nil {
s.Events.OnSubscribe(pk.Topics[i], cl.Info(), pk.Qoss[i])
}
atomic.AddInt64(&s.System.Subscriptions, 1)
}
cl.NoteSubscription(pk.Topics[i], pk.Qoss[i])
@@ -785,6 +791,9 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
for i := 0; i < len(pk.Topics); i++ {
q := s.Topics.Unsubscribe(pk.Topics[i], cl.ID)
if q {
if s.Events.OnUnsubscribe != nil {
s.Events.OnUnsubscribe(pk.Topics[i], cl.Info())
}
atomic.AddInt64(&s.System.Subscriptions, -1)
}
cl.ForgetSubscription(pk.Topics[i])
@@ -1004,9 +1013,13 @@ func (s *Server) loadServerInfo(v persistence.ServerInfo) {
// loadSubscriptions restores subscriptions from the datastore.
func (s *Server) loadSubscriptions(v []persistence.Subscription) {
for _, sub := range v {
s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS)
if cl, ok := s.Clients.Get(sub.Client); ok {
cl.NoteSubscription(sub.Filter, sub.QoS)
if s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) {
if cl, ok := s.Clients.Get(sub.Client); ok {
cl.NoteSubscription(sub.Filter, sub.QoS)
if s.Events.OnSubscribe != nil {
s.Events.OnSubscribe(sub.Filter, cl.Info(), sub.QoS)
}
}
}
}
}

View File

@@ -368,9 +368,10 @@ func TestServerEventOnConnect(t *testing.T) {
time.Sleep(10 * time.Millisecond)
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
require.Equal(t, events.Packet(packets.Packet{
@@ -443,9 +444,10 @@ func TestServerEventOnDisconnect(t *testing.T) {
w.Close()
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
require.ErrorIs(t, ErrClientDisconnect, hook.err)
@@ -506,9 +508,10 @@ func TestServerEventOnDisconnectOnError(t *testing.T) {
require.Equal(t, errx, hook.err)
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
clw, ok := s.Clients.Get("mochi")
@@ -1946,6 +1949,15 @@ func TestServerProcessSubscribeInvalid(t *testing.T) {
func TestServerProcessSubscribe(t *testing.T) {
s, cl, r, w := setupClient()
subscribeEvent := ""
subscribeClient := ""
s.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
if filter == "a/b/c" {
subscribeEvent = "a/b/c"
subscribeClient = cl.ID
}
}
s.Topics.RetainMessage(packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
@@ -1995,6 +2007,8 @@ func TestServerProcessSubscribe(t *testing.T) {
require.Equal(t, byte(1), cl.Subscriptions["d/e/f"])
require.Equal(t, topics.Subscriptions{cl.ID: 0}, s.Topics.Subscribers("a/b/c"))
require.Equal(t, topics.Subscriptions{cl.ID: 1}, s.Topics.Subscribers("d/e/f"))
require.Equal(t, "a/b/c", subscribeEvent)
require.Equal(t, cl.ID, subscribeClient)
}
func TestServerProcessSubscribeFailACL(t *testing.T) {
@@ -2114,6 +2128,16 @@ func TestServerProcessUnsubscribeInvalid(t *testing.T) {
func TestServerProcessUnsubscribe(t *testing.T) {
s, cl, r, w := setupClient()
unsubscribeEvent := ""
unsubscribeClient := ""
s.Events.OnUnsubscribe = func(filter string, cl events.Client) {
if filter == "a/b/c" {
unsubscribeEvent = "a/b/c"
unsubscribeClient = cl.ID
}
}
s.Clients.Add(cl)
s.Topics.Subscribe("a/b/c", cl.ID, 0)
s.Topics.Subscribe("d/e/f", cl.ID, 1)
@@ -2155,6 +2179,9 @@ func TestServerProcessUnsubscribe(t *testing.T) {
require.NotEmpty(t, s.Topics.Subscribers("a/b/+"))
require.Contains(t, cl.Subscriptions, "a/b/+")
require.Equal(t, "a/b/c", unsubscribeEvent)
require.Equal(t, cl.ID, unsubscribeClient)
}
func TestServerProcessUnsubscribeWriteError(t *testing.T) {