Compare commits

..

28 Commits

Author SHA1 Message Date
mochi-co
1e8f922102 update server version 2023-02-20 18:14:57 +00:00
Hubertus Hohl
4c16e5593f fix: correct decoding of packets including Properties exceeding 127 bytes in length (#172) 2023-02-20 18:14:19 +00:00
mochi-co
49cada4fbc Update server version 2023-02-10 23:39:27 +00:00
JB
ef34510c0b Expose dropped publish messages count in sys info (#170) 2023-02-10 23:38:20 +00:00
JB
e5716caad1 Fix potential NextPacketID endless loop, expand tests (#169)
* Fix possible NextPacketID endless loop, expand tests

* Optimize NextPacketID

* Use math constants
2023-02-10 23:27:21 +00:00
thedevop
4b039cb35c Add PublishDropped metrics (#167)
* Add PublishDropped

* Add PublishDropped

* Add PublishDropped

* Update storage_test.go

* Update system.go

* Update server.go
2023-02-10 14:44:01 +00:00
JB
aac245441a No longer issue retained messages on session takeover (#166) 2023-02-09 23:57:24 +00:00
JB
bb54cc68e6 Client write buffers (#165)
* Replace fanpool with client write buffers
2023-02-09 22:34:30 +00:00
thedevop
7ba1352a60 Add Clone to system.Info (#163)
* Add Clone using atomic operations

* Add Clone using atomic operations

* Use sysinfo.Clone

* Unit test for Clone

* Add Clone using atomic operations

* Update

* Update
2023-02-09 19:07:17 +00:00
mochi-co
ca849131eb Update server version 2023-02-05 11:07:07 +00:00
Wind
ba7e534122 failed to delete inflight data (#162)
The s.hooks.OnQosPublish method needs to be called, otherwise the following s.hooks.OnQosComplete or processPuback(s.hooks.OnQosComplete) method will report a data not found error.
2023-02-05 10:53:49 +00:00
mochi-co
db760c34a5 Update server version 2023-02-04 10:57:27 +00:00
JB
ae3ee81bb4 Rename Quota methods for clarity (#159) 2023-02-04 10:53:45 +00:00
JB
c2ca02d149 Move refreshDeadline to only trigger on successful transmission (#157) 2023-02-04 10:16:05 +00:00
Jeroen Rinzema
77a64d9c87 Include a listener accepting an existing net.Listener (#155) 2023-02-04 10:10:10 +00:00
Wind
8dec9cc962 invalid config type provided (#152)
* invalid config type provided

examples/persistence/bolt/main.go: invalid config type provided

* fixed ErrReceiveMaximum(receive maximum exceeded)

No quotas of the inflight is set in the readStore method, so each quota is equal to 0. The inheritClientSession method overrides the quotas of the new client inflight, so the processPublish method reports an ErrReceiveMaximum and disconnects the client.

* reset receive quota

receive quota should be reset across connections (as specified in the spec).
2023-02-04 10:06:26 +00:00
mochi-co
f90e52328d Update server version 2023-01-16 20:08:55 +00:00
JB
50aae47618 Publish retained messages only after connack (#147) 2023-01-16 19:50:01 +00:00
JB
0d79f2d63b Use Atomic instead of RWMutex for Hooks concurrency (#148)
* Use Atomic instead of RWMutex for Hooks concurrency
* Lock Hooks on Add Hook
2023-01-16 19:49:36 +00:00
JB
300152413c Ignore retain as published v3 (#142)
* Optimise Capabilities struct alignment

* Only use RetainAsPublished for v5 clients
2023-01-13 23:38:49 +00:00
mochi-co
0de1d731db Update version number 2023-01-10 00:01:21 +00:00
JB
80746abc52 Use correct connack return codes for MQTTv3 (#140) 2023-01-10 00:00:43 +00:00
mochi-co
a73cf4ca0e Update server version 2023-01-09 23:08:49 +00:00
mochi-co
bc549ee7ed Fix example imports 2023-01-09 22:52:24 +00:00
mochi-co
c464b46713 export client.Net.Conn for external use 2023-01-09 22:49:40 +00:00
mochi-co
05ce56008c Small code improvements 2023-01-09 22:49:20 +00:00
JB
8254cb0cbc Make hooks safe for concurrency (#139)
Co-authored-by: thedevop <60499013+thedevop@users.noreply.github.com>
2023-01-09 22:41:44 +00:00
mochi-co
4ae58b79e3 Update server version 2023-01-07 20:13:48 +00:00
30 changed files with 719 additions and 478 deletions

View File

@@ -41,10 +41,10 @@ import "github.com/mochi-co/mqtt/v2"
- Direct Packet Injection using special inline client, or masquerade as existing clients.
- Performant and Stable:
- Our classic trie-based Topic-Subscription model.
- A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
- Over a thousand carefully considered unit test scenarios.
- TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners.
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
@@ -83,7 +83,11 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
Importing Mochi MQTT as a package requires just a few lines of code to get started.
``` go
import (
"log"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
@@ -94,7 +98,7 @@ func main() {
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
@@ -116,6 +120,7 @@ The server comes with a variety of pre-packaged network listeners which allow th
| --- | --- |
| listeners.NewTCP | A TCP listener |
| listeners.NewUnixSock | A Unix Socket listener |
| listeners.NewNet | A net.Listener listener |
| listeners.NewWebsocket | A Websocket listener |
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
@@ -292,6 +297,7 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
| OnRetainMessage | Called then a published message is retained. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
@@ -361,23 +367,23 @@ Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inov
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 |
| Mochi v2.2.0 | 127,216 | 125,748 | 124,279 | 319,250 | 309,327 | 299,405 |
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 |
| Mochi v2.2.0 | 45,615 | 30,129 | 21,138 | 232,717 | 86,323 | 50,402 |
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
Million Message Challenge (hit the server with 1 million messages immediately):
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636 |
| Mochi v2.2.0 | 51,044 | 4,682 | 2,345 | 72,634 | 7,645 | 2,464 |
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |

View File

@@ -106,7 +106,7 @@ type Client struct {
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
conn net.Conn // the net.Conn used to establish the connection
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
@@ -135,15 +135,17 @@ type Will struct {
// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
endOnce sync.Once // only end once
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
keepalive uint16 // the number of seconds the connection can wait
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
}
// newClient returns a new instance of Client. This is almost exclusively used by Server
@@ -155,6 +157,7 @@ func newClient(c net.Conn, o *ops) *Client {
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
outbound: make(chan packets.Packet, o.capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
@@ -164,7 +167,7 @@ func newClient(c net.Conn, o *ops) *Client {
if c != nil {
cl.Net = ClientConnection{
conn: c,
Conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Remote: c.RemoteAddr().String(),
}
@@ -175,6 +178,16 @@ func newClient(c net.Conn, o *ops) *Client {
return cl
}
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for pk := range cl.State.outbound {
if err := cl.WritePacket(pk); err != nil {
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
}
atomic.AddInt32(&cl.State.outboundQty, -1)
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
@@ -223,13 +236,13 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.Net.conn != nil {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
@@ -237,28 +250,30 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i + 1
started := i
overflowed := false
for {
if i >= 65535 {
overflowed = true
i = 1
} else {
i++
}
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.capabilities.maximumPacketID {
overflowed = true
i = 0
continue
}
i++
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
break
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
}
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
@@ -272,7 +287,7 @@ func (cl *Client) ResendInflightMessages(force bool) error {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
// cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends)
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
err := cl.WritePacket(tk)
if err != nil {
return err
@@ -297,7 +312,7 @@ func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted = append(deleted, uint16(tk.PacketID))
deleted = append(deleted, tk.PacketID)
}
}
}
@@ -341,8 +356,8 @@ func (cl *Client) Stop(err error) {
}
cl.State.endOnce.Do(func() {
if cl.Net.conn != nil {
_ = cl.Net.conn.Close() // omit close error
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
@@ -464,11 +479,10 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
return ErrConnectionClosed
}
if cl.Net.conn == nil {
if cl.Net.Conn == nil {
return nil
}
defer cl.refreshDeadline(cl.State.keepalive)
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
@@ -533,7 +547,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.conn)
n, err := nb.WriteTo(cl.Net.Conn)
if err != nil {
return err
}

View File

@@ -30,8 +30,10 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
})
@@ -42,6 +44,9 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
go cl.WriteLoop()
return
}
@@ -127,7 +132,7 @@ func TestNewClient(t *testing.T) {
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.conn)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.False(t, cl.Net.Inline)
}
@@ -237,28 +242,32 @@ func TestClientNextPacketIDInUse(t *testing.T) {
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newTestClient()
for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
for i := uint32(1); i <= cl.ops.capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
}
i, err := cl.NextPacketID()
require.Equal(t, uint32(0), i)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, uint32(0), i)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(0); i < cl.ops.capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = uint32(65534)
cl.State.packetID = uint32(cl.ops.capabilities.maximumPacketID - 1)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(65535), i)
require.Equal(t, cl.ops.capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.capabilities.maximumPacketID)] = packets.Packet{}
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
cl.State.packetID = cl.ops.capabilities.maximumPacketID
_, err = cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
}
func TestClientClearInflights(t *testing.T) {
@@ -320,7 +329,7 @@ func TestClientResendInflightMessagesNoMessages(t *testing.T) {
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newTestClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
@@ -586,7 +595,7 @@ func TestClientReadPacket(t *testing.T) {
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.conn.Close()
cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
@@ -610,7 +619,7 @@ func TestClientWritePacket(t *testing.T) {
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.conn.Close()
cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
@@ -692,7 +701,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.conn.Close()
cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)

View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -36,7 +36,7 @@ func main() {
}
err = server.AddHook(new(debug.Hook), &debug.Options{
ShowPacketData: true,
// ShowPacketData: true,
})
if err != nil {
log.Fatal(err)

View File

@@ -30,12 +30,15 @@ func main() {
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(bolt.Hook), bolt.Options{
err := server.AddHook(new(bolt.Hook), &bolt.Options{
Path: "bolt.db",
Options: &bbolt.Options{
Timeout: 500 * time.Millisecond,
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)

View File

@@ -1,101 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
package mqtt
import (
"sync"
"sync/atomic"
xh "github.com/cespare/xxhash/v2"
)
// taskChan is a channel for incoming task functions.
type taskChan chan func()
// FanPool is a fixed-sized fan-style worker pool with multiple
// working 'columns'. Instead of a single queue channel processed by
// many goroutines, this fan pool uses many queue channels each
// processed by a single goroutine.
// Very special thanks are given to the authors of HMQ in particular
// @chowyu08 and @muXxer for their work on the fixpool worker pool
// https://github.com/fhmq/hmq/blob/master/pool/fixpool.go
// from which this fan-pool is heavily inspired.
type FanPool struct {
queue []taskChan
wg sync.WaitGroup
capacity uint64
perChan uint64
Mutex sync.Mutex
}
// New returns a new instance of FanPool. fanSize controls the number of 'columns'
// of the fan, whereas queueSize controls the size of each column's queue.
func NewFanPool(fanSize, queueSize uint64) *FanPool {
pool := &FanPool{
capacity: fanSize,
perChan: queueSize,
queue: make([]taskChan, fanSize),
}
pool.fillWorkers(fanSize)
return pool
}
// fillWorkers adds columns to the fan pool with an associated worker goroutine.
func (p *FanPool) fillWorkers(n uint64) {
for i := uint64(0); i < n; i++ {
p.queue[i] = make(taskChan, p.perChan)
go p.worker(p.queue[i])
p.wg.Add(1)
}
}
// worker is a worker goroutine which processes tasks from a single queue.
func (p *FanPool) worker(ch taskChan) {
defer p.wg.Done()
var task func()
var ok bool
for {
task, ok = <-ch
if !ok {
return
}
task()
}
}
// Enqueue adds a new task to the queue to be processed.
func (p *FanPool) Enqueue(id string, task func()) {
if p.Size() == 0 {
return
}
// We can use xh.Sum64 to get a specific queue index
// which remains the same for a client id, giving each
// client their own queue.
p.queue[xh.Sum64([]byte(id))%p.Size()] <- task
}
// Wait blocks until all the workers in the pool have completed.
func (p *FanPool) Wait() {
p.wg.Wait()
}
// Close issues a shutdown signal to the workers.
func (p *FanPool) Close() {
for i := 0; i < int(p.Size()); i++ {
if p.queue[i] != nil {
close(p.queue[i])
}
}
p.queue = nil
atomic.StoreUint64(&p.capacity, 0)
}
// Size returns the current number of workers in the pool.
func (p *FanPool) Size() uint64 {
return atomic.LoadUint64(&p.capacity)
}

View File

@@ -1,89 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestFanPool(t *testing.T) {
f := NewFanPool(1, 2)
require.NotNil(t, f)
require.Equal(t, uint64(1), f.capacity)
require.Equal(t, 2, cap(f.queue[0]))
o := make(chan bool)
go func() {
f.Enqueue("test", func() {
o <- true
})
}()
require.True(t, <-o)
f.Close()
f.Wait()
}
func TestFillWorkers(t *testing.T) {
f := &FanPool{
perChan: 3,
queue: make([]taskChan, 2),
}
f.fillWorkers(2)
require.Len(t, f.queue, 2)
require.Equal(t, 3, cap(f.queue[0]))
}
func TestEnqueue(t *testing.T) {
f := &FanPool{
capacity: 2,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
},
}
go func() {
f.Enqueue("a", func() {})
}()
require.NotNil(t, <-f.queue[1])
}
func TestEnqueueOnEmpty(t *testing.T) {
f := &FanPool{
queue: []taskChan{},
}
go func() {
f.Enqueue("a", func() {})
}()
require.Len(t, f.queue, 0)
}
func TestSize(t *testing.T) {
f := &FanPool{
capacity: 10,
}
require.Equal(t, uint64(10), f.Size())
}
func TestClose(t *testing.T) {
f := &FanPool{
capacity: 3,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
make(taskChan, 2),
},
}
f.Close()
require.Equal(t, uint64(0), f.Size())
require.Nil(t, f.queue)
}

2
go.mod
View File

@@ -6,7 +6,6 @@ require (
github.com/alicebob/miniredis/v2 v2.23.0
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/cespare/xxhash/v2 v2.1.2
github.com/go-redis/redis/v8 v8.11.5
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
@@ -21,6 +20,7 @@ require (
require (
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/badger v1.6.0 // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect

112
hooks.go
View File

@@ -1,6 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// SPDX-FileContributor: mochi-co, thedevop
package mqtt
@@ -39,6 +39,7 @@ const (
OnUnsubscribed
OnPublish
OnPublished
OnPublishDropped
OnRetainMessage
OnQosPublish
OnQosComplete
@@ -87,6 +88,7 @@ type Hook interface {
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnPublishDropped(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
@@ -110,10 +112,10 @@ type HookOptions struct {
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal []Hook // a slice of hooks
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
@@ -123,7 +125,7 @@ func (h *Hooks) Len() int64 {
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
for _, hb := range b {
if hook.Provides(hb) {
return true
@@ -138,26 +140,39 @@ func (h *Hooks) Provides(b ...byte) bool {
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
if h.internal == nil {
h.internal = []Hook{}
}
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
h.internal = append(h.internal, hook)
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
if err := hook.Stop(); err != nil {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
@@ -172,7 +187,7 @@ func (h *Hooks) Stop() {
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
@@ -181,7 +196,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) {
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
@@ -190,7 +205,7 @@ func (h *Hooks) OnStarted() {
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
@@ -199,7 +214,7 @@ func (h *Hooks) OnStopped() {
// OnConnect is called when a new client connects.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnect) {
hook.OnConnect(cl, pk)
}
@@ -208,7 +223,7 @@ func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
@@ -217,7 +232,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
@@ -227,7 +242,7 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
@@ -248,7 +263,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
@@ -264,7 +279,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet,
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
@@ -275,7 +290,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
@@ -285,7 +300,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
@@ -297,7 +312,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
@@ -307,7 +322,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
@@ -319,7 +334,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
@@ -332,7 +347,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
@@ -342,7 +357,7 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
@@ -354,7 +369,7 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
@@ -373,16 +388,26 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
}
}
// OnPublishDropped is called when a message to a client was dropped instead of delivered
// such as when a client is too slow to respond.
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublishDropped) {
hook.OnPublishDropped(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
@@ -393,7 +418,7 @@ func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
@@ -404,7 +429,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
@@ -415,7 +440,7 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
@@ -427,7 +452,7 @@ func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
@@ -443,7 +468,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
@@ -452,7 +477,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
@@ -461,7 +486,7 @@ func (h *Hooks) OnClientExpired(cl *Client) {
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
@@ -471,7 +496,7 @@ func (h *Hooks) OnRetainedExpired(filter string) {
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
@@ -491,7 +516,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
@@ -511,7 +536,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
@@ -531,7 +556,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
@@ -550,7 +575,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
@@ -572,7 +597,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
@@ -588,7 +613,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
@@ -714,6 +739,9 @@ func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, err
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}

View File

@@ -104,6 +104,7 @@ var (
ClientsMaximum: 7,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
PacketsReceived: 12,
PacketsSent: 13,
Retained: 15,
@@ -111,7 +112,7 @@ var (
InflightDropped: 17,
},
}
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
)
func TestClientMarshalBinary(t *testing.T) {

View File

@@ -27,6 +27,10 @@ type modifiedHookBase struct {
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
@@ -178,12 +182,20 @@ func TestHooksProvides(t *testing.T) {
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddAndLen(t *testing.T) {
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {
@@ -224,6 +236,7 @@ func TestHooksNonReturns(t *testing.T) {
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnPublishDropped(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})

View File

@@ -104,14 +104,14 @@ func (i *Inflight) Delete(id uint16) bool {
}
// TakeRecieveQuota reduces the receive quota by 1.
func (i *Inflight) TakeReceiveQuota() {
func (i *Inflight) DecreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) > 0 {
atomic.AddInt32(&i.receiveQuota, -1)
}
}
// TakeRecieveQuota increases the receive quota by 1.
func (i *Inflight) ReturnReceiveQuota() {
func (i *Inflight) IncreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
atomic.AddInt32(&i.receiveQuota, 1)
}
@@ -123,15 +123,15 @@ func (i *Inflight) ResetReceiveQuota(n int32) {
atomic.StoreInt32(&i.maximumReceiveQuota, n)
}
// TakeSendQuota reduces the send quota by 1.
func (i *Inflight) TakeSendQuota() {
// DecreaseSendQuota reduces the send quota by 1.
func (i *Inflight) DecreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) > 0 {
atomic.AddInt32(&i.sendQuota, -1)
}
}
// ReturnSendQuota increases the send quota by 1.
func (i *Inflight) ReturnSendQuota() {
// IncreaseSendQuota increases the send quota by 1.
func (i *Inflight) IncreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
atomic.AddInt32(&i.sendQuota, 1)
}

View File

@@ -95,12 +95,12 @@ func TestReceiveQuota(t *testing.T) {
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
// Return 1
i.ReturnReceiveQuota()
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Try to go over max limit
i.ReturnReceiveQuota()
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
@@ -110,12 +110,12 @@ func TestReceiveQuota(t *testing.T) {
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
// Take 1
i.TakeReceiveQuota()
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
// Try to go below zero
i.TakeReceiveQuota()
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
}
@@ -137,12 +137,12 @@ func TestSendQuota(t *testing.T) {
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
// Return 1
i.ReturnSendQuota()
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Try to go over max limit
i.ReturnSendQuota()
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
@@ -152,12 +152,12 @@ func TestSendQuota(t *testing.T) {
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
// Take 1
i.TakeSendQuota()
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
// Try to go below zero
i.TakeSendQuota()
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
}

View File

@@ -107,28 +107,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info := &system.Info{
Version: l.sysInfo.Version,
Started: atomic.LoadInt64(&l.sysInfo.Started),
Time: atomic.LoadInt64(&l.sysInfo.Time),
Uptime: atomic.LoadInt64(&l.sysInfo.Uptime),
BytesReceived: atomic.LoadInt64(&l.sysInfo.BytesReceived),
BytesSent: atomic.LoadInt64(&l.sysInfo.BytesSent),
ClientsConnected: atomic.LoadInt64(&l.sysInfo.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&l.sysInfo.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&l.sysInfo.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&l.sysInfo.MessagesReceived),
MessagesSent: atomic.LoadInt64(&l.sysInfo.MessagesSent),
InflightDropped: atomic.LoadInt64(&l.sysInfo.InflightDropped),
Subscriptions: atomic.LoadInt64(&l.sysInfo.Subscriptions),
PacketsReceived: atomic.LoadInt64(&l.sysInfo.PacketsReceived),
PacketsSent: atomic.LoadInt64(&l.sysInfo.PacketsSent),
Retained: atomic.LoadInt64(&l.sysInfo.Retained),
Inflight: atomic.LoadInt64(&l.sysInfo.Inflight),
MemoryAlloc: atomic.LoadInt64(&l.sysInfo.MemoryAlloc),
Threads: atomic.LoadInt64(&l.sysInfo.Threads),
}
info := *l.sysInfo.Clone()
out, err := json.MarshalIndent(info, "", "\t")
if err != nil {

88
listeners/net.go Normal file
View File

@@ -0,0 +1,88 @@
package listeners
import (
"net"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// Net is a listener for establishing client connections on basic TCP protocol.
type Net struct { // [MQTT-4.2.0-1]
mu sync.Mutex
listener net.Listener // a net.Listener which will listen for new clients
id string // the internal id of the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
func NewNet(id string, listener net.Listener) *Net {
return &Net{
id: id,
listener: listener,
}
}
// ID returns the id of the listener.
func (l *Net) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Net) Address() string {
return l.listener.Addr().String()
}
// Protocol returns the network of the listener.
func (l *Net) Protocol() string {
return l.listener.Addr().Network()
}
// Init initializes the listener.
func (l *Net) Init(log *zerolog.Logger) error {
l.log = log
return nil
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *Net) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listener.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *Net) Close(closeClients CloseFn) {
l.mu.Lock()
defer l.mu.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listener != nil {
err := l.listener.Close()
if err != nil {
return
}
}
}

105
listeners/net_test.go Normal file
View File

@@ -0,0 +1,105 @@
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewNet(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.id)
}
func TestNetID(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.ID())
}
func TestNetAddress(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, n.Addr().String(), l.Address())
}
func TestNetProtocol(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "tcp", l.Protocol())
}
func TestNetInit(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
}
func TestNetServeAndClose(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestNetEstablishThenEnd(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", n.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@@ -154,7 +154,7 @@ func (ws *wsConn) Read(p []byte) (int, error) {
br, err = r.Read(p[n:])
n += br
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
return n, err

View File

@@ -376,7 +376,7 @@ func TestEncodeUint16(t *testing.T) {
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result)
result = encodeUint16(65535)
result = encodeUint16(math.MaxUint16)
require.Equal(t, []byte{0xff, 0xff}, result)
}

View File

@@ -113,6 +113,7 @@ var (
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
@@ -124,4 +125,23 @@ var (
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
// MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
Err3ClientIdentifierNotValid = Code{Code: 0x02}
Err3ServerUnavailable = Code{Code: 0x03}
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
Err3NotAuthorized = Code{Code: 0x05}
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
// This is required because MQTTv3 has different return byte specification.
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
V5CodesToV3 = map[Code]Code{
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
ErrServerUnavailable: Err3ServerUnavailable,
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
ErrBadUsernameOrPassword: Err3NotAuthorized,
}
)

View File

@@ -8,6 +8,7 @@ import (
"bytes"
"errors"
"fmt"
"math"
"strconv"
"strings"
"sync"
@@ -208,7 +209,10 @@ func (pk *Packet) Copy(allowTransfer bool) Packet {
Created: pk.Created,
Expiry: pk.Expiry,
Origin: pk.Origin,
PacketID: pk.PacketID, // ... ? Packet ID must not be transferred (in this manner)
}
if allowTransfer {
p.PacketID = pk.PacketID
}
if len(pk.Connect.ProtocolName) > 0 {
@@ -379,7 +383,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) //[MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
@@ -393,7 +397,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if err != nil {
return ErrMalformedWillProperties
}
offset += n + 1
offset += n
}
pk.Connect.WillTopic, offset, err = decodeString(buf, offset)
@@ -439,11 +443,11 @@ func (pk *Packet) ConnectValidate() Code {
return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3]
}
if len(pk.Connect.Password) > 65535 {
if len(pk.Connect.Password) > math.MaxUint16 {
return ErrProtocolViolationPasswordTooLong
}
if len(pk.Connect.Username) > 65535 {
if len(pk.Connect.Username) > math.MaxUint16 {
return ErrProtocolViolationUsernameTooLong
}
@@ -463,7 +467,7 @@ func (pk *Packet) ConnectValidate() Code {
return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18]
}
if len(pk.Connect.ClientIdentifier) > 65535 {
if len(pk.Connect.ClientIdentifier) > math.MaxUint16 {
return ErrClientIdentifierNotValid
}
@@ -640,7 +644,7 @@ func (pk *Packet) PublishDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.Payload = buf[offset:]
@@ -857,7 +861,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.ReasonCodes = buf[offset:]
@@ -914,7 +918,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
var filter string
@@ -1010,7 +1014,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
pk.ReasonCodes = buf[offset:]
}
@@ -1062,7 +1066,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
var filter string

View File

@@ -464,6 +464,9 @@ func TestCopy(t *testing.T) {
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
pkcc := tt.Packet.Copy(false)
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
}
}

View File

@@ -366,13 +366,14 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
return 0, nil
}
n, _, err = DecodeLength(b)
var bu int
n, bu, err = DecodeLength(b)
if err != nil {
return n, err
return n + bu, err
}
if n == 0 {
return n, nil
return n + bu, nil
}
bt := b.Bytes()
@@ -380,11 +381,11 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
for offset := 0; offset < n; {
k, offset, err = decodeByte(bt, offset)
if err != nil {
return n, err
return n + bu, err
}
if _, ok := validPacketProperties[k][pk]; !ok {
return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
}
switch k {
@@ -406,7 +407,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
if err != nil {
return n, err
return n + bu, err
}
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
offset += bu
@@ -452,7 +453,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
var k, v string
k, offset, err = decodeString(bt, offset)
if err != nil {
return n, err
return n + bu, err
}
v, offset, err = decodeString(bt, offset)
p.User = append(p.User, UserProperty{Key: k, Val: v})
@@ -470,9 +471,9 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
}
if err != nil {
return n, err
return n + bu, err
}
}
return n, nil
return n + bu, nil
}

View File

@@ -250,7 +250,7 @@ func TestDecodeProperties(t *testing.T) {
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172, n)
require.Equal(t, 172 + 2, n)
require.EqualValues(t, propertiesStruct, *props)
}

View File

@@ -89,6 +89,7 @@ const (
TConnackServerUnavailable
TConnackBadUsernamePassword
TConnackBadUsernamePasswordNoSession
TConnackMqtt5BadUsernamePasswordNoSession
TConnackNotAuthorised
TConnackMalSessionPresent
TConnackMalReturnCode
@@ -1316,10 +1317,28 @@ var TPacketData = map[byte]TPacketCases{
Desc: "bad username or password no session",
RawBytes: []byte{
Connack << 4, 2, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0, // No session present
Err3NotAuthorized.Code, // use v3 remapping
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
},
ReasonCode: Err3NotAuthorized.Code,
},
},
{
Case: TConnackMqtt5BadUsernamePasswordNoSession,
Desc: "mqtt5 bad username or password no session",
RawBytes: []byte{
Connack << 4, 3, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0,
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
@@ -1327,6 +1346,7 @@ var TPacketData = map[byte]TPacketCases{
ReasonCode: ErrBadUsernameOrPassword.Code,
},
},
{
Case: TConnackNotAuthorised,
Desc: "not authorised",
@@ -1804,13 +1824,10 @@ var TPacketData = map[byte]TPacketCases{
Case: TPublishRetainMqtt5,
Desc: "retain mqtt5",
RawBytes: []byte{
Publish<<4 | 1<<0, 35, // Fixed header
Publish<<4 | 1<<0, 19, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
16, // properties length
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
0, // properties length
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
},
Packet: &Packet{
@@ -1818,18 +1835,11 @@ var TPacketData = map[byte]TPacketCases{
FixedHeader: FixedHeader{
Type: Publish,
Retain: true,
Remaining: 35,
Remaining: 19,
},
TopicName: "a/b/c",
Properties: Properties{
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
Payload: []byte("hello mochi"),
TopicName: "a/b/c",
Properties: Properties{},
Payload: []byte("hello mochi"),
},
},
{

130
server.go
View File

@@ -26,10 +26,8 @@ import (
)
const (
Version = "2.1.1" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
Version = "2.2.2" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
)
var (
@@ -47,6 +45,7 @@ var (
SharedSubAvailable: 1, // shared subscriptions are available
ServerKeepAlive: 10, // default keepalive for clients
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
}
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists.
@@ -56,18 +55,20 @@ var (
// Capabilities indicates the capabilities and features provided by the server.
type Capabilities struct {
MaximumMessageExpiryInterval int64
MaximumClientWritesPending int32
MaximumSessionExpiryInterval uint32
MaximumPacketSize uint32
maximumPacketID uint32 // unexported, used for testing only
ReceiveMaximum uint16
TopicAliasMaximum uint16
ServerKeepAlive uint16
SharedSubAvailable byte
MinimumProtocolVersion byte
Compatibilities Compatibilities
MaximumQos byte
RetainAvailable byte
WildcardSubAvailable byte
SubIDAvailable byte
SharedSubAvailable byte
MinimumProtocolVersion byte
}
// Compatibilities provides flags for using compatibility modes.
@@ -80,7 +81,9 @@ type Compatibilities struct {
// Options contains configurable options for the server.
type Options struct {
// Capabilities defines the server features and behaviour.
// Capabilities defines the server features and behaviour. If you only wish to modify
// several of these values, set them explicitly - e.g.
// server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
Capabilities *Capabilities
// Logger specifies a custom configured implementation of zerolog to override
@@ -91,16 +94,6 @@ type Options struct {
// server.Log = &l
Logger *zerolog.Logger
// FanPoolSize is the number of individual workers and queues to initialize.
// Bigger is not necessarily better, and you should rely on defaults unless
// you have know what you are doing.
FanPoolSize uint64
// FanPoolQueueSize is the size of the queue per worker. Increase this value
// accordingly if you anticipate having intermittent but massive numbers of
// messages. Cluster support is roadmapped.
FanPoolQueueSize uint64
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
SysTopicResendInterval int64
}
@@ -113,7 +106,6 @@ type Server struct {
Clients *Clients // clients known to the broker
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
Info *system.Info // values about the server commonly known as $SYS topics
fanpool *FanPool // a fixed size worker pool for processing inbound and outbound messages
loop *loop // loop contains tickers for the system event loop
done chan bool // indicate that the server is ending
Log *zerolog.Logger // minimal no-alloc logger
@@ -165,8 +157,7 @@ func New(opts *Options) *Server {
Version: Version,
Started: time.Now().Unix(),
},
fanpool: NewFanPool(opts.FanPoolSize, opts.FanPoolQueueSize),
Log: opts.Logger,
Log: opts.Logger,
hooks: &Hooks{
Log: opts.Logger,
},
@@ -181,18 +172,12 @@ func (o *Options) ensureDefaults() {
o.Capabilities = DefaultServerCapabilities
}
o.Capabilities.maximumPacketID = math.MaxUint16 // spec maximum is 65535
if o.SysTopicResendInterval == 0 {
o.SysTopicResendInterval = defaultSysTopicInterval
}
if o.FanPoolSize == 0 {
o.FanPoolSize = defaultFanPoolSize
}
if o.FanPoolQueueSize < 1 {
o.FanPoolQueueSize = defaultFanPoolQueueSize
}
if o.Logger == nil {
log := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.InfoLevel).Output(zerolog.ConsoleWriter{Out: os.Stderr})
o.Logger = &log
@@ -219,6 +204,8 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool)
// By default we don't want to restrict developer publishes,
// but if you do, reset this after creating inline client.
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
} else {
go cl.WriteLoop() // can only write to real clients
}
return cl
@@ -375,6 +362,8 @@ func (s *Server) attachClient(cl *Client, listener string) error {
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)
close(cl.State.outbound)
if expire {
s.UnsubscribeClient(cl)
cl.ClearInflights(math.MaxInt64, 0)
@@ -451,8 +440,6 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
// session is abandoned.
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
existing.Lock()
defer existing.Unlock()
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.UnsubscribeClient(existing)
@@ -460,14 +447,22 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
return false // [MQTT-3.2.2-3]
}
cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5]
if existing.State.Inflight.Len() > 0 {
existing.State.Inflight.Lock()
cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5]
existing.State.Inflight.Unlock()
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.capabilities.ReceiveMaximum != 0 {
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
}
}
for _, sub := range existing.State.Subscriptions.GetAll() {
existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if !existed {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sub.Filter, sub)
s.publishRetainedToClient(cl, sub, existed)
}
return true // [MQTT-3.2.2-3]
@@ -489,6 +484,12 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro
}
if reason.Code >= packets.ErrUnspecifiedError.Code {
if cl.Properties.ProtocolVersion < 5 {
if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes
reason = v3reason
}
}
properties.ReasonString = reason.Reason
ack := packets.Packet{
FixedHeader: packets.FixedHeader{
@@ -588,7 +589,7 @@ func (s *Server) processPacket(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Delete(next.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
}
cl.State.Inflight.TakeSendQuota()
cl.State.Inflight.DecreaseSendQuota()
}
}
@@ -695,15 +696,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
}
if pk.FixedHeader.Qos == 0 {
s.fanpool.Enqueue(cl.ID, func() {
s.publishToSubscribers(pk)
})
s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk)
return nil
}
cl.State.Inflight.TakeReceiveQuota()
cl.State.Inflight.DecreaseReceiveQuota()
ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4]
if pk.FixedHeader.Qos == 2 {
ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8]
@@ -711,6 +709,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Set(ack); ok {
atomic.AddInt64(&s.Info.Inflight, 1)
s.hooks.OnQosPublish(cl, ack, ack.Created, 0)
}
err := cl.WritePacket(ack)
@@ -722,15 +721,13 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Delete(ack.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
}
cl.State.Inflight.ReturnReceiveQuota()
s.hooks.OnQosComplete(cl, pk)
cl.State.Inflight.IncreaseReceiveQuota()
s.hooks.OnQosComplete(cl, ack)
}
s.fanpool.Enqueue(cl.ID, func() {
s.publishToSubscribers(pk)
})
s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk)
return nil
}
@@ -773,13 +770,13 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
}
}
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (out packets.Packet, err error) {
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) {
if sub.NoLocal && pk.Origin == cl.ID {
return pk, nil // [MQTT-3.8.3-3]
}
out = pk.Copy(false)
if !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
out := pk.Copy(false)
if cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
out.FixedHeader.Retain = false // [MQTT-3.3.1-12]
}
@@ -819,22 +816,32 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3]
atomic.AddInt64(&s.Info.Inflight, 1)
s.hooks.OnQosPublish(cl, out, out.Created, 0)
cl.State.Inflight.DecreaseSendQuota()
}
if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 {
out.Expiry = -1
cl.State.Inflight.Set(out)
return pk, nil
return out, nil
}
}
if cl.Net.conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
return pk, packets.CodeDisconnect
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
return out, packets.CodeDisconnect
}
cl.State.Inflight.TakeSendQuota()
select {
case cl.State.outbound <- out:
atomic.AddInt32(&cl.State.outboundQty, 1)
default:
atomic.AddInt64(&s.Info.MessagesDropped, 1)
cl.ops.hooks.OnPublishDropped(cl, pk)
cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight.
cl.State.Inflight.IncreaseSendQuota()
return out, packets.ErrPendingClientWritesExceeded
}
return out, cl.WritePacket(out)
return out, nil
}
func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, existed bool) {
@@ -849,7 +856,7 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e
for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4]
_, err := s.publishToClient(cl, sub, pkv)
if err != nil {
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
}
}
}
@@ -883,7 +890,7 @@ func (s *Server) processPuback(cl *Client, pk packets.Packet) error {
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
cl.State.Inflight.ReturnSendQuota()
cl.State.Inflight.IncreaseSendQuota()
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
}
@@ -906,7 +913,7 @@ func (s *Server) processPubrec(cl *Client, pk packets.Packet) error {
}
ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6]
cl.State.Inflight.TakeReceiveQuota() // -1 RECV QUOTA
cl.State.Inflight.DecreaseReceiveQuota() // -1 RECV QUOTA
cl.State.Inflight.Set(ack) // [MQTT-4.3.3-5]
return cl.WritePacket(ack)
}
@@ -933,8 +940,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
return err
}
cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.ReturnSendQuota() // +1 SENT QUOTA
cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12]
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
@@ -946,8 +953,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
// processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server.
func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error {
// regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas.
cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.ReturnSendQuota() // +1 SENT QUOTA
cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
if ok := cl.State.Inflight.Delete(pk.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
@@ -1094,7 +1101,6 @@ func (s *Server) UnsubscribeClient(cl *Client) {
// processAuth processes an Auth packet.
func (s *Server) processAuth(cl *Client, pk packets.Packet) error {
_, err := s.hooks.OnAuthPacket(cl, pk)
fmt.Println("err", err)
if err != nil {
return err
}
@@ -1182,6 +1188,7 @@ func (s *Server) publishSysTopics() {
SysPrefix + "/broker/packets/sent": AtomicItoa(&s.Info.PacketsSent),
SysPrefix + "/broker/messages/received": AtomicItoa(&s.Info.MessagesReceived),
SysPrefix + "/broker/messages/sent": AtomicItoa(&s.Info.MessagesSent),
SysPrefix + "/broker/messages/dropped": AtomicItoa(&s.Info.MessagesDropped),
SysPrefix + "/broker/messages/inflight": AtomicItoa(&s.Info.Inflight),
SysPrefix + "/broker/retained": AtomicItoa(&s.Info.Retained),
SysPrefix + "/broker/subscriptions": AtomicItoa(&s.Info.Subscriptions),
@@ -1203,8 +1210,6 @@ func (s *Server) publishSysTopics() {
func (s *Server) Close() error {
close(s.done)
s.Listeners.CloseAll(s.closeListenerClients)
s.fanpool.Close()
s.fanpool.Wait()
s.hooks.OnStopped()
s.hooks.Stop()
@@ -1324,6 +1329,7 @@ func (s *Server) loadServerInfo(v system.Info) {
atomic.StoreInt64(&s.Info.ClientsDisconnected, v.ClientsDisconnected)
atomic.StoreInt64(&s.Info.MessagesReceived, v.MessagesReceived)
atomic.StoreInt64(&s.Info.MessagesSent, v.MessagesSent)
atomic.StoreInt64(&s.Info.MessagesDropped, v.MessagesDropped)
atomic.StoreInt64(&s.Info.PacketsReceived, v.PacketsReceived)
atomic.StoreInt64(&s.Info.PacketsSent, v.PacketsSent)
atomic.StoreInt64(&s.Info.InflightDropped, v.InflightDropped)

View File

@@ -54,10 +54,8 @@ func newServer() *Server {
cc.ReceiveMaximum = 0
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Capabilities: &cc,
Logger: &logger,
Capabilities: &cc,
})
s.AddHook(new(AllowHook), nil)
return s
@@ -68,8 +66,6 @@ func TestOptionsSetDefaults(t *testing.T) {
opts.ensureDefaults()
require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval)
require.Equal(t, defaultFanPoolSize, opts.FanPoolSize)
require.Equal(t, defaultFanPoolQueueSize, opts.FanPoolQueueSize)
require.Equal(t, DefaultServerCapabilities, opts.Capabilities)
opts = new(Options)
@@ -86,7 +82,6 @@ func TestNew(t *testing.T) {
require.NotNil(t, s.Info)
require.NotNil(t, s.Log)
require.NotNil(t, s.Options)
require.NotNil(t, s.fanpool)
require.NotNil(t, s.loop)
require.NotNil(t, s.loop.sysTopics)
require.NotNil(t, s.loop.inflightExpiry)
@@ -117,7 +112,7 @@ func TestServerNewClient(t *testing.T) {
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.conn)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.Equal(t, s.Log, cl.ops.log)
@@ -408,18 +403,19 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
cl.Properties.ProtocolVersion = 5
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl.State.Inflight.Set(*packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
s.Clients.Add(cl)
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
err := s.EstablishConnection("tcp", r)
o <- err
}()
go func() {
w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect.
w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
@@ -445,9 +441,14 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
}
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, <-recv)
connackPlusPacket := append(
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes,
packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes...,
)
require.Equal(t, connackPlusPacket, <-recv)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover)
time.Sleep(time.Microsecond * 100)
w.Close()
r.Close()
@@ -553,9 +554,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) {
func TestEstablishConnectionBadAuthentication(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
defer s.Close()
@@ -589,9 +588,7 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) {
func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
defer s.Close()
@@ -806,7 +803,7 @@ func TestInheritClientSession(t *testing.T) {
n := time.Now().Unix()
existing, _, _ := newTestClient()
existing.Net.conn = nil
existing.Net.Conn = nil
existing.ID = "mochi"
existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
existing.State.Inflight = NewInflights()
@@ -1023,7 +1020,7 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) {
w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, <-receiverBuf)
require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
}
@@ -1098,9 +1095,7 @@ func TestServerProcessPublishInvalidTopic(t *testing.T) {
func TestServerProcessPublishACLCheckDeny(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
defer s.Close()
@@ -1383,6 +1378,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
pkx.FixedHeader.Qos = 2
s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx)
time.Sleep(time.Microsecond * 100)
w.Close()
}()
@@ -1396,6 +1392,31 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
}
func TestPublishToClientExceedClientWritesPending(t *testing.T) {
s := newServer()
_, w := net.Pipe()
cl := newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
MaximumClientWritesPending: 3,
},
})
s.Clients.Add(cl)
for i := int32(0); i < cl.ops.capabilities.MaximumClientWritesPending; i++ {
cl.State.outbound <- packets.Packet{}
atomic.AddInt32(&cl.State.outboundQty, 1)
}
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{})
require.Error(t, err)
require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err)
}
func TestPublishToClientServerTopicAlias(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
@@ -1407,6 +1428,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet
s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
time.Sleep(time.Millisecond)
w.Close()
}()
@@ -1428,7 +1450,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) {
func TestPublishToClientExhaustedPacketID(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
for i := 0; i <= 65535; i++ {
for i := uint32(0); i <= cl.ops.capabilities.maximumPacketID; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
}
@@ -1440,7 +1462,7 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) {
func TestPublishToClientNoConn(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Net.conn = nil
cl.Net.Conn = nil
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.Error(t, err)
@@ -1497,7 +1519,7 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
for i := 0; i <= 65535; i++ {
for i := uint32(0); i <= cl.ops.capabilities.maximumPacketID; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: 1})
}
@@ -1537,7 +1559,7 @@ func TestPublishRetainedToClient(t *testing.T) {
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
require.True(t, subbed)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetainMqtt5).Packet)
require.Equal(t, int64(1), retained)
go func() {
@@ -1548,7 +1570,7 @@ func TestPublishRetainedToClient(t *testing.T) {
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, buf)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, buf)
}
func TestPublishRetainedToClientIsShared(t *testing.T) {
@@ -1863,7 +1885,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) {
for i, tx := range tt {
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
r, w = net.Pipe()
cl.Net.conn = w
cl.Net.Conn = w
recv := make(chan []byte)
go func() { // receive the ack
@@ -1937,7 +1959,8 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
for i, tx := range tt {
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
r, w := net.Pipe()
cl.Net.conn = w
time.Sleep(time.Millisecond)
cl.Net.Conn = w
recv := make(chan []byte)
go func() { // receive the ack
@@ -1953,6 +1976,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
require.NoError(t, err)
}
time.Sleep(time.Millisecond)
w.Close()
if i != 2 {
@@ -2064,7 +2088,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) {
require.NoError(t, err)
require.Equal(t, append(
packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes,
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes...,
packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes...,
), buf)
}
@@ -2164,9 +2188,7 @@ func TestServerProcessSubscribeNoConnection(t *testing.T) {
func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
cl, r, w := newTestClient()
@@ -2185,9 +2207,7 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
@@ -2452,7 +2472,7 @@ func TestServerSendLWTDelayed(t *testing.T) {
recv <- buf
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-recv)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, <-recv)
}
func TestServerReadStore(t *testing.T) {
@@ -2573,7 +2593,6 @@ func TestServerClose(t *testing.T) {
err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
require.NoError(t, err)
s.Serve()
require.Equal(t, uint64(2), s.fanpool.Size())
// receive the disconnect
recv := make(chan []byte)
@@ -2593,7 +2612,6 @@ func TestServerClose(t *testing.T) {
s.Close()
time.Sleep(time.Millisecond)
require.Equal(t, false, listener.(*listeners.MockListener).IsServing())
require.Equal(t, uint64(0), s.fanpool.Size())
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv)
}

View File

@@ -4,6 +4,8 @@
package system
import "sync/atomic"
// Info contains atomic counters and values for various server statistics
// commonly found in $SYS topics (and others).
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
@@ -20,6 +22,7 @@ type Info struct {
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
Retained int64 `json:"retained"` // total number of retained messages active on the broker
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
@@ -29,3 +32,30 @@ type Info struct {
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
}
// Clone makes a copy of Info using atomic operation
func (i *Info) Clone() *Info {
return &Info{
Version: i.Version,
Started: atomic.LoadInt64(&i.Started),
Time: atomic.LoadInt64(&i.Time),
Uptime: atomic.LoadInt64(&i.Uptime),
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
BytesSent: atomic.LoadInt64(&i.BytesSent),
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
Retained: atomic.LoadInt64(&i.Retained),
Inflight: atomic.LoadInt64(&i.Inflight),
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
Threads: atomic.LoadInt64(&i.Threads),
}
}

37
system/system_test.go Normal file
View File

@@ -0,0 +1,37 @@
package system
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestClone(t *testing.T) {
o := &Info{
Version: "version",
Started: 1,
Time: 2,
Uptime: 3,
BytesReceived: 4,
BytesSent: 5,
ClientsConnected: 6,
ClientsMaximum: 7,
ClientsTotal: 8,
ClientsDisconnected: 9,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
Retained: 12,
Inflight: 13,
InflightDropped: 14,
Subscriptions: 15,
PacketsReceived: 16,
PacketsSent: 17,
MemoryAlloc: 18,
Threads: 19,
}
n := o.Clone()
require.Equal(t, o, n)
}

View File

@@ -347,6 +347,8 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()
if len(pk.Payload) > 0 {
n.retainPath = pk.TopicName
x.Retained.Add(pk.TopicName, pk)
@@ -361,6 +363,7 @@ func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
n.retainPath = ""
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
x.trim(n)
return out
}
@@ -619,6 +622,7 @@ type particle struct {
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle
}
// newParticle returns a pointer to a new instance of particle.