Client write buffers (#165)

* Replace fanpool with client write buffers
This commit is contained in:
JB
2023-02-09 22:34:30 +00:00
committed by GitHub
parent 7ba1352a60
commit bb54cc68e6
13 changed files with 189 additions and 284 deletions

View File

@@ -41,10 +41,10 @@ import "github.com/mochi-co/mqtt/v2"
- Direct Packet Injection using special inline client, or masquerade as existing clients.
- Performant and Stable:
- Our classic trie-based Topic-Subscription model.
- A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
- Over a thousand carefully considered unit test scenarios.
- TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners.
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
@@ -297,6 +297,7 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
| OnRetainMessage | Called then a published message is retained. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
@@ -366,23 +367,23 @@ Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inov
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 |
| Mochi v2.2.0 | 127,216 | 125,748 | 124,279 | 319,250 | 309,327 | 299,405 |
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 |
| Mochi v2.2.0 | 45,615 | 30,129 | 21,138 | 232,717 | 86,323 | 50,402 |
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
Million Message Challenge (hit the server with 1 million messages immediately):
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636 |
| Mochi v2.2.0 | 51,044 | 4,682 | 2,345 | 72,634 | 7,645 | 2,464 |
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |

View File

@@ -135,15 +135,17 @@ type Will struct {
// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
endOnce sync.Once // only end once
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
keepalive uint16 // the number of seconds the connection can wait
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
}
// newClient returns a new instance of Client. This is almost exclusively used by Server
@@ -155,6 +157,7 @@ func newClient(c net.Conn, o *ops) *Client {
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
outbound: make(chan packets.Packet, o.capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
@@ -175,6 +178,16 @@ func newClient(c net.Conn, o *ops) *Client {
return cl
}
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for pk := range cl.State.outbound {
if err := cl.WritePacket(pk); err != nil {
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
}
atomic.AddInt32(&cl.State.outboundQty, -1)
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
@@ -223,12 +236,12 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.Net.Conn != nil {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
@@ -237,6 +250,8 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i + 1
overflowed := false
@@ -543,7 +558,6 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
}
cl.refreshDeadline(cl.State.keepalive)
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
return err

View File

@@ -30,8 +30,9 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
},
})
@@ -42,6 +43,9 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
go cl.WriteLoop()
return
}

View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -1,101 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
package mqtt
import (
"sync"
"sync/atomic"
xh "github.com/cespare/xxhash/v2"
)
// taskChan is a channel for incoming task functions.
type taskChan chan func()
// FanPool is a fixed-sized fan-style worker pool with multiple
// working 'columns'. Instead of a single queue channel processed by
// many goroutines, this fan pool uses many queue channels each
// processed by a single goroutine.
// Very special thanks are given to the authors of HMQ in particular
// @chowyu08 and @muXxer for their work on the fixpool worker pool
// https://github.com/fhmq/hmq/blob/master/pool/fixpool.go
// from which this fan-pool is heavily inspired.
type FanPool struct {
queue []taskChan
wg sync.WaitGroup
capacity uint64
perChan uint64
Mutex sync.Mutex
}
// New returns a new instance of FanPool. fanSize controls the number of 'columns'
// of the fan, whereas queueSize controls the size of each column's queue.
func NewFanPool(fanSize, queueSize uint64) *FanPool {
pool := &FanPool{
capacity: fanSize,
perChan: queueSize,
queue: make([]taskChan, fanSize),
}
pool.fillWorkers(fanSize)
return pool
}
// fillWorkers adds columns to the fan pool with an associated worker goroutine.
func (p *FanPool) fillWorkers(n uint64) {
for i := uint64(0); i < n; i++ {
p.queue[i] = make(taskChan, p.perChan)
go p.worker(p.queue[i])
p.wg.Add(1)
}
}
// worker is a worker goroutine which processes tasks from a single queue.
func (p *FanPool) worker(ch taskChan) {
defer p.wg.Done()
var task func()
var ok bool
for {
task, ok = <-ch
if !ok {
return
}
task()
}
}
// Enqueue adds a new task to the queue to be processed.
func (p *FanPool) Enqueue(id string, task func()) {
if p.Size() == 0 {
return
}
// We can use xh.Sum64 to get a specific queue index
// which remains the same for a client id, giving each
// client their own queue.
p.queue[xh.Sum64([]byte(id))%p.Size()] <- task
}
// Wait blocks until all the workers in the pool have completed.
func (p *FanPool) Wait() {
p.wg.Wait()
}
// Close issues a shutdown signal to the workers.
func (p *FanPool) Close() {
for i := 0; i < int(p.Size()); i++ {
if p.queue[i] != nil {
close(p.queue[i])
}
}
p.queue = nil
atomic.StoreUint64(&p.capacity, 0)
}
// Size returns the current number of workers in the pool.
func (p *FanPool) Size() uint64 {
return atomic.LoadUint64(&p.capacity)
}

