Compare commits

..

12 Commits

Author SHA1 Message Date
mochi-co
84fc2f848b Abandon inflights at the end of clean-session connections 2022-08-16 21:41:39 +01:00
JB
8703d6d020 Merge pull request #90 from mochi-co/resend-inflights
Adds Inflight TTL and Period Resend
2022-08-16 21:31:42 +01:00
mochi-co
666440fe56 Adds Inflight TTL and Period Resend 2022-08-16 21:19:42 +01:00
JB
1ae050939a Merge pull request #84 from mochi-co/goreport-fixes
Goreport fixes
2022-06-22 15:52:36 +01:00
mochi
f4683d27d0 remove ineffective assignments 2022-06-22 15:45:13 +01:00
mochi
dff2b1db30 apply gofmt -s 2022-06-22 15:40:52 +01:00
JB
9de6b4e427 Merge pull request #83 from mochi-co/tls-client-auth
Expose tls.Config to Listeners
2022-06-22 15:32:23 +01:00
JB
78c1914270 Merge pull request #82 from mochi-co/expose-event-client-username
Add CleanSession and Username to events.Client struct
2022-06-22 15:31:31 +01:00
mochi
f71bf5c3d6 use TLSConfig instead of deprecated TLS field 2022-06-22 15:26:51 +01:00
mochi
53c4a6b09f Add TLSConfig field to allow direct tls.Config setting 2022-06-22 15:26:26 +01:00
mochi
a02c6bd8df update TLS example to use TLSConfig field 2022-06-22 15:25:52 +01:00
mochi
d8f6d63cc8 Add CleanSession and Username to events.Client struct 2022-06-22 12:33:09 +01:00
21 changed files with 522 additions and 87 deletions

View File

