mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-11-02 12:24:05 +08:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ecbd07fa3a | ||
|
|
ad8bf2a931 | ||
|
|
b8fb068bb9 | ||
|
|
c1348a37b8 | ||
|
|
84fc2f848b | ||
|
|
8703d6d020 | ||
|
|
666440fe56 | ||
|
|
1ae050939a | ||
|
|
f4683d27d0 | ||
|
|
dff2b1db30 | ||
|
|
9de6b4e427 | ||
|
|
78c1914270 | ||
|
|
f71bf5c3d6 | ||
|
|
53c4a6b09f | ||
|
|
a02c6bd8df | ||
|
|
d8f6d63cc8 | ||
|
|
bef13eec20 | ||
|
|
27f3c484ad | ||
|
|
9b5cdb0bcc | ||
|
|
2b60a11d4a | ||
|
|
b53774f818 | ||
|
|
7dee729afb | ||
|
|
aed535b7bf | ||
|
|
4ff888ab3b | ||
|
|
31252c081b |
44
README.md
44
README.md
@@ -30,22 +30,31 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
|
||||
- Interfaces for Client Authentication and Topic access control.
|
||||
- Bolt persistence and storage interfaces (see examples folder).
|
||||
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
|
||||
- Basic Event Hooks (`OnMessage`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
|
||||
- ARM32 Compatible.
|
||||
- Basic Event Hooks (`OnMessage`, `onSubscribe`, `onUnsubscribe`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
|
||||
- ARM32 Compatible (v1.1.1).
|
||||
|
||||
#### Roadmap
|
||||
- Please open an issue to request new features or event hooks.
|
||||
- MQTT v5 compatibility?
|
||||
|
||||
#### Using the Broker
|
||||
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the `main.go` entrypoint in the `cmd` folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners. A docker image is coming soon.
|
||||
#### Using the Broker from Go
|
||||
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the `main.go` entrypoint in the `cmd` folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners.
|
||||
|
||||
```
|
||||
cd cmd
|
||||
go build -o mqtt && ./mqtt
|
||||
```
|
||||
|
||||
#### Quick Start
|
||||
#### Using Docker
|
||||
|
||||
A simple Dockerfile is provided for running the `cmd/main.go` Websocket, TCP, and Stats server:
|
||||
|
||||
```sh
|
||||
docker build -t mochi:latest .
|
||||
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
|
||||
```
|
||||
|
||||
#### Package Quick Start
|
||||
|
||||
``` go
|
||||
import (
|
||||
@@ -76,6 +85,8 @@ func main() {
|
||||
|
||||
Examples of running the broker with various configurations can be found in the `examples` folder.
|
||||
|
||||
|
||||
|
||||
#### Network Listeners
|
||||
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
|
||||
- `listeners.NewTCP(id, address string)` - A TCP Listener, taking a unique ID and a network address to bind.
|
||||
@@ -130,7 +141,25 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
|
||||
|
||||
```go
|
||||
server.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
|
||||
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnSubscribe
|
||||
`server.Events.OnSubscribe` is called when a client subscribes to a new topic filter.
|
||||
|
||||
```go
|
||||
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
|
||||
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnUnsubscribe
|
||||
`server.Events.OnUnsubscribe` is called when a client unsubscribes from a topic filter.
|
||||
|
||||
```go
|
||||
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
|
||||
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
|
||||
}
|
||||
```
|
||||
|
||||
@@ -147,7 +176,6 @@ If an error is returned, the packet will not be modified. and the existing packe
|
||||
|
||||
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
|
||||
|
||||
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/server/events"
|
||||
|
||||
@@ -164,6 +192,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
|
||||
|
||||
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
|
||||
|
||||
|
||||
|
||||
##### OnError
|
||||
`server.Events.OnError` is called when an error is encountered on the server, particularly within the use of a client connection status.
|
||||
|
||||
|
||||
105
examples/auth/main.go
Normal file
105
examples/auth/main.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/logrusorgru/aurora"
|
||||
|
||||
mqtt "github.com/mochi-co/mqtt/server"
|
||||
"github.com/mochi-co/mqtt/server/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
|
||||
|
||||
server := mqtt.NewServer(nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
err := server.AddListener(tcp, &listeners.Config{
|
||||
Auth: &Auth{
|
||||
Users: map[string]string{
|
||||
"peach": "password1",
|
||||
"melon": "password2",
|
||||
"apple": "password3",
|
||||
},
|
||||
AllowedTopics: map[string][]string{
|
||||
// Melon user only has access to melon topics.
|
||||
// If you were implementing this in the real world, you might ensure
|
||||
// that any topic prefixed with "melon" is allowed (see ACL func below).
|
||||
"melon": {"melon/info", "melon/events"},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
fmt.Println(aurora.BgMagenta(" Started! "))
|
||||
|
||||
<-done
|
||||
fmt.Println(aurora.BgRed(" Caught Signal "))
|
||||
|
||||
server.Close()
|
||||
fmt.Println(aurora.BgGreen(" Finished "))
|
||||
}
|
||||
|
||||
// Auth is an example auth provider for the server. In the real world
|
||||
// you are more likely to replace these fields with database/cache lookups
|
||||
// to check against an auth list. As the Auth Controller is an interface, it can
|
||||
// be built however you want, as long as it fulfils the interface signature.
|
||||
type Auth struct {
|
||||
Users map[string]string // A map of usernames (key) with passwords (value).
|
||||
AllowedTopics map[string][]string // A map of usernames and topics
|
||||
}
|
||||
|
||||
// Authenticate returns true if a username and password are acceptable.
|
||||
func (a *Auth) Authenticate(user, password []byte) bool {
|
||||
// If the user exists in the auth users map, and the password is correct,
|
||||
// then they can connect to the server. In the real world, this could be a database
|
||||
// or cached users lookup.
|
||||
if pass, ok := a.Users[string(user)]; ok && pass == string(password) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ACL returns true if a user has access permissions to read or write on a topic.
|
||||
func (a *Auth) ACL(user []byte, topic string, write bool) bool {
|
||||
|
||||
// An example ACL - if the user has an entry in the auth allow list, then they are
|
||||
// subject to ACL restrictions. Only let them use a topic if it's available for their
|
||||
// user.
|
||||
if topics, ok := a.AllowedTopics[string(user)]; ok {
|
||||
for _, t := range topics {
|
||||
|
||||
// In the real world you might allow all topics prefixed with a user's username,
|
||||
// or similar multi-topic filters.
|
||||
if t == topic {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Otherwise, allow all topics.
|
||||
return true
|
||||
}
|
||||
@@ -54,6 +54,16 @@ func main() {
|
||||
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
|
||||
// Add OnSubscribe Event Hook
|
||||
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
|
||||
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
|
||||
}
|
||||
|
||||
// Add OnUnsubscribe Event Hook
|
||||
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
|
||||
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
|
||||
}
|
||||
|
||||
// Add OnMessage Event Hook
|
||||
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
|
||||
pkx = pk
|
||||
|
||||
@@ -27,8 +27,9 @@ func main() {
|
||||
|
||||
// An example of configuring various server options...
|
||||
options := &mqtt.Options{
|
||||
BufferSize: 0, // Use default values
|
||||
BufferBlockSize: 0, // Use default values
|
||||
BufferSize: 0, // Use default values
|
||||
BufferBlockSize: 0, // Use default values
|
||||
InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
|
||||
}
|
||||
|
||||
server := mqtt.NewServer(options)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -57,14 +58,30 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
|
||||
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
// Optionally, if you want clients to authenticate only with certs issued by your CA,
|
||||
// you might want to use something like this:
|
||||
// certPool := x509.NewCertPool()
|
||||
// _ = certPool.AppendCertsFromPEM(caCertPem)
|
||||
// tlsConfig := &tls.Config{
|
||||
// ClientCAs: certPool,
|
||||
// ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// }
|
||||
|
||||
server := mqtt.NewServer(nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
err := server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
err = server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -72,11 +89,8 @@ func main() {
|
||||
|
||||
ws := listeners.NewWebsocket("ws1", ":1882")
|
||||
err = server.AddListener(ws, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -84,11 +98,8 @@ func main() {
|
||||
|
||||
stats := listeners.NewHTTPStats("stats", ":8080")
|
||||
err = server.AddListener(stats, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
||||
@@ -11,6 +11,8 @@ type Events struct {
|
||||
OnError // server error.
|
||||
OnConnect // client connected.
|
||||
OnDisconnect // client disconnected.
|
||||
OnSubscribe // topic subscription created.
|
||||
OnUnsubscribe // topic subscription removed.
|
||||
}
|
||||
|
||||
// Packets is an alias for packets.Packet.
|
||||
@@ -18,9 +20,11 @@ type Packet packets.Packet
|
||||
|
||||
// Client contains limited information about a connected client.
|
||||
type Client struct {
|
||||
ID string
|
||||
Remote string
|
||||
Listener string
|
||||
ID string
|
||||
Remote string
|
||||
Listener string
|
||||
Username []byte
|
||||
CleanSession bool
|
||||
}
|
||||
|
||||
// Clientlike is an interface for Clients and client-like objects that
|
||||
@@ -40,7 +44,7 @@ type Clientlike interface {
|
||||
// be dispatched as if the event hook had not been triggered.
|
||||
// This function will block message dispatching until it returns. To minimise this,
|
||||
// have the function open a new goroutine on the embedding side.
|
||||
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
|
||||
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any further
|
||||
// processing of the packet.
|
||||
type OnProcessMessage func(Client, Packet) (Packet, error)
|
||||
|
||||
@@ -66,3 +70,9 @@ type OnDisconnect func(Client, error)
|
||||
// OnError is called when errors that will not be passed to
|
||||
// OnDisconnect are handled by the server.
|
||||
type OnError func(Client, error)
|
||||
|
||||
// OnSubscribe is called when a new subscription filter for a client is created.
|
||||
type OnSubscribe func(filter string, cl Client, qos byte)
|
||||
|
||||
// OnUnsubscribe is called when an existing subscription filter for a client is removed.
|
||||
type OnUnsubscribe func(filter string, cl Client)
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
|
||||
// BytesPool is a pool of []byte.
|
||||
type BytesPool struct {
|
||||
// int64/uint64 has to the first words in order
|
||||
// to be 64-aligned on 32-bit architectures.
|
||||
used int64 // access atomically
|
||||
pool *sync.Pool
|
||||
used int64
|
||||
}
|
||||
|
||||
// NewBytesPool returns a sync.pool of []byte.
|
||||
|
||||
@@ -49,6 +49,13 @@ func (cl *Clients) Add(val *Client) {
|
||||
cl.Unlock()
|
||||
}
|
||||
|
||||
// GetAll returns all the clients.
|
||||
func (cl *Clients) GetAll() map[string]*Client {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
return cl.internal
|
||||
}
|
||||
|
||||
// Get returns the value of a client if it exists.
|
||||
func (cl *Clients) Get(id string) (*Client, bool) {
|
||||
cl.RLock()
|
||||
@@ -200,9 +207,11 @@ func (cl *Client) Info() events.Client {
|
||||
addr = cl.conn.RemoteAddr().String()
|
||||
}
|
||||
return events.Client{
|
||||
ID: cl.ID,
|
||||
Remote: addr,
|
||||
Listener: cl.Listener,
|
||||
ID: cl.ID,
|
||||
Remote: addr,
|
||||
Username: cl.Username,
|
||||
CleanSession: cl.CleanSession,
|
||||
Listener: cl.Listener,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -509,6 +518,7 @@ type LWT struct {
|
||||
type InflightMessage struct {
|
||||
Packet packets.Packet // the packet currently in-flight.
|
||||
Sent int64 // the last time the message was sent (for retries) in unixtime.
|
||||
Created int64 // the unix timestamp when the inflight message was created.
|
||||
Resends int // the number of times the message was attempted to be sent.
|
||||
}
|
||||
|
||||
@@ -555,8 +565,25 @@ func (i *Inflight) GetAll() map[uint16]InflightMessage {
|
||||
// message existed.
|
||||
func (i *Inflight) Delete(key uint16) bool {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
_, ok := i.internal[key]
|
||||
delete(i.internal, key)
|
||||
i.Unlock()
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// ClearExpired deletes any inflight messages that have remained longer than
|
||||
// the servers InflightTTL duration. Returns number of deleted inflights.
|
||||
func (i *Inflight) ClearExpired(expiry int64) int64 {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
var deleted int64
|
||||
for k, m := range i.internal {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
delete(i.internal, k)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
|
||||
return deleted
|
||||
}
|
||||
|
||||
@@ -69,6 +69,36 @@ func BenchmarkClientsGet(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsGetAll(t *testing.T) {
|
||||
cl := New()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
cl.Add(&Client{ID: "t3"})
|
||||
cl.Add(&Client{ID: "t4"})
|
||||
cl.Add(&Client{ID: "t5"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
require.Contains(t, cl.internal, "t3")
|
||||
require.Contains(t, cl.internal, "t4")
|
||||
require.Contains(t, cl.internal, "t5")
|
||||
|
||||
clients := cl.GetAll()
|
||||
require.Len(t, clients, 5)
|
||||
}
|
||||
|
||||
func BenchmarkClientsGetAll(b *testing.B) {
|
||||
cl := New()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
cl.Add(&Client{ID: "t3"})
|
||||
cl.Add(&Client{ID: "t4"})
|
||||
cl.Add(&Client{ID: "t5"})
|
||||
for n := 0; n < b.N; n++ {
|
||||
clients := cl.GetAll()
|
||||
require.Len(b, clients, 5)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsLen(t *testing.T) {
|
||||
cl := New()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
@@ -714,12 +744,13 @@ func TestClientWritePacket(t *testing.T) {
|
||||
r.Close()
|
||||
|
||||
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
|
||||
|
||||
cl.Stop(testClientStop)
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
|
||||
// The stop cause is either the test error, EOF, or a
|
||||
// closed pipe, depending on which goroutine runs first.
|
||||
err = cl.StopCause()
|
||||
time.Sleep(time.Millisecond * 5)
|
||||
require.True(t,
|
||||
errors.Is(err, testClientStop) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
@@ -858,6 +889,43 @@ func BenchmarkInflightDelete(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInflightClearExpired(t *testing.T) {
|
||||
n := time.Now().Unix()
|
||||
|
||||
cl := genClient()
|
||||
cl.Inflight.Set(1, InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 1,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(2, InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 2,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(3, InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 3,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(5, InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 5,
|
||||
Sent: 0,
|
||||
})
|
||||
|
||||
require.Len(t, cl.Inflight.internal, 4)
|
||||
|
||||
deleted := cl.Inflight.ClearExpired(n - 2)
|
||||
cl.Inflight.RLock()
|
||||
defer cl.Inflight.RUnlock()
|
||||
require.Len(t, cl.Inflight.internal, 2)
|
||||
require.Equal(t, (n - 1), cl.Inflight.internal[1].Created)
|
||||
require.Equal(t, (n - 2), cl.Inflight.internal[2].Created)
|
||||
require.Equal(t, int64(0), cl.Inflight.internal[3].Created)
|
||||
require.Equal(t, int64(2), deleted)
|
||||
}
|
||||
|
||||
var (
|
||||
pkTable = []struct {
|
||||
bytes []byte
|
||||
|
||||
@@ -21,7 +21,7 @@ func BenchmarkBytesToString(b *testing.B) {
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
name string
|
||||
rawBytes []byte
|
||||
result string
|
||||
offset int
|
||||
@@ -88,8 +88,8 @@ func TestDecodeString(t *testing.T) {
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrOffsetStrInvalidUTF8,
|
||||
},
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func TestDecodeString(t *testing.T) {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
@@ -209,7 +209,7 @@ func TestDecodeByte(t *testing.T) {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
@@ -250,12 +250,11 @@ func TestDecodeUint16(t *testing.T) {
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
@@ -295,7 +294,7 @@ func TestDecodeByteBool(t *testing.T) {
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
|
||||
@@ -215,7 +215,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
}
|
||||
|
||||
if pk.PasswordFlag {
|
||||
pk.Password, offset, err = decodeBytes(buf, offset)
|
||||
pk.Password, _, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
|
||||
}
|
||||
@@ -287,7 +287,7 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
|
||||
}
|
||||
|
||||
pk.ReturnCode, offset, err = decodeByte(buf, offset)
|
||||
pk.ReturnCode, _, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
|
||||
}
|
||||
|
||||
@@ -70,6 +70,9 @@ func (l *HTTPStats) Listen(s *system.Info) error {
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
@@ -79,6 +82,8 @@ func (l *HTTPStats) Listen(s *system.Info) error {
|
||||
l.listen.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
} else {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -73,6 +73,18 @@ func TestHTTPStatsListen(t *testing.T) {
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLSConfig(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(new(system.Info))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLS(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
@@ -10,11 +11,25 @@ import (
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
type Config struct {
|
||||
Auth auth.Controller // an authentication controller containing auth and ACL logic.
|
||||
TLS *TLS // the TLS certficates and settings for the connection.
|
||||
// Auth controller containing auth and ACL logic for
|
||||
// allowing or denying access to the server and topics.
|
||||
Auth auth.Controller
|
||||
|
||||
// TLS certficates and settings for the connection.
|
||||
//
|
||||
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
|
||||
// Please use TLSConfig instead.
|
||||
TLS *TLS
|
||||
|
||||
// TLSConfig is a tls.Config configuration to be used with the listener.
|
||||
// See examples folder for basic and mutual-tls use.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// TLS contains the TLS certificates and settings for the listener connection.
|
||||
//
|
||||
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
|
||||
// Please use TLSConfig instead.
|
||||
type TLS struct {
|
||||
Certificate []byte // the body of a public certificate.
|
||||
PrivateKey []byte // the body of a private key.
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -37,8 +39,22 @@ WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
tlsConfigBasic *tls.Config
|
||||
)
|
||||
|
||||
func init() {
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfigBasic = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
l := New(nil)
|
||||
require.NotNil(t, l.internal)
|
||||
|
||||
@@ -62,6 +62,9 @@ func (l *TCP) ID() string {
|
||||
func (l *TCP) Listen(s *system.Info) error {
|
||||
var err error
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
@@ -72,9 +75,12 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
} else if l.config.TLSConfig != nil {
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
|
||||
} else {
|
||||
l.listen, err = net.Listen(l.protocol, l.address)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -73,6 +73,17 @@ func TestTCPListen(t *testing.T) {
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLSConfig(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLS(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
||||
@@ -122,6 +122,9 @@ func (l *Websocket) Listen(s *system.Info) error {
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
@@ -131,6 +134,8 @@ func (l *Websocket) Listen(s *system.Info) error {
|
||||
l.listen.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
} else {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -77,6 +77,18 @@ func TestWebsocketListen(t *testing.T) {
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLSConfig(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLS(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
||||
@@ -27,9 +27,10 @@ var (
|
||||
|
||||
// Store is a backend for writing and reading to bolt persistent storage.
|
||||
type Store struct {
|
||||
path string // the path on which to store the db file.
|
||||
opts *bbolt.Options // options for configuring the boltdb instance.
|
||||
db *storm.DB // the boltdb instance.
|
||||
path string // the path on which to store the db file.
|
||||
opts *bbolt.Options // options for configuring the boltdb instance.
|
||||
db *storm.DB // the boltdb instance.
|
||||
inflightTTL int64 // the number of seconds an inflight message should be retained before being dropped.
|
||||
}
|
||||
|
||||
// New returns a configured instance of the boltdb store.
|
||||
@@ -50,6 +51,13 @@ func New(path string, opts *bbolt.Options) *Store {
|
||||
}
|
||||
}
|
||||
|
||||
// SetInflightTTL sets the number of seconds an inflight message should be kept
|
||||
// before being dropped, in the event it is not delivered. Unless you have a good reason,
|
||||
// you should allow this to be called by the server (in AddStore) instead of directly.
|
||||
func (s *Store) SetInflightTTL(seconds int64) {
|
||||
s.inflightTTL = seconds
|
||||
}
|
||||
|
||||
// Open opens the boltdb instance.
|
||||
func (s *Store) Open() error {
|
||||
var err error
|
||||
@@ -251,7 +259,7 @@ func (s *Store) ReadRetained() (v []persistence.Message, err error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
//ReadServerInfo loads the server info from the boltdb instance.
|
||||
// ReadServerInfo loads the server info from the boltdb instance.
|
||||
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
|
||||
if s.db == nil {
|
||||
return v, ErrDBNotOpen
|
||||
@@ -264,3 +272,27 @@ func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// ClearExpiredInflight deletes any inflight messages older than the provided unix timestamp.
|
||||
func (s *Store) ClearExpiredInflight(expiry int64) error {
|
||||
if s.db == nil {
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
var v []persistence.Message
|
||||
err := s.db.Find("T", persistence.KInflight, &v)
|
||||
if err != nil && err != storm.ErrNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
err := s.db.DeleteStruct(&persistence.Message{ID: m.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,6 +49,12 @@ func TestNewNoOpts(t *testing.T) {
|
||||
require.Equal(t, defaultTimeout, s.opts.Timeout)
|
||||
}
|
||||
|
||||
func TestSetInflightTTL(t *testing.T) {
|
||||
s := New("", nil)
|
||||
s.SetInflightTTL(5)
|
||||
require.Equal(t, int64(5), s.inflightTTL)
|
||||
}
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
s := New(tmpPath, nil)
|
||||
err := s.Open()
|
||||
@@ -484,3 +490,51 @@ func TestDeleteRetainedFail(t *testing.T) {
|
||||
err = s.DeleteRetained("a")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClearExpiredInflight(t *testing.T) {
|
||||
n := time.Now().Unix()
|
||||
|
||||
s := New(tmpPath, nil)
|
||||
err := s.Open()
|
||||
defer teardown(s, t)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.WriteInflight(persistence.Message{
|
||||
ID: "i1",
|
||||
T: persistence.KInflight,
|
||||
Created: n - 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.WriteInflight(persistence.Message{
|
||||
ID: "i2",
|
||||
T: persistence.KInflight,
|
||||
Created: n - 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.WriteInflight(persistence.Message{
|
||||
ID: "i3",
|
||||
T: persistence.KInflight,
|
||||
Created: n - 3,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.WriteInflight(persistence.Message{
|
||||
ID: "i5",
|
||||
T: persistence.KInflight,
|
||||
Created: n - 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
m, err := s.ReadInflight()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 4)
|
||||
|
||||
s.ClearExpiredInflight(n - 2)
|
||||
|
||||
m, err = s.ReadInflight()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 2)
|
||||
}
|
||||
|
||||
@@ -28,22 +28,28 @@ const (
|
||||
type Store interface {
|
||||
Open() error
|
||||
Close()
|
||||
WriteSubscription(v Subscription) error
|
||||
WriteClient(v Client) error
|
||||
WriteInflight(v Message) error
|
||||
WriteServerInfo(v ServerInfo) error
|
||||
WriteRetained(v Message) error
|
||||
|
||||
DeleteSubscription(id string) error
|
||||
DeleteClient(id string) error
|
||||
DeleteInflight(id string) error
|
||||
DeleteRetained(id string) error
|
||||
|
||||
ReadSubscriptions() (v []Subscription, err error)
|
||||
ReadInflight() (v []Message, err error)
|
||||
ReadRetained() (v []Message, err error)
|
||||
WriteSubscription(v Subscription) error
|
||||
DeleteSubscription(id string) error
|
||||
|
||||
ReadClients() (v []Client, err error)
|
||||
WriteClient(v Client) error
|
||||
DeleteClient(id string) error
|
||||
|
||||
ReadInflight() (v []Message, err error)
|
||||
WriteInflight(v Message) error
|
||||
DeleteInflight(id string) error
|
||||
|
||||
SetInflightTTL(seconds int64)
|
||||
ClearExpiredInflight(expiry int64) error
|
||||
|
||||
ReadServerInfo() (v ServerInfo, err error)
|
||||
WriteServerInfo(v ServerInfo) error
|
||||
|
||||
ReadRetained() (v []Message, err error)
|
||||
WriteRetained(v Message) error
|
||||
DeleteRetained(id string) error
|
||||
}
|
||||
|
||||
// ServerInfo contains information and statistics about the server.
|
||||
@@ -69,6 +75,7 @@ type Message struct {
|
||||
ID string // the storage key.
|
||||
Client string // the id of the client who sent the message (if inflight).
|
||||
TopicName string // the topic the message was sent to (if retained).
|
||||
Created int64 // the time the message was created in unixtime (if inflight).
|
||||
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
|
||||
Resends int // the number of times the message was attempted to be sent (if inflight).
|
||||
PacketID uint16 // the unique id of the packet (if inflight).
|
||||
@@ -103,10 +110,16 @@ type LWT struct {
|
||||
|
||||
// MockStore is a mock storage backend for testing.
|
||||
type MockStore struct {
|
||||
Fail map[string]bool // issue errors for different methods.
|
||||
FailOpen bool // error on open.
|
||||
Closed bool // indicate mock store is closed.
|
||||
Opened bool // indicate mock store is open.
|
||||
Fail map[string]bool // issue errors for different methods.
|
||||
FailOpen bool // error on open.
|
||||
Closed bool // indicate mock store is closed.
|
||||
Opened bool // indicate mock store is open.
|
||||
inflightTTL int64 // inflight expiry duration.
|
||||
}
|
||||
|
||||
// Close closes the storage instance.
|
||||
func (s *MockStore) SetInflightTTL(seconds int64) {
|
||||
s.inflightTTL = seconds
|
||||
}
|
||||
|
||||
// Open opens the storage instance.
|
||||
@@ -275,7 +288,7 @@ func (s *MockStore) ReadRetained() (v []Message, err error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
//ReadServerInfo loads the server info from the storage instance.
|
||||
// ReadServerInfo loads the server info from the storage instance.
|
||||
func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
|
||||
if _, ok := s.Fail["read_info"]; ok {
|
||||
return v, errors.New("test_info")
|
||||
@@ -289,3 +302,8 @@ func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
|
||||
KServerInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadServerInfo loads the server info from the storage instance.
|
||||
func (s *MockStore) ClearExpiredInflight(d int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -26,6 +26,12 @@ func TestMockStoreClose(t *testing.T) {
|
||||
require.Equal(t, true, s.Closed)
|
||||
}
|
||||
|
||||
func TestMockStoreSetInflightTTL(t *testing.T) {
|
||||
s := new(MockStore)
|
||||
s.SetInflightTTL(5)
|
||||
require.Equal(t, int64(5), s.inflightTTL)
|
||||
}
|
||||
|
||||
func TestMockStoreWriteSubscription(t *testing.T) {
|
||||
s := new(MockStore)
|
||||
err := s.WriteSubscription(Subscription{})
|
||||
@@ -249,3 +255,9 @@ func TestMockStoreReadRetainedFail(t *testing.T) {
|
||||
_, err := s.ReadRetained()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMockStoreClearExpiredInflight(t *testing.T) {
|
||||
s := new(MockStore)
|
||||
err := s.ClearExpiredInflight(2)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
121
server/server.go
121
server/server.go
@@ -25,6 +25,9 @@ import (
|
||||
const (
|
||||
// Version indicates the current server version.
|
||||
Version = "1.1.1"
|
||||
|
||||
// defaultInflightTTL is the number of seconds a pending inflight message should last.
|
||||
defaultInflightTTL int64 = 60 * 60 * 24
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,17 +76,19 @@ var (
|
||||
// Server is an MQTT broker server. It should be created with server.New()
|
||||
// in order to ensure all the internal fields are correctly populated.
|
||||
type Server struct {
|
||||
inline inlineMessages // channels for direct publishing.
|
||||
Events events.Events // overrideable event hooks.
|
||||
Store persistence.Store // a persistent storage backend if desired.
|
||||
Options *Options // configurable server options.
|
||||
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
|
||||
Clients *clients.Clients // clients which are known to the broker.
|
||||
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
|
||||
System *system.Info // values about the server commonly found in $SYS topics.
|
||||
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
|
||||
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
|
||||
done chan bool // indicate that the server is ending.
|
||||
inline inlineMessages // channels for direct publishing.
|
||||
Events events.Events // overrideable event hooks.
|
||||
Store persistence.Store // a persistent storage backend if desired.
|
||||
Options *Options // configurable server options.
|
||||
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
|
||||
Clients *clients.Clients // clients which are known to the broker.
|
||||
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
|
||||
System *system.Info // values about the server commonly found in $SYS topics.
|
||||
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
|
||||
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
|
||||
inflightExpiryTicker *time.Ticker // the interval ticker for cleaning up expired messages.
|
||||
inflightResendTicker *time.Ticker // the interval ticker for resending unresolved inflight messages.
|
||||
done chan bool // indicate that the server is ending.
|
||||
}
|
||||
|
||||
// Options contains configurable options for the server.
|
||||
@@ -93,6 +98,9 @@ type Options struct {
|
||||
|
||||
// BufferBlockSize overrides the default buffer block size (DefaultBlockSize) for the client buffers.
|
||||
BufferBlockSize int
|
||||
|
||||
// InflightTTL specifies the duration that a queued inflight message should exist before being purged.
|
||||
InflightTTL int64
|
||||
}
|
||||
|
||||
// inlineMessages contains channels for handling inline (direct) publishing.
|
||||
@@ -114,6 +122,10 @@ func NewServer(opts *Options) *Server {
|
||||
opts = new(Options)
|
||||
}
|
||||
|
||||
if opts.InflightTTL < 1 {
|
||||
opts.InflightTTL = defaultInflightTTL
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
done: make(chan bool),
|
||||
bytepool: circ.NewBytesPool(opts.BufferSize),
|
||||
@@ -123,7 +135,9 @@ func NewServer(opts *Options) *Server {
|
||||
Version: Version,
|
||||
Started: time.Now().Unix(),
|
||||
},
|
||||
sysTicker: time.NewTicker(SysTopicInterval * time.Millisecond),
|
||||
sysTicker: time.NewTicker(SysTopicInterval * time.Millisecond),
|
||||
inflightExpiryTicker: time.NewTicker(time.Duration(opts.InflightTTL) * time.Second),
|
||||
inflightResendTicker: time.NewTicker(time.Duration(10) * time.Second),
|
||||
inline: inlineMessages{
|
||||
done: make(chan bool),
|
||||
pub: make(chan packets.Packet, 1024),
|
||||
@@ -143,6 +157,8 @@ func NewServer(opts *Options) *Server {
|
||||
// called before calling server.Server().
|
||||
func (s *Server) AddStore(p persistence.Store) error {
|
||||
s.Store = p
|
||||
s.Store.SetInflightTTL(s.Options.InflightTTL)
|
||||
|
||||
err := s.Store.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -198,6 +214,10 @@ func (s *Server) eventLoop() {
|
||||
return
|
||||
case <-s.sysTicker.C:
|
||||
s.publishSysTopics()
|
||||
case <-s.inflightExpiryTicker.C:
|
||||
s.clearExpiredInflights(time.Now().Unix())
|
||||
case <-s.inflightResendTicker.C:
|
||||
s.resendPendingInflights()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -319,9 +339,11 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
return s.onError(cl.Info(), fmt.Errorf("ack connection packet: %w", err))
|
||||
}
|
||||
|
||||
err = s.ResendClientInflight(cl, true)
|
||||
if err != nil {
|
||||
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
|
||||
if sessionPresent {
|
||||
err = s.ResendClientInflight(cl, true)
|
||||
if err != nil {
|
||||
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
|
||||
}
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
@@ -346,6 +368,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
|
||||
err = cl.StopCause() // Determine true cause of stop.
|
||||
|
||||
if cl.CleanSession {
|
||||
s.clearAbandonedInflights(cl)
|
||||
}
|
||||
|
||||
if s.Events.OnDisconnect != nil {
|
||||
s.Events.OnDisconnect(cl.Info(), err)
|
||||
}
|
||||
@@ -379,6 +405,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *clients.Client) boo
|
||||
// The state associated with a CleanSession MUST NOT be reused in any subsequent session.
|
||||
if pk.CleanSession || existing.CleanSession {
|
||||
s.unsubscribeClient(existing)
|
||||
s.clearAbandonedInflights(existing)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -400,6 +427,9 @@ func (s *Server) unsubscribeClient(cl *clients.Client) {
|
||||
for k := range cl.Subscriptions {
|
||||
delete(cl.Subscriptions, k)
|
||||
if s.Topics.Unsubscribe(k, cl.ID) {
|
||||
if s.Events.OnUnsubscribe != nil {
|
||||
s.Events.OnUnsubscribe(k, cl.Info())
|
||||
}
|
||||
atomic.AddInt64(&s.System.Subscriptions, -1)
|
||||
}
|
||||
}
|
||||
@@ -637,8 +667,9 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
// if an appropriate ack is not received (or if the client is offline).
|
||||
sent := time.Now().Unix()
|
||||
q := client.Inflight.Set(out.PacketID, clients.InflightMessage{
|
||||
Packet: out,
|
||||
Sent: sent,
|
||||
Packet: out,
|
||||
Created: time.Now().Unix(),
|
||||
Sent: sent,
|
||||
})
|
||||
if q {
|
||||
atomic.AddInt64(&s.System.Inflight, 1)
|
||||
@@ -735,8 +766,11 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
|
||||
if !cl.AC.ACL(cl.Username, pk.Topics[i], false) {
|
||||
retCodes[i] = packets.ErrSubAckNetworkError
|
||||
} else {
|
||||
q := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
|
||||
if q {
|
||||
r := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
|
||||
if r {
|
||||
if s.Events.OnSubscribe != nil {
|
||||
s.Events.OnSubscribe(pk.Topics[i], cl.Info(), pk.Qoss[i])
|
||||
}
|
||||
atomic.AddInt64(&s.System.Subscriptions, 1)
|
||||
}
|
||||
cl.NoteSubscription(pk.Topics[i], pk.Qoss[i])
|
||||
@@ -785,6 +819,9 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
|
||||
for i := 0; i < len(pk.Topics); i++ {
|
||||
q := s.Topics.Unsubscribe(pk.Topics[i], cl.ID)
|
||||
if q {
|
||||
if s.Events.OnUnsubscribe != nil {
|
||||
s.Events.OnUnsubscribe(pk.Topics[i], cl.Info())
|
||||
}
|
||||
atomic.AddInt64(&s.System.Subscriptions, -1)
|
||||
}
|
||||
cl.ForgetSubscription(pk.Topics[i])
|
||||
@@ -867,7 +904,7 @@ func (s *Server) publishSysTopics() {
|
||||
// ResendClientInflight attempts to resend all undelivered inflight messages
|
||||
// to a client.
|
||||
func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
|
||||
if cl.Inflight.Len() == 0 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 || cl.Inflight.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1004,9 +1041,13 @@ func (s *Server) loadServerInfo(v persistence.ServerInfo) {
|
||||
// loadSubscriptions restores subscriptions from the datastore.
|
||||
func (s *Server) loadSubscriptions(v []persistence.Subscription) {
|
||||
for _, sub := range v {
|
||||
s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS)
|
||||
if cl, ok := s.Clients.Get(sub.Client); ok {
|
||||
cl.NoteSubscription(sub.Filter, sub.QoS)
|
||||
if s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) {
|
||||
if cl, ok := s.Clients.Get(sub.Client); ok {
|
||||
cl.NoteSubscription(sub.Filter, sub.QoS)
|
||||
if s.Events.OnSubscribe != nil {
|
||||
s.Events.OnSubscribe(sub.Filter, cl.Info(), sub.QoS)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1034,6 +1075,7 @@ func (s *Server) loadInflight(v []persistence.Message) {
|
||||
TopicName: msg.TopicName,
|
||||
Payload: msg.Payload,
|
||||
},
|
||||
Created: msg.Created,
|
||||
Sent: msg.Sent,
|
||||
Resends: msg.Resends,
|
||||
})
|
||||
@@ -1051,3 +1093,36 @@ func (s *Server) loadRetained(v []persistence.Message) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// clearExpiredInflights deletes all inflight messages older than server inflight TTL.
|
||||
func (s *Server) clearExpiredInflights(dt int64) {
|
||||
expiry := dt - s.Options.InflightTTL
|
||||
|
||||
for _, client := range s.Clients.GetAll() {
|
||||
deleted := client.Inflight.ClearExpired(expiry)
|
||||
atomic.AddInt64(&s.System.Inflight, deleted*-1)
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.ClearExpiredInflight(expiry)
|
||||
}
|
||||
}
|
||||
|
||||
// clearAbandonedInflights deletes all inflight messages for a disconnected user (eg. with a clean session).
|
||||
func (s *Server) clearAbandonedInflights(cl *clients.Client) {
|
||||
for i := range cl.Inflight.GetAll() {
|
||||
cl.Inflight.Delete(i)
|
||||
atomic.AddInt64(&s.System.Inflight, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// resendPendingInflights attempts resends of any pending and due inflight messages.
|
||||
func (s *Server) resendPendingInflights() {
|
||||
for _, client := range s.Clients.GetAll() {
|
||||
err := s.ResendClientInflight(client, false)
|
||||
if err != nil {
|
||||
// TODO add log-level debugging.
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,9 +368,10 @@ func TestServerEventOnConnect(t *testing.T) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
require.Equal(t, events.Packet(packets.Packet{
|
||||
@@ -443,9 +444,10 @@ func TestServerEventOnDisconnect(t *testing.T) {
|
||||
w.Close()
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
require.ErrorIs(t, ErrClientDisconnect, hook.err)
|
||||
@@ -506,9 +508,10 @@ func TestServerEventOnDisconnectOnError(t *testing.T) {
|
||||
require.Equal(t, errx, hook.err)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
@@ -1946,6 +1949,15 @@ func TestServerProcessSubscribeInvalid(t *testing.T) {
|
||||
func TestServerProcessSubscribe(t *testing.T) {
|
||||
s, cl, r, w := setupClient()
|
||||
|
||||
subscribeEvent := ""
|
||||
subscribeClient := ""
|
||||
s.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
|
||||
if filter == "a/b/c" {
|
||||
subscribeEvent = "a/b/c"
|
||||
subscribeClient = cl.ID
|
||||
}
|
||||
}
|
||||
|
||||
s.Topics.RetainMessage(packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
@@ -1995,6 +2007,8 @@ func TestServerProcessSubscribe(t *testing.T) {
|
||||
require.Equal(t, byte(1), cl.Subscriptions["d/e/f"])
|
||||
require.Equal(t, topics.Subscriptions{cl.ID: 0}, s.Topics.Subscribers("a/b/c"))
|
||||
require.Equal(t, topics.Subscriptions{cl.ID: 1}, s.Topics.Subscribers("d/e/f"))
|
||||
require.Equal(t, "a/b/c", subscribeEvent)
|
||||
require.Equal(t, cl.ID, subscribeClient)
|
||||
}
|
||||
|
||||
func TestServerProcessSubscribeFailACL(t *testing.T) {
|
||||
@@ -2114,6 +2128,16 @@ func TestServerProcessUnsubscribeInvalid(t *testing.T) {
|
||||
|
||||
func TestServerProcessUnsubscribe(t *testing.T) {
|
||||
s, cl, r, w := setupClient()
|
||||
|
||||
unsubscribeEvent := ""
|
||||
unsubscribeClient := ""
|
||||
s.Events.OnUnsubscribe = func(filter string, cl events.Client) {
|
||||
if filter == "a/b/c" {
|
||||
unsubscribeEvent = "a/b/c"
|
||||
unsubscribeClient = cl.ID
|
||||
}
|
||||
}
|
||||
|
||||
s.Clients.Add(cl)
|
||||
s.Topics.Subscribe("a/b/c", cl.ID, 0)
|
||||
s.Topics.Subscribe("d/e/f", cl.ID, 1)
|
||||
@@ -2155,6 +2179,9 @@ func TestServerProcessUnsubscribe(t *testing.T) {
|
||||
|
||||
require.NotEmpty(t, s.Topics.Subscribers("a/b/+"))
|
||||
require.Contains(t, cl.Subscriptions, "a/b/+")
|
||||
|
||||
require.Equal(t, "a/b/c", unsubscribeEvent)
|
||||
require.Equal(t, cl.ID, unsubscribeClient)
|
||||
}
|
||||
|
||||
func TestServerProcessUnsubscribeWriteError(t *testing.T) {
|
||||
@@ -2677,3 +2704,76 @@ func TestServerResendClientInflightError(t *testing.T) {
|
||||
err := s.ResendClientInflight(cl, true)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestServerClearExpiredInflights(t *testing.T) {
|
||||
n := time.Now().Unix()
|
||||
|
||||
s := New()
|
||||
s.Options.InflightTTL = 2
|
||||
require.NotNil(t, s)
|
||||
|
||||
r, _ := net.Pipe()
|
||||
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
|
||||
cl.Inflight.Set(1, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 1,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(2, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 2,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(3, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 3,
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(5, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Created: n - 5,
|
||||
Sent: 0,
|
||||
})
|
||||
s.Clients.Add(cl)
|
||||
|
||||
require.Len(t, cl.Inflight.GetAll(), 4)
|
||||
s.clearExpiredInflights(n)
|
||||
require.Len(t, cl.Inflight.GetAll(), 2)
|
||||
require.Equal(t, int64(-2), s.System.Inflight)
|
||||
}
|
||||
|
||||
func TestServerClearAbandonedInflights(t *testing.T) {
|
||||
s := New()
|
||||
require.NotNil(t, s)
|
||||
|
||||
r, _ := net.Pipe()
|
||||
cl := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
|
||||
cl.Inflight.Set(1, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Sent: 0,
|
||||
})
|
||||
cl.Inflight.Set(2, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Sent: 0,
|
||||
})
|
||||
|
||||
cl2 := clients.NewClient(r, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
|
||||
|
||||
cl2.Inflight.Set(3, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Sent: 0,
|
||||
})
|
||||
cl2.Inflight.Set(5, clients.InflightMessage{
|
||||
Packet: packets.Packet{},
|
||||
Sent: 0,
|
||||
})
|
||||
s.Clients.Add(cl)
|
||||
s.Clients.Add(cl2)
|
||||
|
||||
require.Len(t, cl.Inflight.GetAll(), 2)
|
||||
require.Len(t, cl2.Inflight.GetAll(), 2)
|
||||
s.clearAbandonedInflights(cl)
|
||||
require.Len(t, cl.Inflight.GetAll(), 0)
|
||||
require.Len(t, cl2.Inflight.GetAll(), 2)
|
||||
require.Equal(t, int64(-2), s.System.Inflight)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user