View File

@@ -1,89 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestFanPool(t *testing.T) {
f := NewFanPool(1, 2)
require.NotNil(t, f)
require.Equal(t, uint64(1), f.capacity)
require.Equal(t, 2, cap(f.queue[0]))
o := make(chan bool)
go func() {
f.Enqueue("test", func() {
o <- true
})
}()
require.True(t, <-o)
f.Close()
f.Wait()
}
func TestFillWorkers(t *testing.T) {
f := &FanPool{
perChan: 3,
queue: make([]taskChan, 2),
}
f.fillWorkers(2)
require.Len(t, f.queue, 2)
require.Equal(t, 3, cap(f.queue[0]))
}
func TestEnqueue(t *testing.T) {
f := &FanPool{
capacity: 2,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
},
}
go func() {
f.Enqueue("a", func() {})
}()
require.NotNil(t, <-f.queue[1])
}
func TestEnqueueOnEmpty(t *testing.T) {
f := &FanPool{
queue: []taskChan{},
}
go func() {
f.Enqueue("a", func() {})
}()
require.Len(t, f.queue, 0)
}
func TestSize(t *testing.T) {
f := &FanPool{
capacity: 10,
}
require.Equal(t, uint64(10), f.Size())
}
func TestClose(t *testing.T) {
f := &FanPool{
capacity: 3,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
make(taskChan, 2),
},
}
f.Close()
require.Equal(t, uint64(0), f.Size())
require.Nil(t, f.queue)
}

2
go.mod
View File

