mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-04 07:46:34 +08:00
Compare commits
12 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
84fc2f848b | ||
![]() |
8703d6d020 | ||
![]() |
666440fe56 | ||
![]() |
1ae050939a | ||
![]() |
f4683d27d0 | ||
![]() |
dff2b1db30 | ||
![]() |
9de6b4e427 | ||
![]() |
78c1914270 | ||
![]() |
f71bf5c3d6 | ||
![]() |
53c4a6b09f | ||
![]() |
a02c6bd8df | ||
![]() |
d8f6d63cc8 |
@@ -27,8 +27,9 @@ func main() {
|
||||
|
||||
// An example of configuring various server options...
|
||||
options := &mqtt.Options{
|
||||
BufferSize: 0, // Use default values
|
||||
BufferBlockSize: 0, // Use default values
|
||||
BufferSize: 0, // Use default values
|
||||
BufferBlockSize: 0, // Use default values
|
||||
InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
|
||||
}
|
||||
|
||||
server := mqtt.NewServer(options)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -57,14 +58,30 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
|
||||
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
// Optionally, if you want clients to authenticate only with certs issued by your CA,
|
||||
// you might want to use something like this:
|
||||
// certPool := x509.NewCertPool()
|
||||
// _ = certPool.AppendCertsFromPEM(caCertPem)
|
||||
// tlsConfig := &tls.Config{
|
||||
// ClientCAs: certPool,
|
||||
// ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// }
|
||||
|
||||
server := mqtt.NewServer(nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
err := server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
err = server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -72,11 +89,8 @@ func main() {
|
||||
|
||||
ws := listeners.NewWebsocket("ws1", ":1882")
|
||||
err = server.AddListener(ws, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -84,11 +98,8 @@ func main() {
|
||||
|
||||
stats := listeners.NewHTTPStats("stats", ":8080")
|
||||
err = server.AddListener(stats, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@@ -20,9 +20,11 @@ type Packet packets.Packet
|
||||
|
||||
// Client contains limited information about a connected client.
|
||||
type Client struct {
|
||||
ID string
|
||||
Remote string
|
||||
Listener string
|
||||
ID string
|
||||
Remote string
|
||||
Listener string
|
||||
Username []byte
|
||||
CleanSession bool
|
||||
}
|
||||
|
||||
// Clientlike is an interface for Clients and client-like objects that
|
||||
@@ -42,7 +44,7 @@ type Clientlike interface {
|
||||
// be dispatched as if the event hook had not been triggered.
|
||||
// This function will block message dispatching until it returns. To minimise this,
|
||||
// have the function open a new goroutine on the embedding side.
|
||||
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
|
||||
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any further
|
||||
// processing of the packet.
|
||||
type OnProcessMessage func(Client, Packet) (Packet, error)
|
||||
|
||||
|
@@ -49,6 +49,13 @@ func (cl *Clients) Add(val *Client) {
|
||||
cl.Unlock()
|
||||
}
|
||||
|
||||
// GetAll returns all the clients.
|
||||
func (cl *Clients) GetAll() map[string]*Client {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
return cl.internal
|
||||
}
|
||||
|
||||
// Get returns the value of a client if it exists.
|
||||
func (cl *Clients) Get(id string) (*Client, bool) {
|
||||
cl.RLock()
|
||||
@@ -200,9 +207,11 @@ func (cl *Client) Info() events.Client {
|
||||
addr = cl.conn.RemoteAddr().String()
|
||||
}
|
||||
return events.Client{
|
||||
ID: cl.ID,
|
||||
Remote: addr,
|
||||
Listener: cl.Listener,
|
||||
ID: cl.ID,
|
||||
Remote: addr,
|
||||
Username: cl.Username,
|
||||
CleanSession: cl.CleanSession,
|
||||
Listener: cl.Listener,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -509,6 +518,7 @@ type LWT struct {
|
||||
type InflightMessage struct {
|
||||
Packet packets.Packet // the packet currently in-flight.
|
||||
Sent int64 // the last time the message was sent (for retries) in unixtime.
|
||||
Created int64 // the unix timestamp when the inflight message was created.
|
||||
Resends int // the number of times the message was attempted to be sent.
|
||||
}
|
||||
|
||||
@@ -555,8 +565,21 @@ func (i *Inflight) GetAll() map[uint16]InflightMessage {
|
||||
// message existed.
|
||||
func (i *Inflight) Delete(key uint16) bool {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
_, ok := i.internal[key]
|
||||
delete(i.internal, key)
|
||||
i.Unlock()
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// ClearExpired deletes any inflight messages that have remained longer than
|
||||
// the servers InflightTTL duration.
|
||||
func (i *Inflight) ClearExpired(expiry int64) {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
for k, m := range i.internal {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
delete(i.internal, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -69,6 +69,36 @@ func BenchmarkClientsGet(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsGetAll(t *testing.T) {
|
||||
cl := New()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
cl.Add(&Client{ID: "t3"})
|
||||
cl.Add(&Client{ID: "t4"})
|
||||
cl.Add(&Client{ID: "t5"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
require.Contains(t, cl.internal, "t3")
|
||||
require.Contains(t, cl.internal, "t4")
|
||||
require.Contains(t, cl.internal, "t5")
|
||||
|
||||
clients := cl.GetAll()
|
||||
require.Len(t, clients, 5)
|
||||
}
|
||||
|
||||
func BenchmarkClientsGetAll(b *testing.B) {
|
||||
cl := New()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
cl.Add(&Client{ID: "t3"})
|
||||
cl.Add(&Client{ID: "t4"})
|
||||
cl.Add(&Client{ID: "t5"})
|
||||
for n := 0; n < b.N; n++ {
|
||||
clients := cl.GetAll()
|
||||
require.Len(b, clients, 5)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsLen(t *testing.T) {
|
||||
cl := New()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
@@ -714,12 +744,13 @@ func TestClientWritePacket(t *testing.T) {
|
||||
r.Close()
|
||||
|
||||
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
|
||||
|
||||
cl.Stop(testClientStop)
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
|
||||
// The stop cause is either the test error, EOF, or a
|
||||
// closed pipe, depending on which goroutine runs first.
|
||||
err = cl.StopCause()
|
||||
time.Sleep(time.Millisecond * 5)
|
||||
require.True(t,
|
||||
errors.Is(err, testClientStop) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
@@ -858,6 +889,42 @@ func BenchmarkInflightDelete(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInflightClearExpired(t *testing.T) {
|
||||
n := time.Now().Unix()
|
||||
|
||||
cl := genClient()
|
||||
cl.Inflight.Set(1, InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 1,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(2, InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 2,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(3, InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 3,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(5, InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 5,
|
||||
Sent: 0,
|
||||
})
|
||||
|
||||
require.Len(t, cl.Inflight.internal, 4)
|
||||
|
||||
cl.Inflight.ClearExpired(n - 2)
|
||||
cl.Inflight.RLock()
|
||||
defer cl.Inflight.RUnlock()
|
||||
require.Len(t, cl.Inflight.internal, 2)
|
||||
require.Equal(t, (n - 1), cl.Inflight.internal[1].Created)
|
||||
require.Equal(t, (n - 2), cl.Inflight.internal[2].Created)
|
||||
require.Equal(t, int64(0), cl.Inflight.internal[3].Created)
|
||||
}
|
||||
|
||||
var (
|
||||
pkTable = []struct {
|
||||
bytes []byte
|
||||
|
@@ -21,7 +21,7 @@ func BenchmarkBytesToString(b *testing.B) {
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
name string
|
||||
rawBytes []byte
|
||||
result string
|
||||
offset int
|
||||
@@ -88,8 +88,8 @@ func TestDecodeString(t *testing.T) {
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrOffsetStrInvalidUTF8,
|
||||
},
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func TestDecodeString(t *testing.T) {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
@@ -209,7 +209,7 @@ func TestDecodeByte(t *testing.T) {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
@@ -250,12 +250,11 @@ func TestDecodeUint16(t *testing.T) {
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
@@ -295,7 +294,7 @@ func TestDecodeByteBool(t *testing.T) {
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
|
@@ -215,7 +215,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
}
|
||||
|
||||
if pk.PasswordFlag {
|
||||
pk.Password, offset, err = decodeBytes(buf, offset)
|
||||
pk.Password, _, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
|
||||
}
|
||||
@@ -287,7 +287,7 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
|
||||
}
|
||||
|
||||
pk.ReturnCode, offset, err = decodeByte(buf, offset)
|
||||
pk.ReturnCode, _, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
|
||||
}
|
||||
|
@@ -70,6 +70,9 @@ func (l *HTTPStats) Listen(s *system.Info) error {
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
@@ -79,6 +82,8 @@ func (l *HTTPStats) Listen(s *system.Info) error {
|
||||
l.listen.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
} else {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@@ -73,6 +73,18 @@ func TestHTTPStatsListen(t *testing.T) {
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLSConfig(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(new(system.Info))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLS(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
@@ -10,11 +11,25 @@ import (
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
type Config struct {
|
||||
Auth auth.Controller // an authentication controller containing auth and ACL logic.
|
||||
TLS *TLS // the TLS certficates and settings for the connection.
|
||||
// Auth controller containing auth and ACL logic for
|
||||
// allowing or denying access to the server and topics.
|
||||
Auth auth.Controller
|
||||
|
||||
// TLS certficates and settings for the connection.
|
||||
//
|
||||
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
|
||||
// Please use TLSConfig instead.
|
||||
TLS *TLS
|
||||
|
||||
// TLSConfig is a tls.Config configuration to be used with the listener.
|
||||
// See examples folder for basic and mutual-tls use.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// TLS contains the TLS certificates and settings for the listener connection.
|
||||
//
|
||||
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
|
||||
// Please use TLSConfig instead.
|
||||
type TLS struct {
|
||||
Certificate []byte // the body of a public certificate.
|
||||
PrivateKey []byte // the body of a private key.
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -37,8 +39,22 @@ WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
tlsConfigBasic *tls.Config
|
||||
)
|
||||
|
||||
func init() {
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfigBasic = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
l := New(nil)
|
||||
require.NotNil(t, l.internal)
|
||||
|
@@ -62,6 +62,9 @@ func (l *TCP) ID() string {
|
||||
func (l *TCP) Listen(s *system.Info) error {
|
||||
var err error
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
@@ -72,9 +75,12 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
} else if l.config.TLSConfig != nil {
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
|
||||
} else {
|
||||
l.listen, err = net.Listen(l.protocol, l.address)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -73,6 +73,17 @@ func TestTCPListen(t *testing.T) {
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLSConfig(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLS(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
@@ -122,6 +122,9 @@ func (l *Websocket) Listen(s *system.Info) error {
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
@@ -131,6 +134,8 @@ func (l *Websocket) Listen(s *system.Info) error {
|
||||
l.listen.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
} else {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@@ -77,6 +77,18 @@ func TestWebsocketListen(t *testing.T) {
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLSConfig(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLS(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
@@ -27,9 +27,10 @@ var (
|
||||
|
||||
// Store is a backend for writing and reading to bolt persistent storage.
|
||||
type Store struct {
|
||||
path string // the path on which to store the db file.
|
||||
opts *bbolt.Options // options for configuring the boltdb instance.
|
||||
db *storm.DB // the boltdb instance.
|
||||
path string // the path on which to store the db file.
|
||||
opts *bbolt.Options // options for configuring the boltdb instance.
|
||||
db *storm.DB // the boltdb instance.
|
||||
inflightTTL int64 // the number of seconds an inflight message should be retained before being dropped.
|
||||
}
|
||||
|
||||
// New returns a configured instance of the boltdb store.
|
||||
@@ -50,6 +51,13 @@ func New(path string, opts *bbolt.Options) *Store {
|
||||
}
|
||||
}
|
||||
|
||||
// SetInflightTTL sets the number of seconds an inflight message should be kept
|
||||
// before being dropped, in the event it is not delivered. Unless you have a good reason,
|
||||
// you should allow this to be called by the server (in AddStore) instead of directly.
|
||||
func (s *Store) SetInflightTTL(seconds int64) {
|
||||
s.inflightTTL = seconds
|
||||
}
|
||||
|
||||
// Open opens the boltdb instance.
|
||||
func (s *Store) Open() error {
|
||||
var err error
|
||||
@@ -251,7 +259,7 @@ func (s *Store) ReadRetained() (v []persistence.Message, err error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
//ReadServerInfo loads the server info from the boltdb instance.
|
||||
// ReadServerInfo loads the server info from the boltdb instance.
|
||||
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
|
||||
if s.db == nil {
|
||||
return v, ErrDBNotOpen
|
||||
@@ -264,3 +272,27 @@ func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// ClearExpiredInflight deletes any inflight messages older than the provided unix timestamp.
|
||||
func (s *Store) ClearExpiredInflight(expiry int64) error {
|
||||
if s.db == nil {
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
var v []persistence.Message
|
||||
err := s.db.Find("T", persistence.KInflight, &v)
|
||||
if err != nil && err != storm.ErrNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
err := s.db.DeleteStruct(&persistence.Message{ID: m.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -49,6 +49,12 @@ func TestNewNoOpts(t *testing.T) {
|
||||
require.Equal(t, defaultTimeout, s.opts.Timeout)
|
||||
}
|
||||
|
||||
func TestSetInflightTTL(t *testing.T) {
|
||||
s := New("", nil)
|
||||
s.SetInflightTTL(5)
|
||||
require.Equal(t, int64(5), s.inflightTTL)
|
||||
}
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
s := New(tmpPath, nil)
|
||||
err := s.Open()
|
||||
@@ -484,3 +490,51 @@ func TestDeleteRetainedFail(t *testing.T) {
|
||||
err = s.DeleteRetained("a")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClearExpiredInflight(t *testing.T) {
|
||||
n := time.Now().Unix()
|
||||
|
||||
s := New(tmpPath, nil)
|
||||
err := s.Open()
|
||||
defer teardown(s, t)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.WriteInflight(persistence.Message{
|
||||
ID: "i1",
|
||||
T: persistence.KInflight,
|
||||
Created: n - 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.WriteInflight(persistence.Message{
|
||||
ID: "i2",
|
||||
T: persistence.KInflight,
|
||||
Created: n - 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.WriteInflight(persistence.Message{
|
||||
ID: "i3",
|
||||
T: persistence.KInflight,
|
||||
Created: n - 3,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.WriteInflight(persistence.Message{
|
||||
ID: "i5",
|
||||
T: persistence.KInflight,
|
||||
Created: n - 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
m, err := s.ReadInflight()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 4)
|
||||
|
||||
s.ClearExpiredInflight(n - 2)
|
||||
|
||||
m, err = s.ReadInflight()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 2)
|
||||
}
|
||||
|
@@ -28,22 +28,28 @@ const (
|
||||
type Store interface {
|
||||
Open() error
|
||||
Close()
|
||||
WriteSubscription(v Subscription) error
|
||||
WriteClient(v Client) error
|
||||
WriteInflight(v Message) error
|
||||
WriteServerInfo(v ServerInfo) error
|
||||
WriteRetained(v Message) error
|
||||
|
||||
DeleteSubscription(id string) error
|
||||
DeleteClient(id string) error
|
||||
DeleteInflight(id string) error
|
||||
DeleteRetained(id string) error
|
||||
|
||||
ReadSubscriptions() (v []Subscription, err error)
|
||||
ReadInflight() (v []Message, err error)
|
||||
ReadRetained() (v []Message, err error)
|
||||
WriteSubscription(v Subscription) error
|
||||
DeleteSubscription(id string) error
|
||||
|
||||
ReadClients() (v []Client, err error)
|
||||
WriteClient(v Client) error
|
||||
DeleteClient(id string) error
|
||||
|
||||
ReadInflight() (v []Message, err error)
|
||||
WriteInflight(v Message) error
|
||||
DeleteInflight(id string) error
|
||||
|
||||
SetInflightTTL(seconds int64)
|
||||
ClearExpiredInflight(expiry int64) error
|
||||
|
||||
ReadServerInfo() (v ServerInfo, err error)
|
||||
WriteServerInfo(v ServerInfo) error
|
||||
|
||||
ReadRetained() (v []Message, err error)
|
||||
WriteRetained(v Message) error
|
||||
DeleteRetained(id string) error
|
||||
}
|
||||
|
||||
// ServerInfo contains information and statistics about the server.
|
||||
@@ -69,6 +75,7 @@ type Message struct {
|
||||
ID string // the storage key.
|
||||
Client string // the id of the client who sent the message (if inflight).
|
||||
TopicName string // the topic the message was sent to (if retained).
|
||||
Created int64 // the time the message was created in unixtime (if inflight).
|
||||
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
|
||||
Resends int // the number of times the message was attempted to be sent (if inflight).
|
||||
PacketID uint16 // the unique id of the packet (if inflight).
|
||||
@@ -103,10 +110,16 @@ type LWT struct {
|
||||
|
||||
// MockStore is a mock storage backend for testing.
|
||||
type MockStore struct {
|
||||
Fail map[string]bool // issue errors for different methods.
|
||||
FailOpen bool // error on open.
|
||||
Closed bool // indicate mock store is closed.
|
||||
Opened bool // indicate mock store is open.
|
||||
Fail map[string]bool // issue errors for different methods.
|
||||
FailOpen bool // error on open.
|
||||
Closed bool // indicate mock store is closed.
|
||||
Opened bool // indicate mock store is open.
|
||||
inflightTTL int64 // inflight expiry duration.
|
||||
}
|
||||
|
||||
// Close closes the storage instance.
|
||||
func (s *MockStore) SetInflightTTL(seconds int64) {
|
||||
s.inflightTTL = seconds
|
||||
}
|
||||
|
||||
// Open opens the storage instance.
|
||||
@@ -275,7 +288,7 @@ func (s *MockStore) ReadRetained() (v []Message, err error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
//ReadServerInfo loads the server info from the storage instance.
|
||||
// ReadServerInfo loads the server info from the storage instance.
|
||||
func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
|
||||
if _, ok := s.Fail["read_info"]; ok {
|
||||
return v, errors.New("test_info")
|
||||
@@ -289,3 +302,8 @@ func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
|
||||
KServerInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadServerInfo loads the server info from the storage instance.
|
||||
func (s *MockStore) ClearExpiredInflight(d int64) error {
|
||||
return nil
|
||||
}
|
||||
|
@@ -26,6 +26,12 @@ func TestMockStoreClose(t *testing.T) {
|
||||
require.Equal(t, true, s.Closed)
|
||||
}
|
||||
|
||||
func TestMockStoreSetInflightTTL(t *testing.T) {
|
||||
s := new(MockStore)
|
||||
s.SetInflightTTL(5)
|
||||
require.Equal(t, int64(5), s.inflightTTL)
|
||||
}
|
||||
|
||||
func TestMockStoreWriteSubscription(t *testing.T) {
|
||||
s := new(MockStore)
|
||||
err := s.WriteSubscription(Subscription{})
|
||||
@@ -249,3 +255,9 @@ func TestMockStoreReadRetainedFail(t *testing.T) {
|
||||
_, err := s.ReadRetained()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMockStoreClearExpiredInflight(t *testing.T) {
|
||||
s := new(MockStore)
|
||||
err := s.ClearExpiredInflight(2)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
@@ -25,6 +25,9 @@ import (
|
||||
const (
|
||||
// Version indicates the current server version.
|
||||
Version = "1.1.1"
|
||||
|
||||
// defaultInflightTTL is the number of seconds a pending inflight message should last.
|
||||
defaultInflightTTL int64 = 60 * 60 * 24
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,17 +76,19 @@ var (
|
||||
// Server is an MQTT broker server. It should be created with server.New()
|
||||
// in order to ensure all the internal fields are correctly populated.
|
||||
type Server struct {
|
||||
inline inlineMessages // channels for direct publishing.
|
||||
Events events.Events // overrideable event hooks.
|
||||
Store persistence.Store // a persistent storage backend if desired.
|
||||
Options *Options // configurable server options.
|
||||
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
|
||||
Clients *clients.Clients // clients which are known to the broker.
|
||||
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
|
||||
System *system.Info // values about the server commonly found in $SYS topics.
|
||||
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
|
||||
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
|
||||
done chan bool // indicate that the server is ending.
|
||||
inline inlineMessages // channels for direct publishing.
|
||||
Events events.Events // overrideable event hooks.
|
||||
Store persistence.Store // a persistent storage backend if desired.
|
||||
Options *Options // configurable server options.
|
||||
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
|
||||
Clients *clients.Clients // clients which are known to the broker.
|
||||
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
|
||||
System *system.Info // values about the server commonly found in $SYS topics.
|
||||
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
|
||||
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
|
||||
inflightExpiryTicker *time.Ticker // the interval ticker for cleaning up expired messages.
|
||||
inflightResendTicker *time.Ticker // the interval ticker for resending unresolved inflight messages.
|
||||
done chan bool // indicate that the server is ending.
|
||||
}
|
||||
|
||||
// Options contains configurable options for the server.
|
||||
@@ -93,6 +98,9 @@ type Options struct {
|
||||
|
||||
// BufferBlockSize overrides the default buffer block size (DefaultBlockSize) for the client buffers.
|
||||
BufferBlockSize int
|
||||
|
||||
// InflightTTL specifies the duration that a queued inflight message should exist before being purged.
|
||||
InflightTTL int64
|
||||
}
|
||||
|
||||
// inlineMessages contains channels for handling inline (direct) publishing.
|
||||
@@ -114,6 +122,10 @@ func NewServer(opts *Options) *Server {
|
||||
opts = new(Options)
|
||||
}
|
||||
|
||||
if opts.InflightTTL < 1 {
|
||||
opts.InflightTTL = defaultInflightTTL
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
done: make(chan bool),
|
||||
bytepool: circ.NewBytesPool(opts.BufferSize),
|
||||
@@ -123,7 +135,9 @@ func NewServer(opts *Options) *Server {
|
||||
Version: Version,
|
||||
Started: time.Now().Unix(),
|
||||
},
|
||||
sysTicker: time.NewTicker(SysTopicInterval * time.Millisecond),
|
||||
sysTicker: time.NewTicker(SysTopicInterval * time.Millisecond),
|
||||
inflightExpiryTicker: time.NewTicker(time.Duration(opts.InflightTTL) * time.Second),
|
||||
inflightResendTicker: time.NewTicker(time.Duration(10) * time.Second),
|
||||
inline: inlineMessages{
|
||||
done: make(chan bool),
|
||||
pub: make(chan packets.Packet, 1024),
|
||||
@@ -143,6 +157,8 @@ func NewServer(opts *Options) *Server {
|
||||
// called before calling server.Server().
|
||||
func (s *Server) AddStore(p persistence.Store) error {
|
||||
s.Store = p
|
||||
s.Store.SetInflightTTL(s.Options.InflightTTL)
|
||||
|
||||
err := s.Store.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -198,6 +214,10 @@ func (s *Server) eventLoop() {
|
||||
return
|
||||
case <-s.sysTicker.C:
|
||||
s.publishSysTopics()
|
||||
case <-s.inflightExpiryTicker.C:
|
||||
s.clearExpiredInflights(time.Now().Unix())
|
||||
case <-s.inflightResendTicker.C:
|
||||
s.resendPendingInflights()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -319,9 +339,11 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
return s.onError(cl.Info(), fmt.Errorf("ack connection packet: %w", err))
|
||||
}
|
||||
|
||||
err = s.ResendClientInflight(cl, true)
|
||||
if err != nil {
|
||||
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
|
||||
if sessionPresent {
|
||||
err = s.ResendClientInflight(cl, true)
|
||||
if err != nil {
|
||||
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
|
||||
}
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
@@ -346,6 +368,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
|
||||
err = cl.StopCause() // Determine true cause of stop.
|
||||
|
||||
if !sessionPresent {
|
||||
s.clearAbandonedInflights(cl)
|
||||
}
|
||||
|
||||
if s.Events.OnDisconnect != nil {
|
||||
s.Events.OnDisconnect(cl.Info(), err)
|
||||
}
|
||||
@@ -379,6 +405,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *clients.Client) boo
|
||||
// The state associated with a CleanSession MUST NOT be reused in any subsequent session.
|
||||
if pk.CleanSession || existing.CleanSession {
|
||||
s.unsubscribeClient(existing)
|
||||
s.clearAbandonedInflights(existing)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -640,8 +667,9 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
// if an appropriate ack is not received (or if the client is offline).
|
||||
sent := time.Now().Unix()
|
||||
q := client.Inflight.Set(out.PacketID, clients.InflightMessage{
|
||||
Packet: out,
|
||||
Sent: sent,
|
||||
Packet: out,
|
||||
Created: time.Now().Unix(),
|
||||
Sent: sent,
|
||||
})
|
||||
if q {
|
||||
atomic.AddInt64(&s.System.Inflight, 1)
|
||||
@@ -876,7 +904,7 @@ func (s *Server) publishSysTopics() {
|
||||
// ResendClientInflight attempts to resend all undelivered inflight messages
|
||||
// to a client.
|
||||
func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
|
||||
if cl.Inflight.Len() == 0 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 || cl.Inflight.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1047,6 +1075,7 @@ func (s *Server) loadInflight(v []persistence.Message) {
|
||||
TopicName: msg.TopicName,
|
||||
Payload: msg.Payload,
|
||||
},
|
||||
Created: msg.Created,
|
||||
Sent: msg.Sent,
|
||||
Resends: msg.Resends,
|
||||
})
|
||||
@@ -1064,3 +1093,34 @@ func (s *Server) loadRetained(v []persistence.Message) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// clearExpiredInflights deletes all inflight messages older than server inflight TTL.
|
||||
func (s *Server) clearExpiredInflights(dt int64) {
|
||||
expiry := dt - s.Options.InflightTTL
|
||||
|
||||
for _, client := range s.Clients.GetAll() {
|
||||
client.Inflight.ClearExpired(expiry)
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.ClearExpiredInflight(expiry)
|
||||
}
|
||||
}
|
||||
|
||||
// clearAbandonedInflights deletes all inflight messages for a disconnected user (eg. with a clean session).
|
||||
func (s *Server) clearAbandonedInflights(cl *clients.Client) {
|
||||
for i := range cl.Inflight.GetAll() {
|
||||
cl.Inflight.Delete(i)
|
||||
}
|
||||
}
|
||||
|
||||
// resendPendingInflights attempts resends of any pending and due inflight messages.
|
||||
func (s *Server) resendPendingInflights() {
|
||||
for _, client := range s.Clients.GetAll() {
|
||||
err := s.ResendClientInflight(client, false)
|
||||
if err != nil {
|
||||
// TODO add log-level debugging.
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -368,9 +368,10 @@ func TestServerEventOnConnect(t *testing.T) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
require.Equal(t, events.Packet(packets.Packet{
|
||||
@@ -443,9 +444,10 @@ func TestServerEventOnDisconnect(t *testing.T) {
|
||||
w.Close()
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
require.ErrorIs(t, ErrClientDisconnect, hook.err)
|
||||
@@ -506,9 +508,10 @@ func TestServerEventOnDisconnectOnError(t *testing.T) {
|
||||
require.Equal(t, errx, hook.err)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
@@ -2701,3 +2704,74 @@ func TestServerResendClientInflightError(t *testing.T) {
|
||||
err := s.ResendClientInflight(cl, true)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestServerClearExpiredInflights(t *testing.T) {
|
||||
n := time.Now().Unix()
|
||||
|
||||
s := New()
|
||||
s.Options.InflightTTL = 2
|
||||
require.NotNil(t, s)
|
||||
|
||||
r, _ := net.Pipe()
|
||||
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
|
||||
cl.Inflight.Set(1, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 1,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(2, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 2,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(3, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 3,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(5, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 5,
|
||||
Sent: 0,
|
||||
})
|
||||
s.Clients.Add(cl)
|
||||
|
||||
require.Len(t, cl.Inflight.GetAll(), 4)
|
||||
s.clearExpiredInflights(n)
|
||||
require.Len(t, cl.Inflight.GetAll(), 2)
|
||||
}
|
||||
|
||||
func TestServerClearAbandonedInflights(t *testing.T) {
|
||||
s := New()
|
||||
require.NotNil(t, s)
|
||||
|
||||
r, _ := net.Pipe()
|
||||
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
|
||||
cl.Inflight.Set(1, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(2, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Sent: 0,
|
||||
})
|
||||
|
||||
cl2 := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
|
||||
|
||||
cl2.Inflight.Set(3, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Sent: 0,
|
||||
})
|
||||
cl2.Inflight.Set(5, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Sent: 0,
|
||||
})
|
||||
s.Clients.Add(cl)
|
||||
s.Clients.Add(cl2)
|
||||
|
||||
require.Len(t, cl.Inflight.GetAll(), 2)
|
||||
require.Len(t, cl2.Inflight.GetAll(), 2)
|
||||
s.clearAbandonedInflights(cl)
|
||||
require.Len(t, cl.Inflight.GetAll(), 0)
|
||||
require.Len(t, cl2.Inflight.GetAll(), 2)
|
||||
}
|
||||
|
Reference in New Issue
Block a user