Compare commits

...

25 Commits

Author SHA1 Message Date
mochi-co
ecbd07fa3a Check against the correct clean session var for abandoning old inflights 2022-08-18 00:19:11 +01:00
zynzel
ad8bf2a931 Keep in sync server.System.Inflight (#92)
* Keep in sync server.System.Inflight

* Fix args order in tests
2022-08-17 23:58:43 +01:00
JB
b8fb068bb9 Update README.md 2022-08-16 22:21:23 +01:00
JB
c1348a37b8 Update README.md 2022-08-16 22:20:50 +01:00
mochi-co
84fc2f848b Abandon inflights at the end of clean-session connections 2022-08-16 21:41:39 +01:00
JB
8703d6d020 Merge pull request #90 from mochi-co/resend-inflights
Adds Inflight TTL and Period Resend
2022-08-16 21:31:42 +01:00
mochi-co
666440fe56 Adds Inflight TTL and Period Resend 2022-08-16 21:19:42 +01:00
JB
1ae050939a Merge pull request #84 from mochi-co/goreport-fixes
Goreport fixes
2022-06-22 15:52:36 +01:00
mochi
f4683d27d0 remove ineffective assignments 2022-06-22 15:45:13 +01:00
mochi
dff2b1db30 apply gofmt -s 2022-06-22 15:40:52 +01:00
JB
9de6b4e427 Merge pull request #83 from mochi-co/tls-client-auth
Expose tls.Config to Listeners
2022-06-22 15:32:23 +01:00
JB
78c1914270 Merge pull request #82 from mochi-co/expose-event-client-username
Add CleanSession and Username to events.Client struct
2022-06-22 15:31:31 +01:00
mochi
f71bf5c3d6 use TLSConfig instead of deprecated TLS field 2022-06-22 15:26:51 +01:00
mochi
53c4a6b09f Add TLSConfig field to allow direct tls.Config setting 2022-06-22 15:26:26 +01:00
mochi
a02c6bd8df update TLS example to use TLSConfig field 2022-06-22 15:25:52 +01:00
mochi
d8f6d63cc8 Add CleanSession and Username to events.Client struct 2022-06-22 12:33:09 +01:00
mochi
bef13eec20 Add OnSubscribe, OnUnsubscribe events examples 2022-05-04 12:58:23 +01:00
mochi
27f3c484ad Extend onsusbcribe, onunsubscribe events 2022-05-04 12:53:04 +01:00
JB
9b5cdb0bcc Merge pull request #74 from muXxer/feat/topic-subscription-events 2022-05-04 12:33:12 +01:00
muXxer
2b60a11d4a Add topic un-/subscribe events 2022-04-28 00:48:20 +02:00
JB
b53774f818 Merge pull request #72 from BoskyWSMFN/master
fix-panic
2022-04-19 08:58:46 +01:00
BoskyWSMFN
7dee729afb fix-panic
fixed runtime panic in server/internal/circ/pool.go occurring on 32-bits architectures caused by misalignment of BytesPool struct members.

https://github.com/golang/go/issues/36606#issue-551005857
2022-04-19 00:26:44 +03:00
mochi
aed535b7bf fix comments 2022-04-13 10:46:13 +01:00
mochi
4ff888ab3b Add auth controller example 2022-04-13 10:37:24 +01:00
JB
31252c081b Add Docker info 2022-04-10 20:46:16 +01:00
25 changed files with 736 additions and 100 deletions

View File

@@ -30,22 +30,31 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
- Interfaces for Client Authentication and Topic access control.
- Bolt persistence and storage interfaces (see examples folder).
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
- Basic Event Hooks (`OnMessage`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
- ARM32 Compatible.
- Basic Event Hooks (`OnMessage`, `onSubscribe`, `onUnsubscribe`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
- ARM32 Compatible (v1.1.1).
#### Roadmap
- Please open an issue to request new features or event hooks.
- MQTT v5 compatibility?
#### Using the Broker
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the `main.go` entrypoint in the `cmd` folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners. A docker image is coming soon.
#### Using the Broker from Go
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the `main.go` entrypoint in the `cmd` folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners.
```
cd cmd
go build -o mqtt && ./mqtt
```
#### Quick Start
#### Using Docker
A simple Dockerfile is provided for running the `cmd/main.go` Websocket, TCP, and Stats server:
```sh
docker build -t mochi:latest .
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
```
#### Package Quick Start
``` go
import (
@@ -76,6 +85,8 @@ func main() {
Examples of running the broker with various configurations can be found in the `examples` folder.
#### Network Listeners
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
- `listeners.NewTCP(id, address string)` - A TCP Listener, taking a unique ID and a network address to bind.
@@ -130,7 +141,25 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
```go
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
}
```
##### OnSubscribe
`server.Events.OnSubscribe` is called when a client subscribes to a new topic filter.
```go
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
}
```
##### OnUnsubscribe
`server.Events.OnUnsubscribe` is called when a client unsubscribes from a topic filter.
```go
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
}
```
@@ -147,7 +176,6 @@ If an error is returned, the packet will not be modified. and the existing packe
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
```go
import "github.com/mochi-co/mqtt/server/events"
@@ -164,6 +192,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
##### OnError
`server.Events.OnError` is called when an error is encountered on the server, particularly within the use of a client connection status.