@@ -6,7 +6,6 @@ require (
github.com/alicebob/miniredis/v2 v2.23.0
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/cespare/xxhash/v2 v2.1.2
github.com/go-redis/redis/v8 v8.11.5
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
@@ -21,6 +20,7 @@ require (
require (
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/badger v1.6.0 // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect

View File

@@ -39,6 +39,7 @@ const (
OnUnsubscribed
OnPublish
OnPublished
OnPublishDropped
OnRetainMessage
OnQosPublish
OnQosComplete
@@ -87,6 +88,7 @@ type Hook interface {
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnPublishDropped(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
@@ -393,6 +395,16 @@ func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
}
}
// OnPublishDropped is called when a message to a client was dropped instead of delivered
// such as when a client is too slow to respond.
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublishDropped) {
hook.OnPublishDropped(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.GetAll() {
@@ -727,6 +739,9 @@ func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, err
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}

View File

@@ -236,6 +236,7 @@ func TestHooksNonReturns(t *testing.T) {
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnPublishDropped(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})

View File

@@ -113,6 +113,7 @@ var (
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}

View File

@@ -26,10 +26,8 @@ import (
)
const (
Version = "2.1.8" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
Version = "2.2.0" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
)
var (
@@ -47,6 +45,7 @@ var (
SharedSubAvailable: 1, // shared subscriptions are available
ServerKeepAlive: 10, // default keepalive for clients
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
}
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists.
@@ -56,6 +55,7 @@ var (
// Capabilities indicates the capabilities and features provided by the server.
type Capabilities struct {
MaximumMessageExpiryInterval int64
MaximumClientWritesPending int32
MaximumSessionExpiryInterval uint32
MaximumPacketSize uint32
ReceiveMaximum uint16
@@ -80,7 +80,9 @@ type Compatibilities struct {
// Options contains configurable options for the server.
type Options struct {
// Capabilities defines the server features and behaviour.
// Capabilities defines the server features and behaviour. If you only wish to modify
// several of these values, set them explicitly - e.g.
// server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
Capabilities *Capabilities
// Logger specifies a custom configured implementation of zerolog to override
@@ -91,16 +93,6 @@ type Options struct {
// server.Log = &l
Logger *zerolog.Logger
// FanPoolSize is the number of individual workers and queues to initialize.
// Bigger is not necessarily better, and you should rely on defaults unless
// you have know what you are doing.
FanPoolSize uint64
// FanPoolQueueSize is the size of the queue per worker. Increase this value
// accordingly if you anticipate having intermittent but massive numbers of
// messages. Cluster support is roadmapped.
FanPoolQueueSize uint64
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
SysTopicResendInterval int64
}
@@ -113,7 +105,6 @@ type Server struct {
Clients *Clients // clients known to the broker
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
Info *system.Info // values about the server commonly known as $SYS topics
fanpool *FanPool // a fixed size worker pool for processing inbound and outbound messages
loop *loop // loop contains tickers for the system event loop
done chan bool // indicate that the server is ending
Log *zerolog.Logger // minimal no-alloc logger
@@ -165,8 +156,7 @@ func New(opts *Options) *Server {
Version: Version,
Started: time.Now().Unix(),
},
fanpool: NewFanPool(opts.FanPoolSize, opts.FanPoolQueueSize),
Log: opts.Logger,
Log: opts.Logger,
hooks: &Hooks{
Log: opts.Logger,
},
@@ -185,14 +175,6 @@ func (o *Options) ensureDefaults() {
o.SysTopicResendInterval = defaultSysTopicInterval
}
if o.FanPoolSize == 0 {
o.FanPoolSize = defaultFanPoolSize
}
if o.FanPoolQueueSize < 1 {
o.FanPoolQueueSize = defaultFanPoolQueueSize
}
if o.Logger == nil {
log := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.InfoLevel).Output(zerolog.ConsoleWriter{Out: os.Stderr})
o.Logger = &log
@@ -219,6 +201,8 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool)
// By default we don't want to restrict developer publishes,
// but if you do, reset this after creating inline client.
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
} else {
go cl.WriteLoop() // can only write to real clients
}
return cl
@@ -382,6 +366,8 @@ func (s *Server) attachClient(cl *Client, listener string) error {
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)
close(cl.State.outbound)
if expire {
s.UnsubscribeClient(cl)
cl.ClearInflights(math.MaxInt64, 0)
@@ -458,8 +444,6 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
// session is abandoned.
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
existing.Lock()
defer existing.Unlock()
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.UnsubscribeClient(existing)
@@ -468,13 +452,15 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
}
if existing.State.Inflight.Len() > 0 {
existing.State.Inflight.Lock()
cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5]
existing.State.Inflight.Unlock()
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.capabilities.ReceiveMaximum != 0 {
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
}
}
for _, sub := range existing.State.Subscriptions.GetAll() {
existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if !existed {
@@ -714,10 +700,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
}
if pk.FixedHeader.Qos == 0 {
s.fanpool.Enqueue(cl.ID, func() {
s.publishToSubscribers(pk)
})
s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk)
return nil
}
@@ -746,11 +729,9 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
s.hooks.OnQosComplete(cl, ack)
}
s.fanpool.Enqueue(cl.ID, func() {
s.publishToSubscribers(pk)
})
s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk)
return nil
}
@@ -845,15 +826,25 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 {
out.Expiry = -1
cl.State.Inflight.Set(out)
return pk, nil
return out, nil
}
}
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
return pk, packets.CodeDisconnect
return out, packets.CodeDisconnect
}
return out, cl.WritePacket(out)
select {
case cl.State.outbound <- out:
atomic.AddInt32(&cl.State.outboundQty, 1)
default:
cl.ops.hooks.OnPublishDropped(cl, pk)
cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight.
cl.State.Inflight.IncreaseSendQuota()
return out, packets.ErrPendingClientWritesExceeded
}
return out, nil
}
func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, existed bool) {
@@ -868,7 +859,7 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e
for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4]
_, err := s.publishToClient(cl, sub, pkv)
if err != nil {
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
}
}
}
@@ -1113,7 +1104,6 @@ func (s *Server) UnsubscribeClient(cl *Client) {
// processAuth processes an Auth packet.
func (s *Server) processAuth(cl *Client, pk packets.Packet) error {
_, err := s.hooks.OnAuthPacket(cl, pk)
fmt.Println("err", err)
if err != nil {
return err
}
@@ -1222,8 +1212,6 @@ func (s *Server) publishSysTopics() {
func (s *Server) Close() error {
close(s.done)
s.Listeners.CloseAll(s.closeListenerClients)
s.fanpool.Close()
s.fanpool.Wait()
s.hooks.OnStopped()
s.hooks.Stop()

View File

@@ -54,10 +54,8 @@ func newServer() *Server {
cc.ReceiveMaximum = 0
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Capabilities: &cc,
Logger: &logger,
Capabilities: &cc,
})
s.AddHook(new(AllowHook), nil)
return s
@@ -68,8 +66,6 @@ func TestOptionsSetDefaults(t *testing.T) {
opts.ensureDefaults()
require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval)
require.Equal(t, defaultFanPoolSize, opts.FanPoolSize)
require.Equal(t, defaultFanPoolQueueSize, opts.FanPoolQueueSize)
require.Equal(t, DefaultServerCapabilities, opts.Capabilities)
opts = new(Options)
@@ -86,7 +82,6 @@ func TestNew(t *testing.T) {
require.NotNil(t, s.Info)
require.NotNil(t, s.Log)
require.NotNil(t, s.Options)
require.NotNil(t, s.fanpool)
require.NotNil(t, s.loop)
require.NotNil(t, s.loop.sysTopics)
require.NotNil(t, s.loop.inflightExpiry)
@@ -418,11 +413,13 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
err := s.EstablishConnection("tcp", r)
o <- err
}()
go func() {
w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
time.Sleep(time.Millisecond) // we want to receive the retained message, so we need to wait a moment before sending the disconnect.
w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
@@ -455,6 +452,7 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
require.Equal(t, connackPlusPacket, <-recv)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover)
time.Sleep(time.Microsecond * 100)
w.Close()
r.Close()
@@ -560,9 +558,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) {
func TestEstablishConnectionBadAuthentication(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
defer s.Close()
@@ -596,9 +592,7 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) {
func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
defer s.Close()
@@ -1105,9 +1099,7 @@ func TestServerProcessPublishInvalidTopic(t *testing.T) {
func TestServerProcessPublishACLCheckDeny(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
defer s.Close()
@@ -1390,6 +1382,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
pkx.FixedHeader.Qos = 2
s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx)
time.Sleep(time.Microsecond * 100)
w.Close()
}()
@@ -1403,6 +1396,31 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
}
func TestPublishToClientExceedClientWritesPending(t *testing.T) {
s := newServer()
_, w := net.Pipe()
cl := newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
MaximumClientWritesPending: 3,
},
})
s.Clients.Add(cl)
for i := int32(0); i < cl.ops.capabilities.MaximumClientWritesPending; i++ {
cl.State.outbound <- packets.Packet{}
atomic.AddInt32(&cl.State.outboundQty, 1)
}
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{})
require.Error(t, err)
require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err)
}
func TestPublishToClientServerTopicAlias(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
@@ -1414,6 +1432,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet
s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
time.Sleep(time.Millisecond)
w.Close()
}()
@@ -1944,6 +1963,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
for i, tx := range tt {
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
r, w := net.Pipe()
time.Sleep(time.Millisecond)
cl.Net.Conn = w
recv := make(chan []byte)
@@ -1960,6 +1980,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
require.NoError(t, err)
}
time.Sleep(time.Millisecond)
w.Close()
if i != 2 {
@@ -2171,9 +2192,7 @@ func TestServerProcessSubscribeNoConnection(t *testing.T) {
func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
cl, r, w := newTestClient()
@@ -2192,9 +2211,7 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
@@ -2580,7 +2597,6 @@ func TestServerClose(t *testing.T) {
err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
require.NoError(t, err)
s.Serve()
require.Equal(t, uint64(2), s.fanpool.Size())
// receive the disconnect
recv := make(chan []byte)
@@ -2600,7 +2616,6 @@ func TestServerClose(t *testing.T) {
s.Close()
time.Sleep(time.Millisecond)
require.Equal(t, false, listener.(*listeners.MockListener).IsServing())
require.Equal(t, uint64(0), s.fanpool.Size())
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv)
}

View File

@@ -347,6 +347,8 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()
if len(pk.Payload) > 0 {
n.retainPath = pk.TopicName
x.Retained.Add(pk.TopicName, pk)
@@ -361,6 +363,7 @@ func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
n.retainPath = ""
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
x.trim(n)
return out
}
@@ -619,6 +622,7 @@ type particle struct {
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle
}
// newParticle returns a pointer to a new instance of particle.