@@ -27,8 +27,9 @@ func main() {
// An example of configuring various server options...
options := &mqtt.Options{
BufferSize: 0, // Use default values
BufferBlockSize: 0, // Use default values
BufferSize: 0, // Use default values
BufferBlockSize: 0, // Use default values
InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
}
server := mqtt.NewServer(options)

View File

@@ -1,6 +1,7 @@
package main
import (
"crypto/tls"
"fmt"
"log"
"os"
@@ -57,14 +58,30 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
// Optionally, if you want clients to authenticate only with certs issued by your CA,
// you might want to use something like this:
// certPool := x509.NewCertPool()
// _ = certPool.AppendCertsFromPEM(caCertPem)
// tlsConfig := &tls.Config{
// ClientCAs: certPool,
// ClientAuth: tls.RequireAndVerifyClientCert,
// }
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
err = server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
@@ -72,11 +89,8 @@ func main() {
ws := listeners.NewWebsocket("ws1", ":1882")
err = server.AddListener(ws, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
@@ -84,11 +98,8 @@ func main() {
stats := listeners.NewHTTPStats("stats", ":8080")
err = server.AddListener(stats, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)

View File

@@ -20,9 +20,11 @@ type Packet packets.Packet
// Client contains limited information about a connected client.
type Client struct {
ID string
Remote string
Listener string
ID string
Remote string
Listener string
Username []byte
CleanSession bool
}
// Clientlike is an interface for Clients and client-like objects that
@@ -42,7 +44,7 @@ type Clientlike interface {
// be dispatched as if the event hook had not been triggered.
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any further
// processing of the packet.
type OnProcessMessage func(Client, Packet) (Packet, error)

View File

@@ -49,6 +49,13 @@ func (cl *Clients) Add(val *Client) {
cl.Unlock()
}
// GetAll returns all the clients.
func (cl *Clients) GetAll() map[string]*Client {
cl.RLock()
defer cl.RUnlock()
return cl.internal
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
@@ -200,9 +207,11 @@ func (cl *Client) Info() events.Client {
addr = cl.conn.RemoteAddr().String()
}
return events.Client{
ID: cl.ID,
Remote: addr,
Listener: cl.Listener,
ID: cl.ID,
Remote: addr,
Username: cl.Username,
CleanSession: cl.CleanSession,
Listener: cl.Listener,
}
}
@@ -509,6 +518,7 @@ type LWT struct {
type InflightMessage struct {
Packet packets.Packet // the packet currently in-flight.
Sent int64 // the last time the message was sent (for retries) in unixtime.
Created int64 // the unix timestamp when the inflight message was created.
Resends int // the number of times the message was attempted to be sent.
}
@@ -555,8 +565,21 @@ func (i *Inflight) GetAll() map[uint16]InflightMessage {
// message existed.
func (i *Inflight) Delete(key uint16) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[key]
delete(i.internal, key)
i.Unlock()
return ok
}
// ClearExpired deletes any inflight messages that have remained longer than
// the servers InflightTTL duration.
func (i *Inflight) ClearExpired(expiry int64) {
i.Lock()
defer i.Unlock()
for k, m := range i.internal {
if m.Created < expiry || m.Created == 0 {
delete(i.internal, k)
}
}
}

View File

@@ -69,6 +69,36 @@ func BenchmarkClientsGet(b *testing.B) {
}
}
func TestClientsGetAll(t *testing.T) {
cl := New()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
cl.Add(&Client{ID: "t3"})
cl.Add(&Client{ID: "t4"})
cl.Add(&Client{ID: "t5"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Contains(t, cl.internal, "t3")
require.Contains(t, cl.internal, "t4")
require.Contains(t, cl.internal, "t5")
clients := cl.GetAll()
require.Len(t, clients, 5)
}
func BenchmarkClientsGetAll(b *testing.B) {
cl := New()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
cl.Add(&Client{ID: "t3"})
cl.Add(&Client{ID: "t4"})
cl.Add(&Client{ID: "t5"})
for n := 0; n < b.N; n++ {
clients := cl.GetAll()
require.Len(b, clients, 5)
}
}
func TestClientsLen(t *testing.T) {
cl := New()
cl.Add(&Client{ID: "t1"})
@@ -714,12 +744,13 @@ func TestClientWritePacket(t *testing.T) {
r.Close()
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
cl.Stop(testClientStop)
time.Sleep(time.Millisecond * 1)
// The stop cause is either the test error, EOF, or a
// closed pipe, depending on which goroutine runs first.
err = cl.StopCause()
time.Sleep(time.Millisecond * 5)
require.True(t,
errors.Is(err, testClientStop) ||
errors.Is(err, io.EOF) ||
@@ -858,6 +889,42 @@ func BenchmarkInflightDelete(b *testing.B) {
}
}
func TestInflightClearExpired(t *testing.T) {
n := time.Now().Unix()
cl := genClient()
cl.Inflight.Set(1, InflightMessage{
Packet: packets.Packet{},
Created: n - 1,
Sent: 0,
})
cl.Inflight.Set(2, InflightMessage{
Packet: packets.Packet{},
Created: n - 2,
Sent: 0,
})
cl.Inflight.Set(3, InflightMessage{
Packet: packets.Packet{},
Created: n - 3,
Sent: 0,
})
cl.Inflight.Set(5, InflightMessage{
Packet: packets.Packet{},
Created: n - 5,
Sent: 0,
})
require.Len(t, cl.Inflight.internal, 4)
cl.Inflight.ClearExpired(n - 2)
cl.Inflight.RLock()
defer cl.Inflight.RUnlock()
require.Len(t, cl.Inflight.internal, 2)
require.Equal(t, (n - 1), cl.Inflight.internal[1].Created)
require.Equal(t, (n - 2), cl.Inflight.internal[2].Created)
require.Equal(t, int64(0), cl.Inflight.internal[3].Created)
}
var (
pkTable = []struct {
bytes []byte

View File

@@ -21,7 +21,7 @@ func BenchmarkBytesToString(b *testing.B) {
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
name string
rawBytes []byte
result string
offset int
@@ -88,8 +88,8 @@ func TestDecodeString(t *testing.T) {
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrOffsetStrInvalidUTF8,
},
}
@@ -250,13 +250,12 @@ func TestDecodeUint16(t *testing.T) {
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
@@ -295,7 +294,7 @@ func TestDecodeByteBool(t *testing.T) {
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return

View File

@@ -215,7 +215,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
}
if pk.PasswordFlag {
pk.Password, offset, err = decodeBytes(buf, offset)
pk.Password, _, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
}
@@ -287,7 +287,7 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
}
pk.ReturnCode, offset, err = decodeByte(buf, offset)
pk.ReturnCode, _, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
}

View File

@@ -70,6 +70,9 @@ func (l *HTTPStats) Listen(s *system.Info) error {
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
@@ -79,6 +82,8 @@ func (l *HTTPStats) Listen(s *system.Info) error {
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil

View File

@@ -73,6 +73,18 @@ func TestHTTPStatsListen(t *testing.T) {
l.listen.Close()
}
func TestHTTPStatsListenTLSConfig(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestHTTPStatsListenTLS(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{

View File

@@ -1,6 +1,7 @@
package listeners
import (
"crypto/tls"
"net"
"sync"
@@ -10,11 +11,25 @@ import (
// Config contains configuration values for a listener.
type Config struct {
Auth auth.Controller // an authentication controller containing auth and ACL logic.
TLS *TLS // the TLS certficates and settings for the connection.
// Auth controller containing auth and ACL logic for
// allowing or denying access to the server and topics.
Auth auth.Controller
// TLS certficates and settings for the connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
TLS *TLS
// TLSConfig is a tls.Config configuration to be used with the listener.
// See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config
}
// TLS contains the TLS certificates and settings for the listener connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
type TLS struct {
Certificate []byte // the body of a public certificate.
PrivateKey []byte // the body of a private key.

View File

@@ -1,6 +1,8 @@
package listeners
import (
"crypto/tls"
"log"
"testing"
"time"
@@ -37,8 +39,22 @@ WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
tlsConfigBasic *tls.Config
)
func init() {
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfigBasic = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
func TestNew(t *testing.T) {
l := New(nil)
require.NotNil(t, l.internal)

View File

@@ -62,6 +62,9 @@ func (l *TCP) ID() string {
func (l *TCP) Listen(s *system.Info) error {
var err error
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
@@ -72,9 +75,12 @@ func (l *TCP) Listen(s *system.Info) error {
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
} else if l.config.TLSConfig != nil {
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
} else {
l.listen, err = net.Listen(l.protocol, l.address)
}
if err != nil {
return err
}

View File

@@ -73,6 +73,17 @@ func TestTCPListen(t *testing.T) {
l.listen.Close()
}
func TestTCPListenTLSConfig(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
l.listen.Close()
}
func TestTCPListenTLS(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{

View File

@@ -122,6 +122,9 @@ func (l *Websocket) Listen(s *system.Info) error {
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
@@ -131,6 +134,8 @@ func (l *Websocket) Listen(s *system.Info) error {
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil

View File

@@ -77,6 +77,18 @@ func TestWebsocketListen(t *testing.T) {
require.NotNil(t, l.listen)
}
func TestWebsocketListenTLSConfig(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestWebsocketListenTLS(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{

View File

@@ -27,9 +27,10 @@ var (
// Store is a backend for writing and reading to bolt persistent storage.
type Store struct {
path string // the path on which to store the db file.
opts *bbolt.Options // options for configuring the boltdb instance.
db *storm.DB // the boltdb instance.
path string // the path on which to store the db file.
opts *bbolt.Options // options for configuring the boltdb instance.
db *storm.DB // the boltdb instance.
inflightTTL int64 // the number of seconds an inflight message should be retained before being dropped.
}
// New returns a configured instance of the boltdb store.
@@ -50,6 +51,13 @@ func New(path string, opts *bbolt.Options) *Store {
}
}
// SetInflightTTL sets the number of seconds an inflight message should be kept
// before being dropped, in the event it is not delivered. Unless you have a good reason,
// you should allow this to be called by the server (in AddStore) instead of directly.
func (s *Store) SetInflightTTL(seconds int64) {
s.inflightTTL = seconds
}
// Open opens the boltdb instance.
func (s *Store) Open() error {
var err error
@@ -251,7 +259,7 @@ func (s *Store) ReadRetained() (v []persistence.Message, err error) {
return v, nil
}
//ReadServerInfo loads the server info from the boltdb instance.
// ReadServerInfo loads the server info from the boltdb instance.
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
if s.db == nil {
return v, ErrDBNotOpen
@@ -264,3 +272,27 @@ func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
return v, nil
}
// ClearExpiredInflight deletes any inflight messages older than the provided unix timestamp.
func (s *Store) ClearExpiredInflight(expiry int64) error {
if s.db == nil {
return ErrDBNotOpen
}
var v []persistence.Message
err := s.db.Find("T", persistence.KInflight, &v)
if err != nil && err != storm.ErrNotFound {
return err
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := s.db.DeleteStruct(&persistence.Message{ID: m.ID})
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -49,6 +49,12 @@ func TestNewNoOpts(t *testing.T) {
require.Equal(t, defaultTimeout, s.opts.Timeout)
}
func TestSetInflightTTL(t *testing.T) {
s := New("", nil)
s.SetInflightTTL(5)
require.Equal(t, int64(5), s.inflightTTL)
}
func TestOpen(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
@@ -484,3 +490,51 @@ func TestDeleteRetainedFail(t *testing.T) {
err = s.DeleteRetained("a")
require.Error(t, err)
}
func TestClearExpiredInflight(t *testing.T) {
n := time.Now().Unix()
s := New(tmpPath, nil)
err := s.Open()
defer teardown(s, t)
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{
ID: "i1",
T: persistence.KInflight,
Created: n - 1,
})
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{
ID: "i2",
T: persistence.KInflight,
Created: n - 2,
})
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{
ID: "i3",
T: persistence.KInflight,
Created: n - 3,
})
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{
ID: "i5",
T: persistence.KInflight,
Created: n - 5,
})
require.NoError(t, err)
m, err := s.ReadInflight()
require.NoError(t, err)
require.Len(t, m, 4)
s.ClearExpiredInflight(n - 2)
m, err = s.ReadInflight()
require.NoError(t, err)
require.Len(t, m, 2)
}

View File

@@ -28,22 +28,28 @@ const (
type Store interface {
Open() error
Close()
WriteSubscription(v Subscription) error
WriteClient(v Client) error
WriteInflight(v Message) error
WriteServerInfo(v ServerInfo) error
WriteRetained(v Message) error
DeleteSubscription(id string) error
DeleteClient(id string) error
DeleteInflight(id string) error
DeleteRetained(id string) error
ReadSubscriptions() (v []Subscription, err error)
ReadInflight() (v []Message, err error)
ReadRetained() (v []Message, err error)
WriteSubscription(v Subscription) error
DeleteSubscription(id string) error
ReadClients() (v []Client, err error)
WriteClient(v Client) error
DeleteClient(id string) error
ReadInflight() (v []Message, err error)
WriteInflight(v Message) error
DeleteInflight(id string) error
SetInflightTTL(seconds int64)
ClearExpiredInflight(expiry int64) error
ReadServerInfo() (v ServerInfo, err error)
WriteServerInfo(v ServerInfo) error
ReadRetained() (v []Message, err error)
WriteRetained(v Message) error
DeleteRetained(id string) error
}
// ServerInfo contains information and statistics about the server.
@@ -69,6 +75,7 @@ type Message struct {
ID string // the storage key.
Client string // the id of the client who sent the message (if inflight).
TopicName string // the topic the message was sent to (if retained).
Created int64 // the time the message was created in unixtime (if inflight).
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
Resends int // the number of times the message was attempted to be sent (if inflight).
PacketID uint16 // the unique id of the packet (if inflight).
@@ -103,10 +110,16 @@ type LWT struct {
// MockStore is a mock storage backend for testing.
type MockStore struct {
Fail map[string]bool // issue errors for different methods.
FailOpen bool // error on open.
Closed bool // indicate mock store is closed.
Opened bool // indicate mock store is open.
Fail map[string]bool // issue errors for different methods.
FailOpen bool // error on open.
Closed bool // indicate mock store is closed.
Opened bool // indicate mock store is open.
inflightTTL int64 // inflight expiry duration.
}
// Close closes the storage instance.
func (s *MockStore) SetInflightTTL(seconds int64) {
s.inflightTTL = seconds
}
// Open opens the storage instance.
@@ -275,7 +288,7 @@ func (s *MockStore) ReadRetained() (v []Message, err error) {
}, nil
}
//ReadServerInfo loads the server info from the storage instance.
// ReadServerInfo loads the server info from the storage instance.
func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
if _, ok := s.Fail["read_info"]; ok {
return v, errors.New("test_info")
@@ -289,3 +302,8 @@ func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
KServerInfo,
}, nil
}
// ReadServerInfo loads the server info from the storage instance.
func (s *MockStore) ClearExpiredInflight(d int64) error {
return nil
}

View File

@@ -26,6 +26,12 @@ func TestMockStoreClose(t *testing.T) {
require.Equal(t, true, s.Closed)
}
func TestMockStoreSetInflightTTL(t *testing.T) {
s := new(MockStore)
s.SetInflightTTL(5)
require.Equal(t, int64(5), s.inflightTTL)
}
func TestMockStoreWriteSubscription(t *testing.T) {
s := new(MockStore)
err := s.WriteSubscription(Subscription{})
@@ -249,3 +255,9 @@ func TestMockStoreReadRetainedFail(t *testing.T) {
_, err := s.ReadRetained()
require.Error(t, err)
}
func TestMockStoreClearExpiredInflight(t *testing.T) {
s := new(MockStore)
err := s.ClearExpiredInflight(2)
require.NoError(t, err)
}

View File

@@ -25,6 +25,9 @@ import (
const (
// Version indicates the current server version.
Version = "1.1.1"
// defaultInflightTTL is the number of seconds a pending inflight message should last.
defaultInflightTTL int64 = 60 * 60 * 24
)
var (
@@ -73,17 +76,19 @@ var (
// Server is an MQTT broker server. It should be created with server.New()
// in order to ensure all the internal fields are correctly populated.
type Server struct {
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
Store persistence.Store // a persistent storage backend if desired.
Options *Options // configurable server options.
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
Clients *clients.Clients // clients which are known to the broker.
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
System *system.Info // values about the server commonly found in $SYS topics.
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
done chan bool // indicate that the server is ending.
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
Store persistence.Store // a persistent storage backend if desired.
Options *Options // configurable server options.
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
Clients *clients.Clients // clients which are known to the broker.
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
System *system.Info // values about the server commonly found in $SYS topics.
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
inflightExpiryTicker *time.Ticker // the interval ticker for cleaning up expired messages.
inflightResendTicker *time.Ticker // the interval ticker for resending unresolved inflight messages.
done chan bool // indicate that the server is ending.
}
// Options contains configurable options for the server.
@@ -93,6 +98,9 @@ type Options struct {
// BufferBlockSize overrides the default buffer block size (DefaultBlockSize) for the client buffers.
BufferBlockSize int
// InflightTTL specifies the duration that a queued inflight message should exist before being purged.
InflightTTL int64
}
// inlineMessages contains channels for handling inline (direct) publishing.
@@ -114,6 +122,10 @@ func NewServer(opts *Options) *Server {
opts = new(Options)
}
if opts.InflightTTL < 1 {
opts.InflightTTL = defaultInflightTTL
}
s := &Server{
done: make(chan bool),
bytepool: circ.NewBytesPool(opts.BufferSize),
@@ -123,7 +135,9 @@ func NewServer(opts *Options) *Server {
Version: Version,
Started: time.Now().Unix(),
},
sysTicker: time.NewTicker(SysTopicInterval * time.Millisecond),
sysTicker: time.NewTicker(SysTopicInterval * time.Millisecond),
inflightExpiryTicker: time.NewTicker(time.Duration(opts.InflightTTL) * time.Second),
inflightResendTicker: time.NewTicker(time.Duration(10) * time.Second),
inline: inlineMessages{
done: make(chan bool),
pub: make(chan packets.Packet, 1024),
@@ -143,6 +157,8 @@ func NewServer(opts *Options) *Server {
// called before calling server.Server().
func (s *Server) AddStore(p persistence.Store) error {
s.Store = p
s.Store.SetInflightTTL(s.Options.InflightTTL)
err := s.Store.Open()
if err != nil {
return err
@@ -198,6 +214,10 @@ func (s *Server) eventLoop() {
return
case <-s.sysTicker.C:
s.publishSysTopics()
case <-s.inflightExpiryTicker.C:
s.clearExpiredInflights(time.Now().Unix())
case <-s.inflightResendTicker.C:
s.resendPendingInflights()
}
}
}
@@ -319,9 +339,11 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
return s.onError(cl.Info(), fmt.Errorf("ack connection packet: %w", err))
}
err = s.ResendClientInflight(cl, true)
if err != nil {
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
if sessionPresent {
err = s.ResendClientInflight(cl, true)
if err != nil {
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
}
}
if s.Store != nil {
@@ -346,6 +368,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
err = cl.StopCause() // Determine true cause of stop.
if !sessionPresent {
s.clearAbandonedInflights(cl)
}
if s.Events.OnDisconnect != nil {
s.Events.OnDisconnect(cl.Info(), err)
}
@@ -379,6 +405,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *clients.Client) boo
// The state associated with a CleanSession MUST NOT be reused in any subsequent session.
if pk.CleanSession || existing.CleanSession {
s.unsubscribeClient(existing)
s.clearAbandonedInflights(existing)
return false
}
@@ -640,8 +667,9 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
// if an appropriate ack is not received (or if the client is offline).
sent := time.Now().Unix()
q := client.Inflight.Set(out.PacketID, clients.InflightMessage{
Packet: out,
Sent: sent,
Packet: out,
Created: time.Now().Unix(),
Sent: sent,
})
if q {
atomic.AddInt64(&s.System.Inflight, 1)
@@ -876,7 +904,7 @@ func (s *Server) publishSysTopics() {
// ResendClientInflight attempts to resend all undelivered inflight messages
// to a client.
func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
if cl.Inflight.Len() == 0 {
if atomic.LoadUint32(&cl.State.Done) == 1 || cl.Inflight.Len() == 0 {
return nil
}
@@ -1047,6 +1075,7 @@ func (s *Server) loadInflight(v []persistence.Message) {
TopicName: msg.TopicName,
Payload: msg.Payload,
},
Created: msg.Created,
Sent: msg.Sent,
Resends: msg.Resends,
})
@@ -1064,3 +1093,34 @@ func (s *Server) loadRetained(v []persistence.Message) {
})
}
}
// clearExpiredInflights deletes all inflight messages older than server inflight TTL.
func (s *Server) clearExpiredInflights(dt int64) {
expiry := dt - s.Options.InflightTTL
for _, client := range s.Clients.GetAll() {
client.Inflight.ClearExpired(expiry)
}
if s.Store != nil {
s.Store.ClearExpiredInflight(expiry)
}
}
// clearAbandonedInflights deletes all inflight messages for a disconnected user (eg. with a clean session).
func (s *Server) clearAbandonedInflights(cl *clients.Client) {
for i := range cl.Inflight.GetAll() {
cl.Inflight.Delete(i)
}
}
// resendPendingInflights attempts resends of any pending and due inflight messages.
func (s *Server) resendPendingInflights() {
for _, client := range s.Clients.GetAll() {
err := s.ResendClientInflight(client, false)
if err != nil {
// TODO add log-level debugging.
continue
}
}
}

View File

@@ -368,9 +368,10 @@ func TestServerEventOnConnect(t *testing.T) {
time.Sleep(10 * time.Millisecond)
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
require.Equal(t, events.Packet(packets.Packet{
@@ -443,9 +444,10 @@ func TestServerEventOnDisconnect(t *testing.T) {
w.Close()
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
require.ErrorIs(t, ErrClientDisconnect, hook.err)
@@ -506,9 +508,10 @@ func TestServerEventOnDisconnectOnError(t *testing.T) {
require.Equal(t, errx, hook.err)
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
clw, ok := s.Clients.Get("mochi")
@@ -2701,3 +2704,74 @@ func TestServerResendClientInflightError(t *testing.T) {
err := s.ResendClientInflight(cl, true)
require.Error(t, err)
}
func TestServerClearExpiredInflights(t *testing.T) {
n := time.Now().Unix()
s := New()
s.Options.InflightTTL = 2
require.NotNil(t, s)
r, _ := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Inflight.Set(1, clients.InflightMessage{
Packet: packets.Packet{},
Created: n - 1,
Sent: 0,
})
cl.Inflight.Set(2, clients.InflightMessage{
Packet: packets.Packet{},
Created: n - 2,
Sent: 0,
})
cl.Inflight.Set(3, clients.InflightMessage{
Packet: packets.Packet{},
Created: n - 3,
Sent: 0,
})
cl.Inflight.Set(5, clients.InflightMessage{
Packet: packets.Packet{},
Created: n - 5,
Sent: 0,
})
s.Clients.Add(cl)
require.Len(t, cl.Inflight.GetAll(), 4)
s.clearExpiredInflights(n)
require.Len(t, cl.Inflight.GetAll(), 2)
}
func TestServerClearAbandonedInflights(t *testing.T) {
s := New()
require.NotNil(t, s)
r, _ := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Inflight.Set(1, clients.InflightMessage{
Packet: packets.Packet{},
Sent: 0,
})
cl.Inflight.Set(2, clients.InflightMessage{
Packet: packets.Packet{},
Sent: 0,
})
cl2 := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl2.Inflight.Set(3, clients.InflightMessage{
Packet: packets.Packet{},
Sent: 0,
})
cl2.Inflight.Set(5, clients.InflightMessage{
Packet: packets.Packet{},
Sent: 0,
})
s.Clients.Add(cl)
s.Clients.Add(cl2)
require.Len(t, cl.Inflight.GetAll(), 2)
require.Len(t, cl2.Inflight.GetAll(), 2)
s.clearAbandonedInflights(cl)
require.Len(t, cl.Inflight.GetAll(), 0)
require.Len(t, cl2.Inflight.GetAll(), 2)
}