105
examples/auth/main.go Normal file
View File

@@ -0,0 +1,105 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: &Auth{
Users: map[string]string{
"peach": "password1",
"melon": "password2",
"apple": "password3",
},
AllowedTopics: map[string][]string{
// Melon user only has access to melon topics.
// If you were implementing this in the real world, you might ensure
// that any topic prefixed with "melon" is allowed (see ACL func below).
"melon": {"melon/info", "melon/events"},
},
},
})
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}
// Auth is an example auth provider for the server. In the real world
// you are more likely to replace these fields with database/cache lookups
// to check against an auth list. As the Auth Controller is an interface, it can
// be built however you want, as long as it fulfils the interface signature.
type Auth struct {
Users map[string]string // A map of usernames (key) with passwords (value).
AllowedTopics map[string][]string // A map of usernames and topics
}
// Authenticate returns true if a username and password are acceptable.
func (a *Auth) Authenticate(user, password []byte) bool {
// If the user exists in the auth users map, and the password is correct,
// then they can connect to the server. In the real world, this could be a database
// or cached users lookup.
if pass, ok := a.Users[string(user)]; ok && pass == string(password) {
return true
}
return false
}
// ACL returns true if a user has access permissions to read or write on a topic.
func (a *Auth) ACL(user []byte, topic string, write bool) bool {
// An example ACL - if the user has an entry in the auth allow list, then they are
// subject to ACL restrictions. Only let them use a topic if it's available for their
// user.
if topics, ok := a.AllowedTopics[string(user)]; ok {
for _, t := range topics {
// In the real world you might allow all topics prefixed with a user's username,
// or similar multi-topic filters.
if t == topic {
return true
}
}
return false
}
// Otherwise, allow all topics.
return true
}

View File

@@ -54,6 +54,16 @@ func main() {
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
}
// Add OnSubscribe Event Hook
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
}
// Add OnUnsubscribe Event Hook
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
}
// Add OnMessage Event Hook
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
pkx = pk

View File

