Compare commits

..

21 Commits

Author SHA1 Message Date
mochi-co
ef5dcf68d0 Use context to signal client open state 2023-05-06 11:49:02 +01:00
JB
6704cf7227 Add packet ID exhausted hook (#217) 2023-05-06 10:37:27 +01:00
thedevop
9233e6fd39 Expire session if SessionExpiryInterval is 0 (#216)
If SessionExpiryInterval was not set in CONNECT, SessionExpiryIntervalFlag is also not set. According to spec:
  If the Session Expiry Interval is absent the value 0 is used. If it is set to 0, or is absent, the Session ends when the Network Connection is closed.
2023-05-06 10:12:33 +01:00
ħþ
1ca65d9631 Update codes.go (#215)
fix typo
2023-05-06 10:02:25 +01:00
ħþ
33229da885 Update codes.go (#214)
Fix typo
2023-05-06 09:59:50 +01:00
mochi-co
c274d5fd08 Update server version 2023-05-05 00:04:27 +01:00
JB
10e82f41d6 Lock on close outbound (#213) 2023-05-05 00:02:49 +01:00
JB
e6c07b2b78 Add lock to client writes (#212) 2023-05-04 23:17:12 +01:00
JB
eed3ef9606 Add OnPacketIDExhausted hook (#211) 2023-05-04 22:51:40 +01:00
JB
1ec880844d Correctly validate WillProperties (#210)
Co-authored-by: sukvojte <sukvojte@gmail.com>
2023-05-04 22:37:23 +01:00
werbenhu
4b49652a8c Update build.yml (#203) 2023-05-04 18:09:58 +01:00
mochi-co
d46e7b5bcf Protect close of nil outbound channel 2023-04-21 22:09:59 +01:00
mochi-co
17fb7dadbc Protect close of nil outbound channel 2023-04-21 22:00:27 +01:00
werben
ed7fd836e1 #78 storage hook should not execute the relevant code if the client has been reconnected (#198)
* storage hook should not execute the relevant code if the client has been reconnected #78

* add test cases for coverage decrease

add test cases for coverage decrease
2023-04-21 21:52:44 +01:00
mochi-co
605bb93c75 Move msgToPacket to storage.Message.ToPacket 2023-04-21 21:49:49 +01:00
Wind
c73ace2ea0 Simplified code (#195)
Simplify the code for the loadInflight and loadRetained methods.
Adjust the validation order of the processSubscribe method to ensure that it fails quickly if there is an error, since s.hooks. OnACLCheck generally takes a long time.

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-04-21 21:29:03 +01:00
thedevop
aac6d699da Ensure to close client WriteLoop (#193)
* Ensure client WriteLoop is closed

* Ensure to close client WriteLoop
2023-04-21 21:20:46 +01:00
Hubertus Hohl
7bd7bd5087 fix: common subscriptions issued by different clients at the same time may be lost (#186) 2023-03-11 23:17:10 +00:00
mochi-co
655bf9fdb1 Update readme 2023-03-11 23:15:51 +00:00
mochi-co
b188055c7d Update server version 2023-03-11 23:14:51 +00:00
JB
aaf1d9d4c6 Configurable client bufio reader/writer sizes (#190) 2023-03-11 23:13:28 +00:00
21 changed files with 362 additions and 187 deletions

View File

@@ -7,37 +7,39 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: 1.19
- name: Vet
run: go vet ./...
- name: Test
run: go test -race ./... && echo true
coverage:
name: Test with Coverage
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.19.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.19'
- name: Check out code
uses: actions/checkout@v3
- name: Install dependencies
run: |
go mod download
- name: Run Unit tests
run: |
go test -race -covermode atomic -coverprofile=covprofile ./...
- name: Install goveralls
run: go install github.com/mattn/goveralls@latest
- name: Send coverage
env:
COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: goveralls -coverprofile=covprofile -service=github

View File

@@ -136,6 +136,8 @@ A number of configurable options are available which can be used to alter the be
```go
server := mqtt.New(&mqtt.Options{
Capabilities: mqtt.Capabilities{
ClientNetWriteBufferSize: 4096,
ClientNetReadBufferSize: 4096,
MaximumSessionExpiryInterval: 3600,
Compatibilities: mqtt.Compatibilities{
ObscureNotAuthorized: true,
@@ -145,7 +147,7 @@ server := mqtt.New(&mqtt.Options{
})
```
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options.
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
## Event Hooks
@@ -302,6 +304,7 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
| OnQosDropped | Called when an inflight message expires before completion. |
| OnPacketIDExhausted | Called when a client runs out of unused packet ids to assign. |
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
@@ -343,6 +346,8 @@ server.InjectPacket(cl, packets.Packet{
See the [hooks example](examples/hooks/main.go) to see this feature in action.
### Testing
#### Unit Tests
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:

View File

@@ -7,6 +7,7 @@ package mqtt
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
@@ -87,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
@@ -144,7 +145,7 @@ type ClientState struct {
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
open context.Context // indicate that the client is open for packet exchange
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
}
@@ -156,9 +157,10 @@ func newClient(c net.Conn, o *ops) *Client {
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.capabilities.MaximumClientWritesPending),
open: context.Background(),
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
@@ -168,8 +170,11 @@ func newClient(c net.Conn, o *ops) *Client {
if c != nil {
cl.Net = ClientConnection{
Conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Conn: c,
bconn: bufio.NewReadWriter(
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize),
),
Remote: c.RemoteAddr().String(),
}
}
@@ -198,8 +203,8 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
@@ -209,7 +214,7 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Props.AssignedClientID = cl.ID
}
cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive
cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}
@@ -262,7 +267,7 @@ func (cl *Client) NextPacketID() (i uint32, err error) {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.capabilities.maximumPacketID {
if i >= cl.ops.options.Capabilities.maximumPacketID {
overflowed = true
i = 0
continue
@@ -327,7 +332,7 @@ func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return nil
}
@@ -352,11 +357,10 @@ func (cl *Client) Read(packetHandler ReadFn) error {
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.done) == 1 {
return
}
cl.State.endOnce.Do(func() {
cl.Lock()
defer cl.Unlock()
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
@@ -365,7 +369,16 @@ func (cl *Client) Stop(err error) {
cl.State.stopCause.Store(err)
}
atomic.StoreUint32(&cl.State.done, 1)
if cl.State.outbound != nil {
close(cl.State.outbound)
}
if cl.State.open != nil {
var cancel context.CancelFunc
cl.State.open, cancel = context.WithCancel(cl.State.open)
cancel()
}
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
@@ -380,7 +393,7 @@ func (cl *Client) StopCause() error {
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return atomic.LoadUint32(&cl.State.done) == 1
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
@@ -405,7 +418,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
return err
}
if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize {
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
@@ -476,14 +489,6 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if atomic.LoadUint32(&cl.State.done) == 1 {
return ErrConnectionClosed
}
if cl.Net.Conn == nil {
return nil
}
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
@@ -497,7 +502,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
@@ -547,6 +552,17 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
}
cl.Lock()
defer cl.Unlock()
if cl.Closed() {
return ErrConnectionClosed
}
if cl.Net.Conn == nil {
return nil
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.Conn)
if err != nil {

View File

@@ -5,6 +5,7 @@
package mqtt
import (
"context"
"errors"
"io"
"net"
@@ -29,11 +30,13 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
options: &Options{
Capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
@@ -112,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
@@ -134,6 +137,8 @@ func TestNewClient(t *testing.T) {
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.NotNil(t, cl.ops.options.Capabilities)
require.False(t, cl.Net.Inline)
}
@@ -168,8 +173,8 @@ func TestClientParseConnect(t *testing.T) {
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
}
@@ -242,7 +247,7 @@ func TestClientNextPacketIDInUse(t *testing.T) {
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(1); i <= cl.ops.capabilities.maximumPacketID; i++ {
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
}
@@ -254,17 +259,17 @@ func TestClientNextPacketIDExhausted(t *testing.T) {
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(0); i < cl.ops.capabilities.maximumPacketID; i++ {
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = uint32(cl.ops.capabilities.maximumPacketID - 1)
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, cl.ops.capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.capabilities.maximumPacketID)] = packets.Packet{}
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
cl.State.packetID = cl.ops.capabilities.maximumPacketID
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
_, err = cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
@@ -363,7 +368,7 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.capabilities.MaximumPacketSize = 2
cl.ops.options.Capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
@@ -462,7 +467,7 @@ func TestClientReadOK(t *testing.T) {
func TestClientReadDone(t *testing.T) {
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
cl.State.open = nil
o := make(chan error)
go func() {
@@ -479,7 +484,7 @@ func TestClientStop(t *testing.T) {
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.True(t, cl.Closed())
require.Equal(t, nil, cl.StopCause())
}

View File

@@ -44,6 +44,7 @@ const (
OnQosPublish
OnQosComplete
OnQosDropped
OnPacketIDExhausted
OnWill
OnWillSent
OnClientExpired
@@ -93,6 +94,7 @@ type Hook interface {
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
OnQosDropped(cl *Client, pk packets.Packet)
OnPacketIDExhausted(cl *Client, pk packets.Packet)
OnWill(cl *Client, will Will) (Will, error)
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
@@ -447,6 +449,16 @@ func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
}
}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to
// assign to a packet.
func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketIDExhausted) {
hook.OnPacketIDExhausted(cl, pk)
}
}
}
// OnWill is called when a client disconnects and publishes an LWT message. This method
// differs from OnWillSent in that it allows you to modify the LWT message before it is
// published. The return values of the hook methods are passed-through in the order
@@ -754,6 +766,9 @@ func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
// OnQosDropped is called the Qos flow for a message expires.
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
// OnWill is called when a client disconnects and publishes an LWT message.
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
return will, nil

View File

@@ -182,6 +182,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")

View File

@@ -232,6 +232,29 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -184,6 +184,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")

View File

@@ -241,6 +241,29 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -199,6 +199,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")

View File

@@ -285,6 +285,28 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()

View File

@@ -117,6 +117,36 @@ func (d *Message) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, d)
}
// ToPacket converts a storage.Message to a standard packet.
func (d *Message) ToPacket() packets.Packet {
pk := packets.Packet{
FixedHeader: d.FixedHeader,
PacketID: d.PacketID,
TopicName: d.TopicName,
Payload: d.Payload,
Origin: d.Origin,
Created: d.Created,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
}
// Return a deep copy of the packet data otherwise the slices will
// continue pointing at the values from the storage packet.
pk = pk.Copy(true)
pk.FixedHeader.Dup = d.FixedHeader.Dup
return pk
}
// Subscription is a storable representation of an mqtt subscription.
type Subscription struct {
T string `json:"t"`

View File

@@ -194,3 +194,35 @@ func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
require.NoError(t, err)
require.Equal(t, SystemInfo{}, d)
}
func TestMessageToPacket(t *testing.T) {
d := messageStruct
pk := d.ToPacket()
require.Equal(t, packets.Packet{
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: d.FixedHeader.Remaining,
Type: d.FixedHeader.Type,
Qos: d.FixedHeader.Qos,
Dup: d.FixedHeader.Dup,
Retain: d.FixedHeader.Retain,
},
Origin: d.Origin,
TopicName: d.TopicName,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
PacketID: 100,
Created: d.Created,
}, pk)
}

View File

@@ -241,6 +241,7 @@ func TestHooksNonReturns(t *testing.T) {
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})
h.OnQosDropped(cl, packets.Packet{})
h.OnPacketIDExhausted(cl, packets.Packet{})
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")

View File

@@ -21,7 +21,7 @@ func (c Code) Error() string {
}
var (
// QosCodes indicicates the reason codes for each Qos byte.
// QosCodes indicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
@@ -120,7 +120,7 @@ var (
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}

View File

@@ -16,22 +16,23 @@ import (
// All of the valid packet types and their packet identifier.
const (
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Auth // 15
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Auth // 15
WillProperties byte = 99 // Special byte for validating Will Properties.
)
var (
@@ -313,7 +314,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
(&pk.Properties).Encode(pk, pb, 0)
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -322,7 +323,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.Connect.WillFlag {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
(&pk.Connect).WillProperties.Encode(pk, pb, 0)
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -393,7 +394,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
if pk.ProtocolVersion == 5 {
n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
if err != nil {
return ErrMalformedWillProperties
}
@@ -496,7 +497,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
nb.Write(pb.Bytes())
}
@@ -539,7 +540,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -608,7 +609,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload))
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
nb.Write(pb.Bytes())
}
@@ -692,7 +693,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
nb.WriteByte(pk.ReasonCode)
}
@@ -833,7 +834,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes))
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
nb.Write(pb.Bytes())
}
@@ -890,7 +891,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -985,7 +986,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -1038,7 +1039,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -1101,7 +1102,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
pk.FixedHeader.Remaining = nb.Len()

View File

@@ -42,11 +42,11 @@ const (
// validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1},
PropMessageExpiryInterval: {Publish: 1},
PropContentType: {Publish: 1},
PropResponseTopic: {Publish: 1},
PropCorrelationData: {Publish: 1},
PropPayloadFormat: {Publish: 1, WillProperties: 1},
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
PropContentType: {Publish: 1, WillProperties: 1},
PropResponseTopic: {Publish: 1, WillProperties: 1},
PropCorrelationData: {Publish: 1, WillProperties: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1},
@@ -54,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {Connect: 1},
PropWillDelayInterval: {WillProperties: 1},
PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1},
@@ -64,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1},
@@ -194,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool {
}
// Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
if p == nil {
return
}
var buf bytes.Buffer
pkt := pk.FixedHeader.Type
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat)
@@ -217,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData))
}
@@ -277,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RequestResponseInfo)
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo))
}
@@ -289,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString)
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
buf.WriteByte(PropReasonString)
buf.Write(b)
}
@@ -322,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RetainAvailable)
}
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := bytes.NewBuffer([]byte{})
for _, v := range p.User {
pb.WriteByte(PropUser)
@@ -331,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
}
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
buf.Write(pb.Bytes())
}
}
@@ -361,7 +359,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
}
// Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
if p == nil {
return 0, nil
}
@@ -384,8 +382,8 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
return n + bu, err
}
if _, ok := validPacketProperties[k][pk]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
if _, ok := validPacketProperties[k][pkt]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
}
switch k {

View File

@@ -202,14 +202,14 @@ func init() {
func TestEncodeProperties(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes())
}
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
@@ -219,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
@@ -232,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) {
pr := tmp{}
b := bytes.NewBuffer([]byte{})
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
pr.p.Encode(Reserved, Mods{}, b, 0)
require.Equal(t, []byte{}, b.Bytes())
}
@@ -240,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties)
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes())
}
@@ -250,7 +250,7 @@ func TestDecodeProperties(t *testing.T) {
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172 + 2, n)
require.Equal(t, 172+2, n)
require.EqualValues(t, propertiesStruct, *props)
}

