Compare commits

...

19 Commits

Author SHA1 Message Date
mochi-co
1adb02e087 Update readme and server version 2022-12-21 20:47:58 +00:00
JB
4d4140aa99 Connect ReturnResponseInfo only applies to Connack values (#128) 2022-12-21 20:37:08 +00:00
JB
e31840a37d Optimize inflight expiry (#127)
* Small formatting/style changes for filter existed

* Use OnQoSDropped hook instead of onInflightExpired
2022-12-21 19:44:25 +00:00
JB
7d2e16f2d3 Merge pull request #123 from wind-c/master
Variable existed in the method processSubscribe is unstable
2022-12-21 11:41:14 +00:00
JB
92cd935a16 Merge branch 'master' into master 2022-12-21 11:38:28 +00:00
JB
25ce27ce2d Merge pull request #124 from zgwit/master
Add unix socket listener
2022-12-21 11:28:23 +00:00
jason
527d084a4b Add unix socket listener 2022-12-20 23:02:59 +08:00
Wind
bb9f937bb0 Variable existed in the method processSubscribe is unstable
The variable existed can be changed repeatedly within a for loop. An array variable must be used to record the subscription of each filter.
2022-12-18 13:46:06 +08:00
Wind
511fe88684 Merge branch 'mochi-co:master' into master 2022-12-17 12:33:09 +08:00
JB
75504ff201 Update server version 2022-12-16 18:27:29 +00:00
Wind
a556feb325 Add the OnUnsubscribed hook to the unsubscribeClient method (#122)
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-16 18:23:58 +00:00
“Wind”
d06f47f4b9 Add the OnUnsubscribed hook to the unsubscribeClient method
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-17 00:40:06 +08:00
JB
8d4cc091b4 Update version number 2022-12-16 00:31:59 +00:00
JB
d8f28cb843 Enforce server max packet (#121)
* Enforce Server Maximum Packet Size on client read
* Fix tests
2022-12-16 00:30:23 +00:00
JB
88861c219d Merge pull request #116 from tommyminds/bugfix/ws_malformed_package
Fix websocket malformed packet bug
2022-12-15 18:21:53 +00:00
JB
7ba6cf28d9 Merge branch 'master' into bugfix/ws_malformed_package 2022-12-15 18:21:33 +00:00
JB
c174cfdc6b Merge pull request #119 from mochi-co/fix-on-published
Fix mis-typed onpublished hook, update version, fanpool defaults
2022-12-15 18:21:19 +00:00
mochi-co
4f198a99dd Fix mis-typed onpublished hook, update version, fanpool defaults 2022-12-15 18:19:02 +00:00
Tommy Maintz
2a9c9fcc40 Fix websocket malformed packet bug 2022-12-14 21:41:33 +01:00
17 changed files with 405 additions and 271 deletions

View File

@@ -113,6 +113,7 @@ Examples of running the broker with various configurations can be found in the [
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
- `listeners.NewTCP(...)` - A TCP listener.
- `listeners.NewUnixSock(...)` - A Unix Socket listener.
- `listeners.NewWebsocket(...)` A Websocket listener.
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
@@ -296,7 +297,6 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |

View File

@@ -290,17 +290,18 @@ func (cl *Client) ResendInflightMessages(force bool) error {
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
var deleted int64
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted++
deleted = append(deleted, uint16(tk.PacketID))
}
}
}
return deleted
}
@@ -383,6 +384,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
return err
}
if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
@@ -472,8 +477,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)

View File

@@ -272,7 +272,9 @@ func TestClientClearInflights(t *testing.T) {
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
cl.ClearInflights(n, 4)
deleted := cl.ClearInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
require.Equal(t, 2, cl.State.Inflight.Len())
}
@@ -350,6 +352,22 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
require.Error(t, err)
}
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)

View File

@@ -29,7 +29,6 @@ func main() {
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
_ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)

View File