@@ -27,8 +27,9 @@ func main() {
// An example of configuring various server options...
options := &mqtt.Options{
BufferSize: 0, // Use default values
BufferBlockSize: 0, // Use default values
BufferSize: 0, // Use default values
BufferBlockSize: 0, // Use default values
InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
}
server := mqtt.NewServer(options)

View File

@@ -1,6 +1,7 @@
package main
import (
"crypto/tls"
"fmt"
"log"
"os"
@@ -57,14 +58,30 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
// Optionally, if you want clients to authenticate only with certs issued by your CA,
// you might want to use something like this:
// certPool := x509.NewCertPool()
// _ = certPool.AppendCertsFromPEM(caCertPem)
// tlsConfig := &tls.Config{
// ClientCAs: certPool,
// ClientAuth: tls.RequireAndVerifyClientCert,
// }
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
err = server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
@@ -72,11 +89,8 @@ func main() {
ws := listeners.NewWebsocket("ws1", ":1882")
err = server.AddListener(ws, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
@@ -84,11 +98,8 @@ func main() {
stats := listeners.NewHTTPStats("stats", ":8080")
err = server.AddListener(stats, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)

View File

@@ -11,6 +11,8 @@ type Events struct {
OnError // server error.
OnConnect // client connected.
OnDisconnect // client disconnected.
OnSubscribe // topic subscription created.
OnUnsubscribe // topic subscription removed.
}
// Packets is an alias for packets.Packet.
@@ -18,9 +20,11 @@ type Packet packets.Packet
// Client contains limited information about a connected client.
type Client struct {
ID string
Remote string
Listener string
ID string
Remote string
Listener string
Username []byte
CleanSession bool
}
// Clientlike is an interface for Clients and client-like objects that
@@ -40,7 +44,7 @@ type Clientlike interface {
// be dispatched as if the event hook had not been triggered.
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any further
// processing of the packet.
type OnProcessMessage func(Client, Packet) (Packet, error)
@@ -66,3 +70,9 @@ type OnDisconnect func(Client, error)
// OnError is called when errors that will not be passed to
// OnDisconnect are handled by the server.
type OnError func(Client, error)
// OnSubscribe is called when a new subscription filter for a client is created.
type OnSubscribe func(filter string, cl Client, qos byte)
// OnUnsubscribe is called when an existing subscription filter for a client is removed.
type OnUnsubscribe func(filter string, cl Client)

View File

@@ -7,8 +7,10 @@ import (
// BytesPool is a pool of []byte.
type BytesPool struct {
// int64/uint64 has to the first words in order
// to be 64-aligned on 32-bit architectures.
used int64 // access atomically
pool *sync.Pool
used int64
}
// NewBytesPool returns a sync.pool of []byte.

View File

@@ -49,6 +49,13 @@ func (cl *Clients) Add(val *Client) {
cl.Unlock()
}
// GetAll returns all the clients.
func (cl *Clients) GetAll() map[string]*Client {
cl.RLock()
defer cl.RUnlock()
return cl.internal
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
@@ -200,9 +207,11 @@ func (cl *Client) Info() events.Client {
addr = cl.conn.RemoteAddr().String()
}
return events.Client{
ID: cl.ID,
Remote: addr,
Listener: cl.Listener,
ID: cl.ID,
Remote: addr,
Username: cl.Username,
CleanSession: cl.CleanSession,
Listener: cl.Listener,
}
}
@@ -509,6 +518,7 @@ type LWT struct {
type InflightMessage struct {
Packet packets.Packet // the packet currently in-flight.
Sent int64 // the last time the message was sent (for retries) in unixtime.
Created int64 // the unix timestamp when the inflight message was created.
Resends int // the number of times the message was attempted to be sent.
}
@@ -555,8 +565,25 @@ func (i *Inflight) GetAll() map[uint16]InflightMessage {
// message existed.
func (i *Inflight) Delete(key uint16) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[key]
delete(i.internal, key)
i.Unlock()
return ok
}
// ClearExpired deletes any inflight messages that have remained longer than
// the servers InflightTTL duration. Returns number of deleted inflights.
func (i *Inflight) ClearExpired(expiry int64) int64 {
i.Lock()
defer i.Unlock()
var deleted int64
for k, m := range i.internal {
if m.Created < expiry || m.Created == 0 {
delete(i.internal, k)
deleted++
}
}
return deleted
}

View File

@@ -69,6 +69,36 @@ func BenchmarkClientsGet(b *testing.B) {
}
}
func TestClientsGetAll(t *testing.T) {
cl := New()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
cl.Add(&Client{ID: "t3"})
cl.Add(&Client{ID: "t4"})
cl.Add(&Client{ID: "t5"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Contains(t, cl.internal, "t3")
require.Contains(t, cl.internal, "t4")
require.Contains(t, cl.internal, "t5")
clients := cl.GetAll()
require.Len(t, clients, 5)
}
func BenchmarkClientsGetAll(b *testing.B) {
cl := New()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
cl.Add(&Client{ID: "t3"})
cl.Add(&Client{ID: "t4"})
cl.Add(&Client{ID: "t5"})
for n := 0; n < b.N; n++ {
clients := cl.GetAll()
require.Len(b, clients, 5)
}
}
func TestClientsLen(t *testing.T) {
cl := New()
cl.Add(&Client{ID: "t1"})
@@ -714,12 +744,13 @@ func TestClientWritePacket(t *testing.T) {
r.Close()
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
cl.Stop(testClientStop)
time.Sleep(time.Millisecond * 1)
// The stop cause is either the test error, EOF, or a
// closed pipe, depending on which goroutine runs first.
err = cl.StopCause()
time.Sleep(time.Millisecond * 5)
require.True(t,
errors.Is(err, testClientStop) ||
errors.Is(err, io.EOF) ||
@@ -858,6 +889,43 @@ func BenchmarkInflightDelete(b *testing.B) {
}
}
func TestInflightClearExpired(t *testing.T) {
n := time.Now().Unix()
cl := genClient()
cl.Inflight.Set(1, InflightMessage{
Packet: packets.Packet{},
Created: n - 1,
Sent: 0,
})
cl.Inflight.Set(2, InflightMessage{
Packet: packets.Packet{},
Created: n - 2,
Sent: 0,
})
cl.Inflight.Set(3, InflightMessage{
Packet: packets.Packet{},
Created: n - 3,
Sent: 0,
})
cl.Inflight.Set(5, InflightMessage{
Packet: packets.Packet{},
Created: n - 5,
Sent: 0,
})
require.Len(t, cl.Inflight.internal, 4)
deleted := cl.Inflight.ClearExpired(n - 2)
cl.Inflight.RLock()
defer cl.Inflight.RUnlock()
require.Len(t, cl.Inflight.internal, 2)
require.Equal(t, (n - 1), cl.Inflight.internal[1].Created)
require.Equal(t, (n - 2), cl.Inflight.internal[2].Created)
require.Equal(t, int64(0), cl.Inflight.internal[3].Created)
require.Equal(t, int64(2), deleted)
}
var (
pkTable = []struct {
bytes []byte

View File

@@ -21,7 +21,7 @@ func BenchmarkBytesToString(b *testing.B) {
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
name string
rawBytes []byte
result string
offset int
@@ -88,8 +88,8 @@ func TestDecodeString(t *testing.T) {
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrOffsetStrInvalidUTF8,
},
}
@@ -101,7 +101,7 @@ func TestDecodeString(t *testing.T) {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
@@ -209,7 +209,7 @@ func TestDecodeByte(t *testing.T) {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
@@ -250,12 +250,11 @@ func TestDecodeUint16(t *testing.T) {
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
@@ -295,7 +294,7 @@ func TestDecodeByteBool(t *testing.T) {
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return

View File

@@ -215,7 +215,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
}
if pk.PasswordFlag {
pk.Password, offset, err = decodeBytes(buf, offset)
pk.Password, _, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
}
@@ -287,7 +287,7 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
}
pk.ReturnCode, offset, err = decodeByte(buf, offset)
pk.ReturnCode, _, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
}

View File

@@ -70,6 +70,9 @@ func (l *HTTPStats) Listen(s *system.Info) error {
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
@@ -79,6 +82,8 @@ func (l *HTTPStats) Listen(s *system.Info) error {
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil

View File

@@ -73,6 +73,18 @@ func TestHTTPStatsListen(t *testing.T) {
l.listen.Close()
}
func TestHTTPStatsListenTLSConfig(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestHTTPStatsListenTLS(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{

View File

@@ -1,6 +1,7 @@
package listeners
import (
"crypto/tls"
"net"
"sync"
@@ -10,11 +11,25 @@ import (
// Config contains configuration values for a listener.
type Config struct {
Auth auth.Controller // an authentication controller containing auth and ACL logic.
TLS *TLS // the TLS certficates and settings for the connection.
// Auth controller containing auth and ACL logic for
// allowing or denying access to the server and topics.
Auth auth.Controller
// TLS certficates and settings for the connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
TLS *TLS
// TLSConfig is a tls.Config configuration to be used with the listener.
// See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config
}
// TLS contains the TLS certificates and settings for the listener connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
type TLS struct {
Certificate []byte // the body of a public certificate.
PrivateKey []byte // the body of a private key.

View File

@@ -1,6 +1,8 @@
package listeners
import (
"crypto/tls"
"log"
"testing"
"time"
@@ -37,8 +39,22 @@ WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
tlsConfigBasic *tls.Config
)
func init() {
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfigBasic = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
func TestNew(t *testing.T) {
l := New(nil)
require.NotNil(t, l.internal)

View File

@@ -62,6 +62,9 @@ func (l *TCP) ID() string {
func (l *TCP) Listen(s *system.Info) error {
var err error
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
@@ -72,9 +75,12 @@ func (l *TCP) Listen(s *system.Info) error {
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
} else if l.config.TLSConfig != nil {
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
} else {
l.listen, err = net.Listen(l.protocol, l.address)
}
if err != nil {
return err
}

View File

@@ -73,6 +73,17 @@ func TestTCPListen(t *testing.T) {
l.listen.Close()
}
func TestTCPListenTLSConfig(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
l.listen.Close()
}
func TestTCPListenTLS(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{

View File

@@ -122,6 +122,9 @@ func (l *Websocket) Listen(s *system.Info) error {
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
@@ -131,6 +134,8 @@ func (l *Websocket) Listen(s *system.Info) error {
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil

View File

@@ -77,6 +77,18 @@ func TestWebsocketListen(t *testing.T) {
require.NotNil(t, l.listen)
}
func TestWebsocketListenTLSConfig(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestWebsocketListenTLS(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{

View File

@@ -27,9 +27,10 @@ var (
// Store is a backend for writing and reading to bolt persistent storage.
type Store struct {
path string // the path on which to store the db file.
opts *bbolt.Options // options for configuring the boltdb instance.
db *storm.DB // the boltdb instance.
path string // the path on which to store the db file.
opts *bbolt.Options // options for configuring the boltdb instance.
db *storm.DB // the boltdb instance.
inflightTTL int64 // the number of seconds an inflight message should be retained before being dropped.
}
// New returns a configured instance of the boltdb store.
@@ -50,6 +51,13 @@ func New(path string, opts *bbolt.Options) *Store {
}
}
// SetInflightTTL sets the number of seconds an inflight message should be kept
// before being dropped, in the event it is not delivered. Unless you have a good reason,
// you should allow this to be called by the server (in AddStore) instead of directly.
func (s *Store) SetInflightTTL(seconds int64) {
s.inflightTTL = seconds
}
// Open opens the boltdb instance.
func (s *Store) Open() error {
var err error
@@ -251,7 +259,7 @@ func (s *Store) ReadRetained() (v []persistence.Message, err error) {
return v, nil
}
//ReadServerInfo loads the server info from the boltdb instance.
// ReadServerInfo loads the server info from the boltdb instance.
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
if s.db == nil {
return v, ErrDBNotOpen
@@ -264,3 +272,27 @@ func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
return v, nil
}
// ClearExpiredInflight deletes any inflight messages older than the provided unix timestamp.
func (s *Store) ClearExpiredInflight(expiry int64) error {
if s.db == nil {
return ErrDBNotOpen
}
var v []persistence.Message
err := s.db.Find("T", persistence.KInflight, &v)
if err != nil && err != storm.ErrNotFound {
return err
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := s.db.DeleteStruct(&persistence.Message{ID: m.ID})
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -49,6 +49,12 @@ func TestNewNoOpts(t *testing.T) {
require.Equal(t, defaultTimeout, s.opts.Timeout)
}
func TestSetInflightTTL(t *testing.T) {
s := New("", nil)
s.SetInflightTTL(5)
require.Equal(t, int64(5), s.inflightTTL)
}
func TestOpen(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
@@ -484,3 +490,51 @@ func TestDeleteRetainedFail(t *testing.T) {
err = s.DeleteRetained("a")
require.Error(t, err)
}
func TestClearExpiredInflight(t *testing.T) {
n := time.Now().Unix()
s := New(tmpPath, nil)
err := s.Open()
defer teardown(s, t)
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{
ID: "i1",
T: persistence.KInflight,
Created: n - 1,
})
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{
ID: "i2",
T: persistence.KInflight,
Created: n - 2,
})
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{
ID: "i3",
T: persistence.KInflight,
Created: n - 3,
})
require.NoError(t, err)
err = s.WriteInflight(persistence.Message{
ID: "i5",
T: persistence.KInflight,
Created: n - 5,
})
require.NoError(t, err)
m, err := s.ReadInflight()
require.NoError(t, err)
require.Len(t, m, 4)
s.ClearExpiredInflight(n - 2)
m, err = s.ReadInflight()
require.NoError(t, err)
require.Len(t, m, 2)
}

View File

@@ -28,22 +28,28 @@ const (
type Store interface {
Open() error
Close()
WriteSubscription(v Subscription) error
WriteClient(v Client) error
WriteInflight(v Message) error
WriteServerInfo(v ServerInfo) error
WriteRetained(v Message) error
DeleteSubscription(id string) error
DeleteClient(id string) error
DeleteInflight(id string) error
DeleteRetained(id string) error
ReadSubscriptions() (v []Subscription, err error)
ReadInflight() (v []Message, err error)
ReadRetained() (v []Message, err error)
WriteSubscription(v Subscription) error
DeleteSubscription(id string) error
ReadClients() (v []Client, err error)
WriteClient(v Client) error
DeleteClient(id string) error
ReadInflight() (v []Message, err error)
WriteInflight(v Message) error
DeleteInflight(id string) error
SetInflightTTL(seconds int64)
ClearExpiredInflight(expiry int64) error
ReadServerInfo() (v ServerInfo, err error)
WriteServerInfo(v ServerInfo) error
ReadRetained() (v []Message, err error)
WriteRetained(v Message) error
DeleteRetained(id string) error
}
// ServerInfo contains information and statistics about the server.
@@ -69,6 +75,7 @@ type Message struct {
ID string // the storage key.
Client string // the id of the client who sent the message (if inflight).
TopicName string // the topic the message was sent to (if retained).
Created int64 // the time the message was created in unixtime (if inflight).
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
Resends int // the number of times the message was attempted to be sent (if inflight).
PacketID uint16 // the unique id of the packet (if inflight).
@@ -103,10 +110,16 @@ type LWT struct {
// MockStore is a mock storage backend for testing.
type MockStore struct {
Fail map[string]bool // issue errors for different methods.
FailOpen bool // error on open.
Closed bool // indicate mock store is closed.
Opened bool // indicate mock store is open.
Fail map[string]bool // issue errors for different methods.
FailOpen bool // error on open.
Closed bool // indicate mock store is closed.
Opened bool // indicate mock store is open.
inflightTTL int64 // inflight expiry duration.
}
// Close closes the storage instance.
func (s *MockStore) SetInflightTTL(seconds int64) {
s.inflightTTL = seconds
}
// Open opens the storage instance.
@@ -275,7 +288,7 @@ func (s *MockStore) ReadRetained() (v []Message, err error) {
}, nil
}
//ReadServerInfo loads the server info from the storage instance.
// ReadServerInfo loads the server info from the storage instance.
func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
if _, ok := s.Fail["read_info"]; ok {
return v, errors.New("test_info")
@@ -289,3 +302,8 @@ func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
KServerInfo,
}, nil
}
// ReadServerInfo loads the server info from the storage instance.
func (s *MockStore) ClearExpiredInflight(d int64) error {
return nil
}

View File

@@ -26,6 +26,12 @@ func TestMockStoreClose(t *testing.T) {
require.Equal(t, true, s.Closed)
}
func TestMockStoreSetInflightTTL(t *testing.T) {
s := new(MockStore)
s.SetInflightTTL(5)
require.Equal(t, int64(5), s.inflightTTL)
}
func TestMockStoreWriteSubscription(t *testing.T) {
s := new(MockStore)
err := s.WriteSubscription(Subscription{})
@@ -249,3 +255,9 @@ func TestMockStoreReadRetainedFail(t *testing.T) {
_, err := s.ReadRetained()
require.Error(t, err)
}
func TestMockStoreClearExpiredInflight(t *testing.T) {
s := new(MockStore)
err := s.ClearExpiredInflight(2)
require.NoError(t, err)
}

View File

@@ -25,6 +25,9 @@ import (
const (
// Version indicates the current server version.
Version = "1.1.1"
// defaultInflightTTL is the number of seconds a pending inflight message should last.
defaultInflightTTL int64 = 60 * 60 * 24
)
var (
@@ -73,17 +76,19 @@ var (
// Server is an MQTT broker server. It should be created with server.New()
// in order to ensure all the internal fields are correctly populated.
type Server struct {
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
Store persistence.Store // a persistent storage backend if desired.
Options *Options // configurable server options.
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
Clients *clients.Clients // clients which are known to the broker.
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
System *system.Info // values about the server commonly found in $SYS topics.
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
done chan bool // indicate that the server is ending.
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
Store persistence.Store // a persistent storage backend if desired.
Options *Options // configurable server options.
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
Clients *clients.Clients // clients which are known to the broker.
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
System *system.Info // values about the server commonly found in $SYS topics.
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
inflightExpiryTicker *time.Ticker // the interval ticker for cleaning up expired messages.
inflightResendTicker *time.Ticker // the interval ticker for resending unresolved inflight messages.
done chan bool // indicate that the server is ending.
}
// Options contains configurable options for the server.
@@ -93,6 +98,9 @@ type Options struct {
// BufferBlockSize overrides the default buffer block size (DefaultBlockSize) for the client buffers.
BufferBlockSize int
// InflightTTL specifies the duration that a queued inflight message should exist before being purged.
InflightTTL int64
}
// inlineMessages contains channels for handling inline (direct) publishing.
@@ -114,6 +122,10 @@ func NewServer(opts *Options) *Server {
opts = new(Options)
}
if opts.InflightTTL < 1 {
opts.InflightTTL = defaultInflightTTL
}
s := &Server{
done: make(chan bool),
bytepool: circ.NewBytesPool(opts.BufferSize),
@@ -123,7 +135,9 @@ func NewServer(opts *Options) *Server {
Version: Version,
Started: time.Now().Unix(),
},
sysTicker: time.NewTicker(SysTopicInterval * time.Millisecond),
sysTicker: time.NewTicker(SysTopicInterval * time.Millisecond),
inflightExpiryTicker: time.NewTicker(time.Duration(opts.InflightTTL) * time.Second),
inflightResendTicker: time.NewTicker(time.Duration(10) * time.Second),
inline: inlineMessages{
done: make(chan bool),
pub: make(chan packets.Packet, 1024),
@@ -143,6 +157,8 @@ func NewServer(opts *Options) *Server {
// called before calling server.Server().
func (s *Server) AddStore(p persistence.Store) error {
s.Store = p
s.Store.SetInflightTTL(s.Options.InflightTTL)
err := s.Store.Open()
if err != nil {
return err
@@ -198,6 +214,10 @@ func (s *Server) eventLoop() {
return
case <-s.sysTicker.C:
s.publishSysTopics()
case <-s.inflightExpiryTicker.C:
s.clearExpiredInflights(time.Now().Unix())
case <-s.inflightResendTicker.C:
s.resendPendingInflights()
}
}
}
@@ -319,9 +339,11 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
return s.onError(cl.Info(), fmt.Errorf("ack connection packet: %w", err))
}
err = s.ResendClientInflight(cl, true)
if err != nil {
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
if sessionPresent {
err = s.ResendClientInflight(cl, true)
if err != nil {
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
}
}
if s.Store != nil {
@@ -346,6 +368,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
err = cl.StopCause() // Determine true cause of stop.
if cl.CleanSession {
s.clearAbandonedInflights(cl)
}
if s.Events.OnDisconnect != nil {
s.Events.OnDisconnect(cl.Info(), err)
}
@@ -379,6 +405,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *clients.Client) boo
// The state associated with a CleanSession MUST NOT be reused in any subsequent session.
if pk.CleanSession || existing.CleanSession {
s.unsubscribeClient(existing)
s.clearAbandonedInflights(existing)
return false
}
@@ -400,6 +427,9 @@ func (s *Server) unsubscribeClient(cl *clients.Client) {
for k := range cl.Subscriptions {
delete(cl.Subscriptions, k)
if s.Topics.Unsubscribe(k, cl.ID) {
if s.Events.OnUnsubscribe != nil {
s.Events.OnUnsubscribe(k, cl.Info())
}
atomic.AddInt64(&s.System.Subscriptions, -1)
}
}
@@ -637,8 +667,9 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
// if an appropriate ack is not received (or if the client is offline).
sent := time.Now().Unix()
q := client.Inflight.Set(out.PacketID, clients.InflightMessage{
Packet: out,
Sent: sent,
Packet: out,
Created: time.Now().Unix(),
Sent: sent,
})
if q {
atomic.AddInt64(&s.System.Inflight, 1)
@@ -735,8 +766,11 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
if !cl.AC.ACL(cl.Username, pk.Topics[i], false) {
retCodes[i] = packets.ErrSubAckNetworkError
} else {
q := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
if q {
r := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
if r {
if s.Events.OnSubscribe != nil {
s.Events.OnSubscribe(pk.Topics[i], cl.Info(), pk.Qoss[i])
}
atomic.AddInt64(&s.System.Subscriptions, 1)
}
cl.NoteSubscription(pk.Topics[i], pk.Qoss[i])
@@ -785,6 +819,9 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
for i := 0; i < len(pk.Topics); i++ {
q := s.Topics.Unsubscribe(pk.Topics[i], cl.ID)
if q {
if s.Events.OnUnsubscribe != nil {
s.Events.OnUnsubscribe(pk.Topics[i], cl.Info())
}
atomic.AddInt64(&s.System.Subscriptions, -1)
}
cl.ForgetSubscription(pk.Topics[i])
@@ -867,7 +904,7 @@ func (s *Server) publishSysTopics() {
// ResendClientInflight attempts to resend all undelivered inflight messages
// to a client.
func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
if cl.Inflight.Len() == 0 {
if atomic.LoadUint32(&cl.State.Done) == 1 || cl.Inflight.Len() == 0 {
return nil
}
@@ -1004,9 +1041,13 @@ func (s *Server) loadServerInfo(v persistence.ServerInfo) {
// loadSubscriptions restores subscriptions from the datastore.
func (s *Server) loadSubscriptions(v []persistence.Subscription) {
for _, sub := range v {
s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS)
if cl, ok := s.Clients.Get(sub.Client); ok {
cl.NoteSubscription(sub.Filter, sub.QoS)
if s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) {
if cl, ok := s.Clients.Get(sub.Client); ok {
cl.NoteSubscription(sub.Filter, sub.QoS)
if s.Events.OnSubscribe != nil {
s.Events.OnSubscribe(sub.Filter, cl.Info(), sub.QoS)
}
}
}
}
}
@@ -1034,6 +1075,7 @@ func (s *Server) loadInflight(v []persistence.Message) {
TopicName: msg.TopicName,
Payload: msg.Payload,
},
Created: msg.Created,
Sent: msg.Sent,
Resends: msg.Resends,
})
@@ -1051,3 +1093,36 @@ func (s *Server) loadRetained(v []persistence.Message) {
})
}
}
// clearExpiredInflights deletes all inflight messages older than server inflight TTL.
func (s *Server) clearExpiredInflights(dt int64) {
expiry := dt - s.Options.InflightTTL
for _, client := range s.Clients.GetAll() {
deleted := client.Inflight.ClearExpired(expiry)
atomic.AddInt64(&s.System.Inflight, deleted*-1)
}
if s.Store != nil {
s.Store.ClearExpiredInflight(expiry)
}
}
// clearAbandonedInflights deletes all inflight messages for a disconnected user (eg. with a clean session).
func (s *Server) clearAbandonedInflights(cl *clients.Client) {
for i := range cl.Inflight.GetAll() {
cl.Inflight.Delete(i)
atomic.AddInt64(&s.System.Inflight, -1)
}
}
// resendPendingInflights attempts resends of any pending and due inflight messages.
func (s *Server) resendPendingInflights() {
for _, client := range s.Clients.GetAll() {
err := s.ResendClientInflight(client, false)
if err != nil {
// TODO add log-level debugging.
continue
}
}
}

View File

@@ -368,9 +368,10 @@ func TestServerEventOnConnect(t *testing.T) {
time.Sleep(10 * time.Millisecond)
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
require.Equal(t, events.Packet(packets.Packet{
@@ -443,9 +444,10 @@ func TestServerEventOnDisconnect(t *testing.T) {
w.Close()
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
require.ErrorIs(t, ErrClientDisconnect, hook.err)
@@ -506,9 +508,10 @@ func TestServerEventOnDisconnectOnError(t *testing.T) {
require.Equal(t, errx, hook.err)
require.Equal(t, events.Client{
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
ID: "mochi",
Remote: "pipe",
Listener: "tcp",
CleanSession: true,
}, hook.client)
clw, ok := s.Clients.Get("mochi")
@@ -1946,6 +1949,15 @@ func TestServerProcessSubscribeInvalid(t *testing.T) {
func TestServerProcessSubscribe(t *testing.T) {
s, cl, r, w := setupClient()
subscribeEvent := ""
subscribeClient := ""
s.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
if filter == "a/b/c" {
subscribeEvent = "a/b/c"
subscribeClient = cl.ID
}
}
s.Topics.RetainMessage(packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
@@ -1995,6 +2007,8 @@ func TestServerProcessSubscribe(t *testing.T) {
require.Equal(t, byte(1), cl.Subscriptions["d/e/f"])
require.Equal(t, topics.Subscriptions{cl.ID: 0}, s.Topics.Subscribers("a/b/c"))
require.Equal(t, topics.Subscriptions{cl.ID: 1}, s.Topics.Subscribers("d/e/f"))
require.Equal(t, "a/b/c", subscribeEvent)
require.Equal(t, cl.ID, subscribeClient)
}
func TestServerProcessSubscribeFailACL(t *testing.T) {
@@ -2114,6 +2128,16 @@ func TestServerProcessUnsubscribeInvalid(t *testing.T) {
func TestServerProcessUnsubscribe(t *testing.T) {
s, cl, r, w := setupClient()
unsubscribeEvent := ""
unsubscribeClient := ""
s.Events.OnUnsubscribe = func(filter string, cl events.Client) {
if filter == "a/b/c" {
unsubscribeEvent = "a/b/c"
unsubscribeClient = cl.ID
}
}
s.Clients.Add(cl)
s.Topics.Subscribe("a/b/c", cl.ID, 0)
s.Topics.Subscribe("d/e/f", cl.ID, 1)
@@ -2155,6 +2179,9 @@ func TestServerProcessUnsubscribe(t *testing.T) {
require.NotEmpty(t, s.Topics.Subscribers("a/b/+"))
require.Contains(t, cl.Subscriptions, "a/b/+")
require.Equal(t, "a/b/c", unsubscribeEvent)
require.Equal(t, cl.ID, unsubscribeClient)
}
func TestServerProcessUnsubscribeWriteError(t *testing.T) {
@@ -2677,3 +2704,76 @@ func TestServerResendClientInflightError(t *testing.T) {
err := s.ResendClientInflight(cl, true)
require.Error(t, err)
}
func TestServerClearExpiredInflights(t *testing.T) {
n := time.Now().Unix()
s := New()
s.Options.InflightTTL = 2
require.NotNil(t, s)
r, _ := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Inflight.Set(1, clients.InflightMessage{
Packet: packets.Packet{},
Created: n - 1,
Sent: 0,
})
cl.Inflight.Set(2, clients.InflightMessage{
Packet: packets.Packet{},
Created: n - 2,
Sent: 0,
})
cl.Inflight.Set(3, clients.InflightMessage{
Packet: packets.Packet{},
Created: n - 3,
Sent: 0,
})
cl.Inflight.Set(5, clients.InflightMessage{
Packet: packets.Packet{},
Created: n - 5,
Sent: 0,
})
s.Clients.Add(cl)
require.Len(t, cl.Inflight.GetAll(), 4)
s.clearExpiredInflights(n)
require.Len(t, cl.Inflight.GetAll(), 2)
require.Equal(t, int64(-2), s.System.Inflight)
}
func TestServerClearAbandonedInflights(t *testing.T) {
s := New()
require.NotNil(t, s)
r, _ := net.Pipe()
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl.Inflight.Set(1, clients.InflightMessage{
Packet: packets.Packet{},
Sent: 0,
})
cl.Inflight.Set(2, clients.InflightMessage{
Packet: packets.Packet{},
Sent: 0,
})
cl2 := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
cl2.Inflight.Set(3, clients.InflightMessage{
Packet: packets.Packet{},
Sent: 0,
})
cl2.Inflight.Set(5, clients.InflightMessage{
Packet: packets.Packet{},
Sent: 0,
})
s.Clients.Add(cl)
s.Clients.Add(cl2)
require.Len(t, cl.Inflight.GetAll(), 2)
require.Len(t, cl2.Inflight.GetAll(), 2)
s.clearAbandonedInflights(cl)
require.Len(t, cl.Inflight.GetAll(), 0)
require.Len(t, cl2.Inflight.GetAll(), 2)
require.Equal(t, int64(-2), s.System.Inflight)
}