View File

@@ -26,7 +26,7 @@ import (
)
const (
Version = "2.2.5" // the current server version.
Version = "2.2.8" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
)
@@ -86,6 +86,12 @@ type Options struct {
// server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
Capabilities *Capabilities
// ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer.
ClientNetWriteBufferSize int
// ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer.
ClientNetReadBufferSize int
// Logger specifies a custom configured implementation of zerolog to override
// the servers default logger configuration. If you wish to change the log level,
// of the default logger, you can do so by setting
@@ -124,10 +130,10 @@ type loop struct {
// ops contains server values which can be propagated to other structs.
type ops struct {
capabilities *Capabilities // a pointer to the server capabilities, for referencing in clients
info *system.Info // pointers to server system info
hooks *Hooks // pointer to the server hooks
log *zerolog.Logger // a structured logger for the client
options *Options // a pointer to the server options and capabilities, for referencing in clients
info *system.Info // pointers to server system info
hooks *Hooks // pointer to the server hooks
log *zerolog.Logger // a structured logger for the client
}
// New returns a new instance of mochi mqtt broker. Optional parameters
@@ -178,6 +184,14 @@ func (o *Options) ensureDefaults() {
o.SysTopicResendInterval = defaultSysTopicInterval
}
if o.ClientNetWriteBufferSize == 0 {
o.ClientNetWriteBufferSize = 1024 * 2
}
if o.ClientNetReadBufferSize == 0 {
o.ClientNetReadBufferSize = 1024 * 2
}
if o.Logger == nil {
log := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.InfoLevel).Output(zerolog.ConsoleWriter{Out: os.Stderr})
o.Logger = &log
@@ -190,10 +204,10 @@ func (o *Options) ensureDefaults() {
// topic validation checks.
func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client {
cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit
capabilities: s.Options.Capabilities,
info: s.Info,
hooks: s.hooks,
log: s.Log,
options: s.Options,
info: s.Info,
hooks: s.hooks,
log: s.Log,
})
cl.ID = id
@@ -358,9 +372,8 @@ func (s *Server) attachClient(cl *Client, listener string) error {
}
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)
close(cl.State.outbound)
if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
cl.ClearInflights(math.MaxInt64, 0)
@@ -449,9 +462,9 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
atomic.StoreUint32(&existing.State.isTakenOver, 1)
if existing.State.Inflight.Len() > 0 {
cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5]
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.capabilities.ReceiveMaximum != 0 {
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 {
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
}
}
@@ -813,6 +826,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
if out.FixedHeader.Qos > 0 {
i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1]
if err != nil {
s.hooks.OnPacketIDExhausted(cl, pk)
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted")
return out, packets.ErrQuotaExceeded
}
@@ -833,7 +847,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
}
}
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Net.Conn == nil || cl.Closed() {
return out, packets.CodeDisconnect
}
@@ -984,15 +998,15 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
if code != packets.CodeSuccess {
reasonCodes[i] = code.Code // NB 3.9.3 Non-normative 0x91
continue
} else if !IsValidFilter(sub.Filter, false) {
reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
} else if !s.hooks.OnACLCheck(cl, sub.Filter, false) {
reasonCodes[i] = packets.ErrNotAuthorized.Code
if s.Options.Capabilities.Compatibilities.ObscureNotAuthorized {
reasonCodes[i] = packets.ErrUnspecifiedError.Code
}
} else if !IsValidFilter(sub.Filter, false) {
reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
} else {
isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if isNew {
@@ -1401,25 +1415,7 @@ func (s *Server) loadClients(v []storage.Client) {
func (s *Server) loadInflight(v []storage.Message) {
for _, msg := range v {
if client, ok := s.Clients.Get(msg.Origin); ok {
client.State.Inflight.Set(packets.Packet{
FixedHeader: msg.FixedHeader,
PacketID: msg.PacketID,
TopicName: msg.TopicName,
Payload: msg.Payload,
Origin: msg.Origin,
Created: msg.Created,
Properties: packets.Properties{
PayloadFormat: msg.Properties.PayloadFormat,
PayloadFormatFlag: msg.Properties.PayloadFormatFlag,
MessageExpiryInterval: msg.Properties.MessageExpiryInterval,
ContentType: msg.Properties.ContentType,
ResponseTopic: msg.Properties.ResponseTopic,
CorrelationData: msg.Properties.CorrelationData,
SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier,
TopicAlias: msg.Properties.TopicAlias,
User: msg.Properties.User,
},
})
client.State.Inflight.Set(msg.ToPacket())
}
}
}
@@ -1427,24 +1423,7 @@ func (s *Server) loadInflight(v []storage.Message) {
// loadRetained restores retained messages from the datastore.
func (s *Server) loadRetained(v []storage.Message) {
for _, msg := range v {
s.Topics.RetainMessage(packets.Packet{
FixedHeader: msg.FixedHeader,
TopicName: msg.TopicName,
Payload: msg.Payload,
Origin: msg.Origin,
Created: msg.Created,
Properties: packets.Properties{
PayloadFormat: msg.Properties.PayloadFormat,
PayloadFormatFlag: msg.Properties.PayloadFormatFlag,
MessageExpiryInterval: msg.Properties.MessageExpiryInterval,
ContentType: msg.Properties.ContentType,
ResponseTopic: msg.Properties.ResponseTopic,
CorrelationData: msg.Properties.CorrelationData,
SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier,
TopicAlias: msg.Properties.TopicAlias,
User: msg.Properties.User,
},
})
s.Topics.RetainMessage(msg.ToPacket())
}
}

View File

@@ -1535,14 +1535,16 @@ func TestPublishToClientExceedClientWritesPending(t *testing.T) {
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
MaximumClientWritesPending: 3,
options: &Options{
Capabilities: &Capabilities{
MaximumClientWritesPending: 3,
},
},
})
s.Clients.Add(cl)
for i := int32(0); i < cl.ops.capabilities.MaximumClientWritesPending; i++ {
for i := int32(0); i < cl.ops.options.Capabilities.MaximumClientWritesPending; i++ {
cl.State.outbound <- new(packets.Packet)
atomic.AddInt32(&cl.State.outboundQty, 1)
}
@@ -1585,7 +1587,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) {
func TestPublishToClientExhaustedPacketID(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
for i := uint32(0); i <= cl.ops.capabilities.maximumPacketID; i++ {
for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
}
@@ -1654,7 +1656,7 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
for i := uint32(0); i <= cl.ops.capabilities.maximumPacketID; i++ {
for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: 1})
}
@@ -2474,7 +2476,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 0, s.loop.willDelayed.Len())
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
require.True(t, cl.Closed())
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
}
@@ -2804,7 +2806,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl0, _, _ := newTestClient()
cl0.ID = "c0"
cl0.State.disconnected = n - 10
cl0.State.done = 1
cl0.State.open = nil
cl0.Properties.ProtocolVersion = 5
cl0.Properties.Props.SessionExpiryInterval = 12
cl0.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2814,7 +2816,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl1, _, _ := newTestClient()
cl1.ID = "c1"
cl1.State.disconnected = n - 10
cl1.State.done = 1
cl1.State.open = nil
cl1.Properties.ProtocolVersion = 5
cl1.Properties.Props.SessionExpiryInterval = 8
cl1.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2824,7 +2826,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl2, _, _ := newTestClient()
cl2.ID = "c2"
cl2.State.disconnected = n - 10
cl2.State.done = 1
cl2.State.open = nil
cl2.Properties.ProtocolVersion = 5
cl2.Properties.Props.SessionExpiryInterval = 0
cl2.Properties.Props.SessionExpiryIntervalFlag = true

View File

@@ -301,6 +301,9 @@ func NewTopicsIndex() *TopicsIndex {
// Subscribe adds a new subscription for a client to a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
prefix, _ := isolateParticle(subscription.Filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
@@ -320,6 +323,9 @@ func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription
// Unsubscribe removes a subscription filter for a client, returning true if the
// subscription existed.
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
x.root.Lock()
defer x.root.Unlock()
var d int
if strings.HasPrefix(filter, SharePrefix) {
d = 2
@@ -346,6 +352,9 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
x.root.Lock()
defer x.root.Unlock()
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()