mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-04 15:52:55 +08:00
Compare commits
19 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
1adb02e087 | ||
![]() |
4d4140aa99 | ||
![]() |
e31840a37d | ||
![]() |
7d2e16f2d3 | ||
![]() |
92cd935a16 | ||
![]() |
25ce27ce2d | ||
![]() |
527d084a4b | ||
![]() |
bb9f937bb0 | ||
![]() |
511fe88684 | ||
![]() |
75504ff201 | ||
![]() |
a556feb325 | ||
![]() |
d06f47f4b9 | ||
![]() |
8d4cc091b4 | ||
![]() |
d8f28cb843 | ||
![]() |
88861c219d | ||
![]() |
7ba6cf28d9 | ||
![]() |
c174cfdc6b | ||
![]() |
4f198a99dd | ||
![]() |
2a9c9fcc40 |
@@ -113,6 +113,7 @@ Examples of running the broker with various configurations can be found in the [
|
|||||||
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
|
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
|
||||||
|
|
||||||
- `listeners.NewTCP(...)` - A TCP listener.
|
- `listeners.NewTCP(...)` - A TCP listener.
|
||||||
|
- `listeners.NewUnixSock(...)` - A Unix Socket listener.
|
||||||
- `listeners.NewWebsocket(...)` A Websocket listener.
|
- `listeners.NewWebsocket(...)` A Websocket listener.
|
||||||
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
|
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
|
||||||
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
|
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
|
||||||
@@ -296,7 +297,6 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
|
|||||||
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
|
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
|
||||||
| OnClientExpired | Called when a client session has expired and should be deleted. |
|
| OnClientExpired | Called when a client session has expired and should be deleted. |
|
||||||
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
|
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
|
||||||
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
|
|
||||||
| StoredClients | Returns clients, eg. from a persistent store. |
|
| StoredClients | Returns clients, eg. from a persistent store. |
|
||||||
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
|
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
|
||||||
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
|
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
|
||||||
|
15
clients.go
15
clients.go
@@ -290,17 +290,18 @@ func (cl *Client) ResendInflightMessages(force bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
|
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
|
||||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
|
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
|
||||||
var deleted int64
|
deleted := []uint16{}
|
||||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||||
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
|
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
|
||||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||||
deleted++
|
deleted = append(deleted, uint16(tk.PacketID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return deleted
|
return deleted
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,6 +384,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize {
|
||||||
|
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
|
||||||
|
}
|
||||||
|
|
||||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -472,8 +477,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
|||||||
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
|
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
|
||||||
}
|
}
|
||||||
|
|
||||||
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||||
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
|
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
|
||||||
}
|
}
|
||||||
|
|
||||||
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
|
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
|
||||||
|
@@ -272,7 +272,9 @@ func TestClientClearInflights(t *testing.T) {
|
|||||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
|
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
|
||||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||||
|
|
||||||
cl.ClearInflights(n, 4)
|
deleted := cl.ClearInflights(n, 4)
|
||||||
|
require.Len(t, deleted, 3)
|
||||||
|
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
|
||||||
require.Equal(t, 2, cl.State.Inflight.Len())
|
require.Equal(t, 2, cl.State.Inflight.Len())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -350,6 +352,22 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||||
|
cl, r, _ := newTestClient()
|
||||||
|
cl.ops.capabilities.MaximumPacketSize = 2
|
||||||
|
defer cl.Stop(errClientStop)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||||
|
r.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
fh := new(packets.FixedHeader)
|
||||||
|
err := cl.ReadFixedHeader(fh)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
|
||||||
|
}
|
||||||
|
|
||||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||||
cl, r, _ := newTestClient()
|
cl, r, _ := newTestClient()
|
||||||
defer cl.Stop(errClientStop)
|
defer cl.Stop(errClientStop)
|
||||||
|
@@ -29,7 +29,6 @@ func main() {
|
|||||||
server.Options.Capabilities.ServerKeepAlive = 60
|
server.Options.Capabilities.ServerKeepAlive = 60
|
||||||
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
||||||
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
|
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
|
||||||
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
|
|
||||||
|
|
||||||
_ = server.AddHook(new(pahoAuthHook), nil)
|
_ = server.AddHook(new(pahoAuthHook), nil)
|
||||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||||
|
24
hooks.go
24
hooks.go
@@ -47,7 +47,6 @@ const (
|
|||||||
OnWillSent
|
OnWillSent
|
||||||
OnClientExpired
|
OnClientExpired
|
||||||
OnRetainedExpired
|
OnRetainedExpired
|
||||||
OnExpireInflights
|
|
||||||
StoredClients
|
StoredClients
|
||||||
StoredSubscriptions
|
StoredSubscriptions
|
||||||
StoredInflightMessages
|
StoredInflightMessages
|
||||||
@@ -96,7 +95,6 @@ type Hook interface {
|
|||||||
OnWillSent(cl *Client, pk packets.Packet)
|
OnWillSent(cl *Client, pk packets.Packet)
|
||||||
OnClientExpired(cl *Client)
|
OnClientExpired(cl *Client)
|
||||||
OnRetainedExpired(filter string)
|
OnRetainedExpired(filter string)
|
||||||
OnExpireInflights(cl *Client, expiry int64)
|
|
||||||
StoredClients() ([]storage.Client, error)
|
StoredClients() ([]storage.Client, error)
|
||||||
StoredSubscriptions() ([]storage.Subscription, error)
|
StoredSubscriptions() ([]storage.Subscription, error)
|
||||||
StoredInflightMessages() ([]storage.Message, error)
|
StoredInflightMessages() ([]storage.Message, error)
|
||||||
@@ -351,7 +349,7 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPublish is called when a client publishes a message. This method differs from OnMessage
|
// OnPublish is called when a client publishes a message. This method differs from OnPublished
|
||||||
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
||||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||||
@@ -414,8 +412,8 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OnQosDropped is called the Qos flow for a message expires. In other words, when
|
// OnQosDropped is called the Qos flow for a message expires. In other words, when
|
||||||
// an inflight message expires or is abandoned.
|
// an inflight message expires or is abandoned. It is typically used to delete an
|
||||||
// It is typically used to delete an inflight message from a store.
|
// inflight message from a store.
|
||||||
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||||
for _, hook := range h.internal {
|
for _, hook := range h.internal {
|
||||||
if hook.Provides(OnQosDropped) {
|
if hook.Provides(OnQosDropped) {
|
||||||
@@ -601,19 +599,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnExpireInflights is called when the server issues a clear request for expired
|
|
||||||
// inflight messages. Expiry should be the time after which the message is no longer
|
|
||||||
// valid (usually some time in the past). A message has expired if it's created time
|
|
||||||
// is older than time.Now() minus Inflight TTL. This method can be used to expire
|
|
||||||
// old inflight messages in a persistent store which doesnt support per-item TTL.
|
|
||||||
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
|
|
||||||
for _, hook := range h.internal {
|
|
||||||
if hook.Provides(OnExpireInflights) {
|
|
||||||
hook.OnExpireInflights(cl, expiry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HookBase provides a set of default methods for each hook. It should be embedded in
|
// HookBase provides a set of default methods for each hook. It should be embedded in
|
||||||
// all hooks.
|
// all hooks.
|
||||||
type HookBase struct {
|
type HookBase struct {
|
||||||
@@ -755,9 +740,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
|
|||||||
// OnRetainedExpired is called when a retained message for a topic has expired.
|
// OnRetainedExpired is called when a retained message for a topic has expired.
|
||||||
func (h *HookBase) OnRetainedExpired(topic string) {}
|
func (h *HookBase) OnRetainedExpired(topic string) {}
|
||||||
|
|
||||||
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
|
|
||||||
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
|
|
||||||
|
|
||||||
// StoredClients returns all clients from a store.
|
// StoredClients returns all clients from a store.
|
||||||
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
|
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
|
||||||
return
|
return
|
||||||
|
@@ -80,7 +80,6 @@ func (h *Hook) Provides(b byte) bool {
|
|||||||
mqtt.OnSysInfoTick,
|
mqtt.OnSysInfoTick,
|
||||||
mqtt.OnClientExpired,
|
mqtt.OnClientExpired,
|
||||||
mqtt.OnRetainedExpired,
|
mqtt.OnRetainedExpired,
|
||||||
mqtt.OnExpireInflights,
|
|
||||||
mqtt.StoredClients,
|
mqtt.StoredClients,
|
||||||
mqtt.StoredInflightMessages,
|
mqtt.StoredInflightMessages,
|
||||||
mqtt.StoredRetainedMessages,
|
mqtt.StoredRetainedMessages,
|
||||||
@@ -348,32 +347,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
|
// OnRetainedExpired deletes expired retained messages from the store.
|
||||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
func (h *Hook) OnRetainedExpired(filter string) {
|
||||||
if h.db == nil {
|
if h.db == nil {
|
||||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var v []storage.Message
|
|
||||||
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
|
||||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
|
||||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range v {
|
|
||||||
if m.Created < expiry || m.Created == 0 {
|
|
||||||
err := h.db.Delete(m.ID, new(storage.Message))
|
|
||||||
if err != nil {
|
|
||||||
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnRetainedExpired deletes expired retained messages from the store.
|
|
||||||
func (h *Hook) OnRetainedExpired(filter string) {
|
|
||||||
err := h.db.Delete(retainedKey(filter), new(storage.Message))
|
err := h.db.Delete(retainedKey(filter), new(storage.Message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
|
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
|
||||||
@@ -382,6 +362,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
|||||||
|
|
||||||
// OnClientExpired deleted expired clients from the store.
|
// OnClientExpired deleted expired clients from the store.
|
||||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||||
|
if h.db == nil {
|
||||||
|
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
|
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
|
||||||
|
@@ -5,13 +5,11 @@
|
|||||||
package badger
|
package badger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/asdine/storm/v3"
|
|
||||||
"github.com/mochi-co/mqtt/v2"
|
"github.com/mochi-co/mqtt/v2"
|
||||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||||
"github.com/mochi-co/mqtt/v2/packets"
|
"github.com/mochi-co/mqtt/v2/packets"
|
||||||
@@ -170,6 +168,21 @@ func TestOnClientExpired(t *testing.T) {
|
|||||||
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||||
|
h := new(Hook)
|
||||||
|
h.SetOpts(&logger, nil)
|
||||||
|
h.OnClientExpired(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||||
|
h := new(Hook)
|
||||||
|
h.SetOpts(&logger, nil)
|
||||||
|
err := h.Init(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
teardown(t, h.config.Path, h)
|
||||||
|
h.OnClientExpired(client)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||||
h := new(Hook)
|
h := new(Hook)
|
||||||
h.SetOpts(&logger, nil)
|
h.SetOpts(&logger, nil)
|
||||||
@@ -333,6 +346,21 @@ func TestOnRetainedExpired(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||||
|
h := new(Hook)
|
||||||
|
h.SetOpts(&logger, nil)
|
||||||
|
h.OnRetainedExpired("a/b/c")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||||
|
h := new(Hook)
|
||||||
|
h.SetOpts(&logger, nil)
|
||||||
|
err := h.Init(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
teardown(t, h.config.Path, h)
|
||||||
|
h.OnRetainedExpired("a/b/c")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||||
h := new(Hook)
|
h := new(Hook)
|
||||||
h.SetOpts(&logger, nil)
|
h.SetOpts(&logger, nil)
|
||||||
@@ -419,48 +447,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
|||||||
h.OnQosDropped(client, packets.Packet{})
|
h.OnQosDropped(client, packets.Packet{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOnExpireInflights(t *testing.T) {
|
|
||||||
h := new(Hook)
|
|
||||||
h.SetOpts(&logger, nil)
|
|
||||||
|
|
||||||
err := h.Init(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer teardown(t, h.config.Path, h)
|
|
||||||
|
|
||||||
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
|
|
||||||
var v []storage.Message
|
|
||||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
|
||||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Len(t, v, 1)
|
|
||||||
require.Equal(t, "i1", v[0].ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
|
||||||
h := new(Hook)
|
|
||||||
h.SetOpts(&logger, nil)
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
|
||||||
h := new(Hook)
|
|
||||||
h.SetOpts(&logger, nil)
|
|
||||||
err := h.Init(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
teardown(t, h.config.Path, h)
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnSysInfoTick(t *testing.T) {
|
func TestOnSysInfoTick(t *testing.T) {
|
||||||
h := new(Hook)
|
h := new(Hook)
|
||||||
h.SetOpts(&logger, nil)
|
h.SetOpts(&logger, nil)
|
||||||
|
@@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool {
|
|||||||
mqtt.OnSysInfoTick,
|
mqtt.OnSysInfoTick,
|
||||||
mqtt.OnClientExpired,
|
mqtt.OnClientExpired,
|
||||||
mqtt.OnRetainedExpired,
|
mqtt.OnRetainedExpired,
|
||||||
mqtt.OnExpireInflights,
|
|
||||||
mqtt.StoredClients,
|
mqtt.StoredClients,
|
||||||
mqtt.StoredInflightMessages,
|
mqtt.StoredInflightMessages,
|
||||||
mqtt.StoredRetainedMessages,
|
mqtt.StoredRetainedMessages,
|
||||||
@@ -369,34 +368,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnExpireInflights removes all inflight messages which have passed the
|
// OnRetainedExpired deletes expired retained messages from the store.
|
||||||
// provided expiry time.
|
func (h *Hook) OnRetainedExpired(filter string) {
|
||||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
|
||||||
if h.db == nil {
|
if h.db == nil {
|
||||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var v []storage.Message
|
|
||||||
err := h.db.Find("T", storage.InflightKey, &v)
|
|
||||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
|
||||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range v {
|
|
||||||
if m.Created < expiry || m.Created == 0 {
|
|
||||||
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
|
|
||||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
|
||||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnRetainedExpired deletes expired retained messages from the store.
|
|
||||||
func (h *Hook) OnRetainedExpired(filter string) {
|
|
||||||
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
|
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
|
||||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
|
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
|
||||||
}
|
}
|
||||||
@@ -404,6 +382,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
|||||||
|
|
||||||
// OnClientExpired deleted expired clients from the store.
|
// OnClientExpired deleted expired clients from the store.
|
||||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||||
|
if h.db == nil {
|
||||||
|
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||||
|
@@ -5,7 +5,6 @@
|
|||||||
package bolt
|
package bolt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -212,6 +211,21 @@ func TestOnClientExpired(t *testing.T) {
|
|||||||
require.ErrorIs(t, storm.ErrNotFound, err)
|
require.ErrorIs(t, storm.ErrNotFound, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||||
|
h := new(Hook)
|
||||||
|
h.SetOpts(&logger, nil)
|
||||||
|
err := h.Init(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
teardown(t, h.config.Path, h)
|
||||||
|
h.OnClientExpired(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||||
|
h := new(Hook)
|
||||||
|
h.SetOpts(&logger, nil)
|
||||||
|
h.OnClientExpired(client)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOnDisconnectNoDB(t *testing.T) {
|
func TestOnDisconnectNoDB(t *testing.T) {
|
||||||
h := new(Hook)
|
h := new(Hook)
|
||||||
h.SetOpts(&logger, nil)
|
h.SetOpts(&logger, nil)
|
||||||
@@ -341,6 +355,21 @@ func TestOnRetainedExpired(t *testing.T) {
|
|||||||
require.Equal(t, storm.ErrNotFound, err)
|
require.Equal(t, storm.ErrNotFound, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||||
|
h := new(Hook)
|
||||||
|
h.SetOpts(&logger, nil)
|
||||||
|
err := h.Init(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
teardown(t, h.config.Path, h)
|
||||||
|
h.OnRetainedExpired("a/b/c")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||||
|
h := new(Hook)
|
||||||
|
h.SetOpts(&logger, nil)
|
||||||
|
h.OnRetainedExpired("a/b/c")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||||
h := new(Hook)
|
h := new(Hook)
|
||||||
h.SetOpts(&logger, nil)
|
h.SetOpts(&logger, nil)
|
||||||
@@ -427,48 +456,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
|||||||
h.OnQosDropped(client, packets.Packet{})
|
h.OnQosDropped(client, packets.Packet{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOnExpireInflights(t *testing.T) {
|
|
||||||
h := new(Hook)
|
|
||||||
h.SetOpts(&logger, nil)
|
|
||||||
|
|
||||||
err := h.Init(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer teardown(t, h.config.Path, h)
|
|
||||||
|
|
||||||
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
|
|
||||||
var v []storage.Message
|
|
||||||
err = h.db.Find("T", storage.InflightKey, &v)
|
|
||||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Len(t, v, 1)
|
|
||||||
require.Equal(t, "i1", v[0].ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
|
||||||
h := new(Hook)
|
|
||||||
h.SetOpts(&logger, nil)
|
|
||||||
err := h.Init(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
teardown(t, h.config.Path, h)
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
|
||||||
h := new(Hook)
|
|
||||||
h.SetOpts(&logger, nil)
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnSysInfoTick(t *testing.T) {
|
func TestOnSysInfoTick(t *testing.T) {
|
||||||
h := new(Hook)
|
h := new(Hook)
|
||||||
h.SetOpts(&logger, nil)
|
h.SetOpts(&logger, nil)
|
||||||
|
@@ -83,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
|
|||||||
mqtt.OnSysInfoTick,
|
mqtt.OnSysInfoTick,
|
||||||
mqtt.OnClientExpired,
|
mqtt.OnClientExpired,
|
||||||
mqtt.OnRetainedExpired,
|
mqtt.OnRetainedExpired,
|
||||||
mqtt.OnExpireInflights,
|
|
||||||
mqtt.StoredClients,
|
mqtt.StoredClients,
|
||||||
mqtt.StoredInflightMessages,
|
mqtt.StoredInflightMessages,
|
||||||
mqtt.StoredRetainedMessages,
|
mqtt.StoredRetainedMessages,
|
||||||
@@ -364,37 +363,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnExpireInflights removes all inflight messages which have passed the
|
// OnRetainedExpired deletes expired retained messages from the store.
|
||||||
// provided expiry time.
|
func (h *Hook) OnRetainedExpired(filter string) {
|
||||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
|
||||||
if h.db == nil {
|
if h.db == nil {
|
||||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
|
||||||
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, row := range rows {
|
|
||||||
var d storage.Message
|
|
||||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
|
||||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.Created < expiry || d.Created == 0 {
|
|
||||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
|
|
||||||
if err != nil {
|
|
||||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnRetainedExpired deletes expired retained messages from the store.
|
|
||||||
func (h *Hook) OnRetainedExpired(filter string) {
|
|
||||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
|
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
|
||||||
@@ -403,6 +378,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
|||||||
|
|
||||||
// OnClientExpired deleted expired clients from the store.
|
// OnClientExpired deleted expired clients from the store.
|
||||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||||
|
if h.db == nil {
|
||||||
|
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||||
|
@@ -253,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
|
|||||||
require.ErrorIs(t, redis.Nil, err)
|
require.ErrorIs(t, redis.Nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||||
|
s := miniredis.RunT(t)
|
||||||
|
defer s.Close()
|
||||||
|
h := newHook(t, s.Addr())
|
||||||
|
teardown(t, h)
|
||||||
|
h.OnClientExpired(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||||
|
s := miniredis.RunT(t)
|
||||||
|
defer s.Close()
|
||||||
|
h := newHook(t, s.Addr())
|
||||||
|
h.db = nil
|
||||||
|
h.OnClientExpired(client)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOnDisconnectNoDB(t *testing.T) {
|
func TestOnDisconnectNoDB(t *testing.T) {
|
||||||
s := miniredis.RunT(t)
|
s := miniredis.RunT(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
@@ -392,6 +408,22 @@ func TestOnRetainedExpired(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, redis.Nil)
|
require.ErrorIs(t, err, redis.Nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||||
|
s := miniredis.RunT(t)
|
||||||
|
defer s.Close()
|
||||||
|
h := newHook(t, s.Addr())
|
||||||
|
teardown(t, h)
|
||||||
|
h.OnRetainedExpired("a/b/c")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||||
|
s := miniredis.RunT(t)
|
||||||
|
defer s.Close()
|
||||||
|
h := newHook(t, s.Addr())
|
||||||
|
h.db = nil
|
||||||
|
h.OnRetainedExpired("a/b/c")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||||
s := miniredis.RunT(t)
|
s := miniredis.RunT(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
@@ -484,60 +516,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
|||||||
h.OnQosDropped(client, packets.Packet{})
|
h.OnQosDropped(client, packets.Packet{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOnExpireInflights(t *testing.T) {
|
|
||||||
s := miniredis.RunT(t)
|
|
||||||
defer s.Close()
|
|
||||||
h := newHook(t, s.Addr())
|
|
||||||
defer teardown(t, h)
|
|
||||||
|
|
||||||
n := time.Now().Unix()
|
|
||||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
|
|
||||||
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
|
|
||||||
).Err()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
|
|
||||||
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
|
|
||||||
).Err()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
|
|
||||||
&storage.Message{ID: "i3", T: storage.InflightKey},
|
|
||||||
).Err()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
|
|
||||||
var r []storage.Message
|
|
||||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, rows, 1)
|
|
||||||
for _, row := range rows {
|
|
||||||
var d storage.Message
|
|
||||||
err = d.UnmarshalBinary([]byte(row))
|
|
||||||
require.NoError(t, err)
|
|
||||||
r = append(r, d)
|
|
||||||
}
|
|
||||||
require.Len(t, r, 1)
|
|
||||||
require.Equal(t, "i1", r[0].ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
|
||||||
s := miniredis.RunT(t)
|
|
||||||
defer s.Close()
|
|
||||||
h := newHook(t, s.Addr())
|
|
||||||
teardown(t, h)
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
|
||||||
s := miniredis.RunT(t)
|
|
||||||
defer s.Close()
|
|
||||||
h := newHook(t, s.Addr())
|
|
||||||
h.db = nil
|
|
||||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnSysInfoTick(t *testing.T) {
|
func TestOnSysInfoTick(t *testing.T) {
|
||||||
s := miniredis.RunT(t)
|
s := miniredis.RunT(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
@@ -231,7 +231,6 @@ func TestHooksNonReturns(t *testing.T) {
|
|||||||
h.OnWillSent(cl, packets.Packet{})
|
h.OnWillSent(cl, packets.Packet{})
|
||||||
h.OnClientExpired(cl)
|
h.OnClientExpired(cl)
|
||||||
h.OnRetainedExpired("a/b/c")
|
h.OnRetainedExpired("a/b/c")
|
||||||
h.OnExpireInflights(cl, time.Now().Unix()-1)
|
|
||||||
|
|
||||||
// on second iteration, check added hook methods
|
// on second iteration, check added hook methods
|
||||||
err := h.Add(new(modifiedHookBase), nil)
|
err := h.Add(new(modifiedHookBase), nil)
|
||||||
|
98
listeners/unixsock.go
Normal file
98
listeners/unixsock.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||||
|
// SPDX-FileContributor: jason@zgwit.com
|
||||||
|
|
||||||
|
package listeners
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
|
||||||
|
type UnixSock struct {
|
||||||
|
sync.RWMutex
|
||||||
|
id string // the internal id of the listener.
|
||||||
|
address string // the network address to bind to.
|
||||||
|
listen net.Listener // a net.Listener which will listen for new clients.
|
||||||
|
log *zerolog.Logger // server logger
|
||||||
|
end uint32 // ensure the close methods are only called once.
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
|
||||||
|
func NewUnixSock(id, address string) *UnixSock {
|
||||||
|
return &UnixSock{
|
||||||
|
id: id,
|
||||||
|
address: address,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the id of the listener.
|
||||||
|
func (l *UnixSock) ID() string {
|
||||||
|
return l.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Address returns the address of the listener.
|
||||||
|
func (l *UnixSock) Address() string {
|
||||||
|
return l.address
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protocol returns the address of the listener.
|
||||||
|
func (l *UnixSock) Protocol() string {
|
||||||
|
return "unix"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes the listener.
|
||||||
|
func (l *UnixSock) Init(log *zerolog.Logger) error {
|
||||||
|
l.log = log
|
||||||
|
|
||||||
|
var err error
|
||||||
|
_ = os.Remove(l.address)
|
||||||
|
l.listen, err = net.Listen("unix", l.address)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve starts waiting for new UnixSock connections, and calls the establish
|
||||||
|
// connection callback for any received.
|
||||||
|
func (l *UnixSock) Serve(establish EstablishFn) {
|
||||||
|
for {
|
||||||
|
if atomic.LoadUint32(&l.end) == 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := l.listen.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if atomic.LoadUint32(&l.end) == 0 {
|
||||||
|
go func() {
|
||||||
|
err = establish(l.id, conn)
|
||||||
|
if err != nil {
|
||||||
|
l.log.Warn().Err(err).Send()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the listener and any client connections.
|
||||||
|
func (l *UnixSock) Close(closeClients CloseFn) {
|
||||||
|
l.Lock()
|
||||||
|
defer l.Unlock()
|
||||||
|
|
||||||
|
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||||
|
closeClients(l.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.listen != nil {
|
||||||
|
err := l.listen.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
96
listeners/unixsock_test.go
Normal file
96
listeners/unixsock_test.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||||
|
// SPDX-FileContributor: jason@zgwit.com
|
||||||
|
|
||||||
|
package listeners
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testUnixAddr = "mochi.sock"
|
||||||
|
|
||||||
|
func TestNewUnixSock(t *testing.T) {
|
||||||
|
l := NewUnixSock("t1", testUnixAddr)
|
||||||
|
require.Equal(t, "t1", l.id)
|
||||||
|
require.Equal(t, testUnixAddr, l.address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSockID(t *testing.T) {
|
||||||
|
l := NewUnixSock("t1", testUnixAddr)
|
||||||
|
require.Equal(t, "t1", l.ID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSockAddress(t *testing.T) {
|
||||||
|
l := NewUnixSock("t1", testUnixAddr)
|
||||||
|
require.Equal(t, testUnixAddr, l.Address())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSockProtocol(t *testing.T) {
|
||||||
|
l := NewUnixSock("t1", testUnixAddr)
|
||||||
|
require.Equal(t, "unix", l.Protocol())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSockInit(t *testing.T) {
|
||||||
|
l := NewUnixSock("t1", testUnixAddr)
|
||||||
|
err := l.Init(&logger)
|
||||||
|
l.Close(MockCloser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
l2 := NewUnixSock("t2", testUnixAddr)
|
||||||
|
err = l2.Init(&logger)
|
||||||
|
l2.Close(MockCloser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSockServeAndClose(t *testing.T) {
|
||||||
|
l := NewUnixSock("t1", testUnixAddr)
|
||||||
|
err := l.Init(&logger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
o := make(chan bool)
|
||||||
|
go func(o chan bool) {
|
||||||
|
l.Serve(MockEstablisher)
|
||||||
|
o <- true
|
||||||
|
}(o)
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
var closed bool
|
||||||
|
l.Close(func(id string) {
|
||||||
|
closed = true
|
||||||
|
})
|
||||||
|
|
||||||
|
require.True(t, closed)
|
||||||
|
<-o
|
||||||
|
|
||||||
|
l.Close(MockCloser) // coverage: close closed
|
||||||
|
l.Serve(MockEstablisher) // coverage: serve closed
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||||
|
l := NewUnixSock("t1", testUnixAddr)
|
||||||
|
err := l.Init(&logger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
o := make(chan bool)
|
||||||
|
established := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
l.Serve(func(id string, c net.Conn) error {
|
||||||
|
established <- true
|
||||||
|
return errors.New("ending") // return an error to exit immediately
|
||||||
|
})
|
||||||
|
o <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
net.Dial("unix", l.listen.Addr().String())
|
||||||
|
require.Equal(t, true, <-established)
|
||||||
|
l.Close(MockCloser)
|
||||||
|
<-o
|
||||||
|
}
|
@@ -7,6 +7,7 @@ package listeners
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -137,25 +138,35 @@ type wsConn struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
||||||
func (ws *wsConn) Read(p []byte) (n int, err error) {
|
func (ws *wsConn) Read(p []byte) (int, error) {
|
||||||
op, r, err := ws.c.NextReader()
|
op, r, err := ws.c.NextReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if op != websocket.BinaryMessage {
|
if op != websocket.BinaryMessage {
|
||||||
err = ErrInvalidMessage
|
err = ErrInvalidMessage
|
||||||
return
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.Read(p)
|
var n, br int
|
||||||
|
for {
|
||||||
|
br, err = r.Read(p[n:])
|
||||||
|
n += br
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes bytes to the websocket connection.
|
// Write writes bytes to the websocket connection.
|
||||||
func (ws *wsConn) Write(p []byte) (n int, err error) {
|
func (ws *wsConn) Write(p []byte) (int, error) {
|
||||||
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
|
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
|
50
server.go
50
server.go
@@ -26,10 +26,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Version = "2.0.0" // the current server version.
|
Version = "2.1.0" // the current server version.
|
||||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||||
defaultFanPoolSize uint64 = 64 // the number of concurrent workers in the pool
|
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
|
||||||
defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue
|
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -376,7 +376,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
|
|||||||
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
|
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
|
||||||
s.hooks.OnDisconnect(cl, err, expire)
|
s.hooks.OnDisconnect(cl, err, expire)
|
||||||
if expire {
|
if expire {
|
||||||
s.unsubscribeClient(cl)
|
s.UnsubscribeClient(cl)
|
||||||
cl.ClearInflights(math.MaxInt64, 0)
|
cl.ClearInflights(math.MaxInt64, 0)
|
||||||
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
|
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
|
||||||
}
|
}
|
||||||
@@ -455,7 +455,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
|
|||||||
defer existing.Unlock()
|
defer existing.Unlock()
|
||||||
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||||
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
||||||
s.unsubscribeClient(existing)
|
s.UnsubscribeClient(existing)
|
||||||
existing.ClearInflights(math.MaxInt64, 0)
|
existing.ClearInflights(math.MaxInt64, 0)
|
||||||
return false // [MQTT-3.2.2-3]
|
return false // [MQTT-3.2.2-3]
|
||||||
}
|
}
|
||||||
@@ -697,6 +697,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
|||||||
s.publishToSubscribers(pk)
|
s.publishToSubscribers(pk)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.hooks.OnPublished(cl, pk)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -727,8 +728,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
|||||||
s.publishToSubscribers(pk)
|
s.publishToSubscribers(pk)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.hooks.OnPublish(cl, pk)
|
s.hooks.OnPublished(cl, pk)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -962,7 +962,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
|||||||
code = packets.ErrPacketIdentifierInUse
|
code = packets.ErrPacketIdentifierInUse
|
||||||
}
|
}
|
||||||
|
|
||||||
existed := false
|
filterExisted := make([]bool, len(pk.Filters))
|
||||||
reasonCodes := make([]byte, len(pk.Filters))
|
reasonCodes := make([]byte, len(pk.Filters))
|
||||||
for i, sub := range pk.Filters {
|
for i, sub := range pk.Filters {
|
||||||
if code != packets.CodeSuccess {
|
if code != packets.CodeSuccess {
|
||||||
@@ -978,8 +978,8 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
|||||||
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
|
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
|
||||||
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
|
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
|
||||||
} else {
|
} else {
|
||||||
existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
|
isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
|
||||||
if !existed {
|
if isNew {
|
||||||
atomic.AddInt64(&s.Info.Subscriptions, 1)
|
atomic.AddInt64(&s.Info.Subscriptions, 1)
|
||||||
}
|
}
|
||||||
cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
|
cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
|
||||||
@@ -988,6 +988,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
|||||||
sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
|
sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filterExisted[i] = !isNew
|
||||||
reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
|
reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1022,7 +1023,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
s.publishRetainedToClient(cl, sub, existed)
|
s.publishRetainedToClient(cl, sub, filterExisted[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1072,14 +1073,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
|
|||||||
return cl.WritePacket(ack)
|
return cl.WritePacket(ack)
|
||||||
}
|
}
|
||||||
|
|
||||||
// unsubscribeClient unsubscribes a client from all of their subscriptions.
|
// UnsubscribeClient unsubscribes a client from all of their subscriptions.
|
||||||
func (s *Server) unsubscribeClient(cl *Client) {
|
func (s *Server) UnsubscribeClient(cl *Client) {
|
||||||
for k := range cl.State.Subscriptions.GetAll() {
|
i := 0
|
||||||
|
filterMap := cl.State.Subscriptions.GetAll()
|
||||||
|
filters := make([]packets.Subscription, len(filterMap))
|
||||||
|
for k, v := range filterMap {
|
||||||
cl.State.Subscriptions.Delete(k)
|
cl.State.Subscriptions.Delete(k)
|
||||||
if s.Topics.Unsubscribe(k, cl.ID) {
|
if s.Topics.Unsubscribe(k, cl.ID) {
|
||||||
atomic.AddInt64(&s.Info.Subscriptions, -1)
|
atomic.AddInt64(&s.Info.Subscriptions, -1)
|
||||||
}
|
}
|
||||||
|
filters[i] = v
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
|
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
|
||||||
}
|
}
|
||||||
|
|
||||||
// processAuth processes an Auth packet.
|
// processAuth processes an Auth packet.
|
||||||
@@ -1126,12 +1133,15 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
|
|||||||
|
|
||||||
// We already have a code we are using to disconnect the client, so we are not
|
// We already have a code we are using to disconnect the client, so we are not
|
||||||
// interested if the write packet fails due to a closed connection (as we are closing it).
|
// interested if the write packet fails due to a closed connection (as we are closing it).
|
||||||
_ = cl.WritePacket(out)
|
err := cl.WritePacket(out)
|
||||||
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
|
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
|
||||||
cl.Stop(code)
|
cl.Stop(code)
|
||||||
|
if code.Code >= packets.ErrUnspecifiedError.Code {
|
||||||
|
return code
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return code
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// publishSysTopics publishes the current values to the server $SYS topics.
|
// publishSysTopics publishes the current values to the server $SYS topics.
|
||||||
@@ -1450,8 +1460,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) {
|
|||||||
// clearExpiredInflights deletes any inflight messages which have expired.
|
// clearExpiredInflights deletes any inflight messages which have expired.
|
||||||
func (s *Server) clearExpiredInflights(now int64) {
|
func (s *Server) clearExpiredInflights(now int64) {
|
||||||
for _, client := range s.Clients.GetAll() {
|
for _, client := range s.Clients.GetAll() {
|
||||||
if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 {
|
if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
|
||||||
s.hooks.OnExpireInflights(client, now)
|
for _, id := range deleted {
|
||||||
|
s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -844,7 +844,7 @@ func TestServerUnsubscribeClient(t *testing.T) {
|
|||||||
s.Topics.Subscribe(cl.ID, pk)
|
s.Topics.Subscribe(cl.ID, pk)
|
||||||
subs := s.Topics.Subscribers("a/b/c")
|
subs := s.Topics.Subscribers("a/b/c")
|
||||||
require.Equal(t, 1, len(subs.Subscriptions))
|
require.Equal(t, 1, len(subs.Subscriptions))
|
||||||
s.unsubscribeClient(cl)
|
s.UnsubscribeClient(cl)
|
||||||
subs = s.Topics.Subscribers("a/b/c")
|
subs = s.Topics.Subscribers("a/b/c")
|
||||||
require.Equal(t, 0, len(subs.Subscriptions))
|
require.Equal(t, 0, len(subs.Subscriptions))
|
||||||
}
|
}
|
||||||
@@ -2291,6 +2291,21 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
|
|||||||
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf)
|
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerRecievePacketDisconnectClient(t *testing.T) {
|
||||||
|
s := newServer()
|
||||||
|
cl, r, w := newTestClient()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := s.DisconnectClient(cl, packets.CodeDisconnect)
|
||||||
|
require.NoError(t, err)
|
||||||
|
w.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf, err := io.ReadAll(r)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, buf)
|
||||||
|
}
|
||||||
|
|
||||||
func TestServerProcessPacketDisconnect(t *testing.T) {
|
func TestServerProcessPacketDisconnect(t *testing.T) {
|
||||||
s := newServer()
|
s := newServer()
|
||||||
cl, _, _ := newTestClient()
|
cl, _, _ := newTestClient()
|
||||||
|
Reference in New Issue
Block a user