@@ -47,7 +47,6 @@ const (
OnWillSent
OnClientExpired
OnRetainedExpired
OnExpireInflights
StoredClients
StoredSubscriptions
StoredInflightMessages
@@ -96,7 +95,6 @@ type Hook interface {
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
OnExpireInflights(cl *Client, expiry int64)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
@@ -351,7 +349,7 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
}
}
// OnPublish is called when a client publishes a message. This method differs from OnMessage
// OnPublish is called when a client publishes a message. This method differs from OnPublished
// in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
@@ -414,8 +412,8 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned.
// It is typically used to delete an inflight message from a store.
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
if hook.Provides(OnQosDropped) {
@@ -601,19 +599,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnExpireInflights is called when the server issues a clear request for expired
// inflight messages. Expiry should be the time after which the message is no longer
// valid (usually some time in the past). A message has expired if it's created time
// is older than time.Now() minus Inflight TTL. This method can be used to expire
// old inflight messages in a persistent store which doesnt support per-item TTL.
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
for _, hook := range h.internal {
if hook.Provides(OnExpireInflights) {
hook.OnExpireInflights(cl, expiry)
}
}
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
@@ -755,9 +740,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return

View File

@@ -80,7 +80,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -348,32 +347,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.Delete(m.ID, new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.Delete(retainedKey(filter), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
@@ -382,6 +362,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")

View File

@@ -5,13 +5,11 @@
package badger
import (
"errors"
"os"
"strings"
"testing"
"time"
"github.com/asdine/storm/v3"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
@@ -170,6 +168,21 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, badgerhold.ErrNotFound, err)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -333,6 +346,21 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -419,48 +447,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -369,34 +368,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
return
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
}
@@ -404,6 +382,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")

View File

@@ -5,7 +5,6 @@
package bolt
import (
"errors"
"os"
"testing"
"time"
@@ -212,6 +211,21 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, storm.ErrNotFound, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -341,6 +355,21 @@ func TestOnRetainedExpired(t *testing.T) {
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -427,48 +456,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -83,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -364,37 +363,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
if d.Created < expiry || d.Created == 0 {
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
@@ -403,6 +378,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")

View File

@@ -253,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, redis.Nil, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -392,6 +408,22 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -484,60 +516,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
n := time.Now().Unix()
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
&storage.Message{ID: "i3", T: storage.InflightKey},
).Err()
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var r []storage.Message
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
require.NoError(t, err)
require.Len(t, rows, 1)
for _, row := range rows {
var d storage.Message
err = d.UnmarshalBinary([]byte(row))
require.NoError(t, err)
r = append(r, d)
}
require.Len(t, r, 1)
require.Equal(t, "i1", r[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()

View File

@@ -231,7 +231,6 @@ func TestHooksNonReturns(t *testing.T) {
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
h.OnExpireInflights(cl, time.Now().Unix()-1)
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)

98
listeners/unixsock.go Normal file
View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"net"
"os"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once.
}
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock {
return &UnixSock{
id: id,
address: address,
}
}
// ID returns the id of the listener.
func (l *UnixSock) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *UnixSock) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *UnixSock) Protocol() string {
return "unix"
}
// Init initializes the listener.
func (l *UnixSock) Init(log *zerolog.Logger) error {
l.log = log
var err error
_ = os.Remove(l.address)
l.listen, err = net.Listen("unix", l.address)
return err
}
// Serve starts waiting for new UnixSock connections, and calls the establish
// connection callback for any received.
func (l *UnixSock) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *UnixSock) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testUnixAddr = "mochi.sock"
func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address)
}
func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.ID())
}
func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, testUnixAddr, l.Address())
}
func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "unix", l.Protocol())
}
func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr)
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
}
func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("unix", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@@ -7,6 +7,7 @@ package listeners
import (
"context"
"errors"
"io"
"net"
"net/http"
"sync"
@@ -137,25 +138,35 @@ type wsConn struct {
}
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
func (ws *wsConn) Read(p []byte) (int, error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
return 0, err
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
return 0, err
}
return r.Read(p)
var n, br int
for {
br, err = r.Read(p[n:])
n += br
if err != nil {
if err == io.EOF {
err = nil
}
return n, err
}
}
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
func (ws *wsConn) Write(p []byte) (int, error) {
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
return 0, err
}
return len(p), nil

View File

@@ -26,10 +26,10 @@ import (
)
const (
Version = "2.0.0" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 64 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue
Version = "2.1.0" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
)
var (
@@ -376,7 +376,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)
if expire {
s.unsubscribeClient(cl)
s.UnsubscribeClient(cl)
cl.ClearInflights(math.MaxInt64, 0)
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
}
@@ -455,7 +455,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
defer existing.Unlock()
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.unsubscribeClient(existing)
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
return false // [MQTT-3.2.2-3]
}
@@ -697,6 +697,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
s.publishToSubscribers(pk)
})
s.hooks.OnPublished(cl, pk)
return nil
}
@@ -727,8 +728,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
s.publishToSubscribers(pk)
})
s.hooks.OnPublish(cl, pk)
s.hooks.OnPublished(cl, pk)
return nil
}
@@ -962,7 +962,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
code = packets.ErrPacketIdentifierInUse
}
existed := false
filterExisted := make([]bool, len(pk.Filters))
reasonCodes := make([]byte, len(pk.Filters))
for i, sub := range pk.Filters {
if code != packets.CodeSuccess {
@@ -978,8 +978,8 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
} else {
existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if !existed {
isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if isNew {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
@@ -988,6 +988,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
}
filterExisted[i] = !isNew
reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
}
@@ -1022,7 +1023,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
continue
}
s.publishRetainedToClient(cl, sub, existed)
s.publishRetainedToClient(cl, sub, filterExisted[i])
}
return nil
@@ -1072,14 +1073,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
return cl.WritePacket(ack)
}
// unsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) unsubscribeClient(cl *Client) {
for k := range cl.State.Subscriptions.GetAll() {
// UnsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) UnsubscribeClient(cl *Client) {
i := 0
filterMap := cl.State.Subscriptions.GetAll()
filters := make([]packets.Subscription, len(filterMap))
for k, v := range filterMap {
cl.State.Subscriptions.Delete(k)
if s.Topics.Unsubscribe(k, cl.ID) {
atomic.AddInt64(&s.Info.Subscriptions, -1)
}
filters[i] = v
i++
}
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
}
// processAuth processes an Auth packet.
@@ -1126,12 +1133,15 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
// We already have a code we are using to disconnect the client, so we are not
// interested if the write packet fails due to a closed connection (as we are closing it).
_ = cl.WritePacket(out)
err := cl.WritePacket(out)
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
cl.Stop(code)
if code.Code >= packets.ErrUnspecifiedError.Code {
return code
}
}
return code
return err
}
// publishSysTopics publishes the current values to the server $SYS topics.
@@ -1450,8 +1460,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) {
// clearExpiredInflights deletes any inflight messages which have expired.
func (s *Server) clearExpiredInflights(now int64) {
for _, client := range s.Clients.GetAll() {
if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 {
s.hooks.OnExpireInflights(client, now)
if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
for _, id := range deleted {
s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
}
}
}
}

View File

@@ -844,7 +844,7 @@ func TestServerUnsubscribeClient(t *testing.T) {
s.Topics.Subscribe(cl.ID, pk)
subs := s.Topics.Subscribers("a/b/c")
require.Equal(t, 1, len(subs.Subscriptions))
s.unsubscribeClient(cl)
s.UnsubscribeClient(cl)
subs = s.Topics.Subscribers("a/b/c")
require.Equal(t, 0, len(subs.Subscriptions))
}
@@ -2291,6 +2291,21 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf)
}
func TestServerRecievePacketDisconnectClient(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
go func() {
err := s.DisconnectClient(cl, packets.CodeDisconnect)
require.NoError(t, err)
w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, buf)
}
func TestServerProcessPacketDisconnect(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()