diff --git a/clients.go b/clients.go index 33bb010..4209680 100644 --- a/clients.go +++ b/clients.go @@ -99,7 +99,7 @@ func (cl *Clients) GetByListener(id string) []*Client { type Client struct { Properties ClientProperties // client properties State ClientState // the operational state of the client. - Net ClientConnection // network connection state of the clinet + Net ClientConnection // network connection state of the client ID string // the client id. ops *ops // ops provides a reference to server ops. sync.RWMutex // mutex @@ -111,7 +111,7 @@ type ClientConnection struct { bconn *bufio.ReadWriter // a buffered net.Conn for reading packets Remote string // the remote address of the client Listener string // listener id of the client - Inline bool // client is an inline programmetic client + Inline bool // if true, the client is the built-in 'inline' embedded client } // ClientProperties contains the properties which define the client behaviour. @@ -134,7 +134,7 @@ type Will struct { Retain bool // - } -// State tracks the state of the client. +// ClientState tracks the state of the client. type ClientState struct { TopicAliases TopicAliases // a map of topic aliases stopCause atomic.Value // reason for stopping @@ -311,7 +311,7 @@ func (cl *Client) ResendInflightMessages(force bool) error { return nil } -// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session. +// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session. func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 { deleted := []uint16{} for _, tk := range cl.State.Inflight.GetAll(false) { diff --git a/clients_test.go b/clients_test.go index ae91de6..37b5ba9 100644 --- a/clients_test.go +++ b/clients_test.go @@ -263,7 +263,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) { cl.State.Inflight.internal[uint16(i)] = packets.Packet{} } - cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1) + cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1 i, err := cl.NextPacketID() require.NoError(t, err) require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i) @@ -303,7 +303,7 @@ func TestClientResendInflightMessages(t *testing.T) { err := cl.ResendInflightMessages(true) require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -315,7 +315,7 @@ func TestClientResendInflightMessages(t *testing.T) { func TestClientResendInflightMessagesWriteFailure(t *testing.T) { pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) cl, r, _ := newTestClient() - r.Close() + _ = r.Close() cl.State.Inflight.Set(*pk1.Packet) require.Equal(t, 1, cl.State.Inflight.Len()) @@ -342,8 +342,8 @@ func TestClientReadFixedHeader(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect << 4, 0x00}) - r.Close() + _, _ = r.Write([]byte{packets.Connect << 4, 0x00}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -357,8 +357,8 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}) - r.Close() + _, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -372,8 +372,8 @@ func TestClientReadFixedHeaderPacketOversized(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -387,7 +387,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Close() + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -401,8 +401,8 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}) - r.Close() + _, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -414,7 +414,7 @@ func TestClientReadOK(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 18, // Fixed header 0, 5, // Topic Name - LSB+MSB 'a', '/', 'b', '/', 'c', // Topic Name @@ -424,7 +424,7 @@ func TestClientReadOK(t *testing.T) { 'd', '/', 'e', '/', 'f', // Topic Name 'y', 'e', 'a', 'h', // Payload }) - r.Close() + _ = r.Close() }() var pks []packets.Packet @@ -499,10 +499,10 @@ func TestClientReadFixedHeaderError(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header }) - r.Close() + _ = r.Close() }() cl.Net.bconn = nil @@ -516,13 +516,13 @@ func TestClientReadReadHandlerErr(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header 0, 5, // Topic Name - LSB+MSB 'd', '/', 'e', '/', 'f', // Topic Name 'y', 'e', 'a', 'h', // Payload }) - r.Close() + _ = r.Close() }() err := cl.Read(func(cl *Client, pk packets.Packet) error { @@ -536,13 +536,13 @@ func TestClientReadReadPacketOK(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -573,7 +573,7 @@ func TestClientReadPacket(t *testing.T) { t.Run(tt.Desc, func(t *testing.T) { atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0) go func() { - r.Write(tt.RawBytes) + _, _ = r.Write(tt.RawBytes) }() fh := new(packets.FixedHeader) @@ -600,7 +600,7 @@ func TestClientReadPacket(t *testing.T) { func TestClientReadPacketInvalidTypeError(t *testing.T) { cl, _, _ := newTestClient() - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() _, err := cl.ReadPacket(&packets.FixedHeader{}) require.Error(t, err) require.Contains(t, err.Error(), "invalid packet type") @@ -624,7 +624,7 @@ func TestClientWritePacket(t *testing.T) { require.NoError(t, err, pkInfo, tt.Case, tt.Desc) time.Sleep(2 * time.Millisecond) - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc) @@ -660,13 +660,13 @@ func TestClientReadPacketReadingError(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ 0, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() _, err := cl.ReadPacket(&packets.FixedHeader{ @@ -680,13 +680,13 @@ func TestClientReadPacketReadUnknown(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ 0, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() _, err := cl.ReadPacket(&packets.FixedHeader{ @@ -706,7 +706,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) { func TestClientWritePacketWriteError(t *testing.T) { cl, _, _ := newTestClient() - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() err := cl.WritePacket(*pkTable[1].Packet) require.Error(t, err) diff --git a/cmd/main.go b/cmd/main.go index 4ef3e82..2c14ab8 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -60,7 +60,7 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/auth/basic/main.go b/examples/auth/basic/main.go index 0bc8b4d..964e03f 100644 --- a/examples/auth/basic/main.go +++ b/examples/auth/basic/main.go @@ -78,6 +78,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/auth/encoded/main.go b/examples/auth/encoded/main.go index 15477fb..6bd1dab 100644 --- a/examples/auth/encoded/main.go +++ b/examples/auth/encoded/main.go @@ -60,6 +60,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/benchmark/main.go b/examples/benchmark/main.go index f791b88..b997086 100644 --- a/examples/benchmark/main.go +++ b/examples/benchmark/main.go @@ -47,6 +47,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/debug/main.go b/examples/debug/main.go index 55cf0f0..ffdc199 100644 --- a/examples/debug/main.go +++ b/examples/debug/main.go @@ -61,6 +61,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/direct/main.go b/examples/direct/main.go index 72c7c99..6e6db49 100644 --- a/examples/direct/main.go +++ b/examples/direct/main.go @@ -26,7 +26,9 @@ func main() { done <- true }() - server := mqtt.New(nil) + server := mqtt.New(&mqtt.Options{ + InlineClient: true, // you must enable inline client to use direct publishing and subscribing. + }) _ = server.AddHook(new(auth.AllowHook), nil) // Start the server @@ -50,12 +52,13 @@ func main() { server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload)) } server.Log.Info("inline client subscribing") - server.Subscribe("direct/#", 1, callbackFn) - server.Subscribe("direct/#", 2, callbackFn) + _ = server.Subscribe("direct/#", 1, callbackFn) + _ = server.Subscribe("direct/#", 2, callbackFn) }() - // There is a shorthand convenience function, Publish, for easily sending - // publish packets if you are not concerned with creating your own packets. + // There is a shorthand convenience function, Publish, for easily sending publish packets if you are not + // concerned with creating your own packets. If you want to have more control over your packets, you can + //directly inject a packet of any kind into the broker. See examples/hooks/main.go for usage. go func() { for range time.Tick(time.Second * 3) { err := server.Publish("direct/publish", []byte("scheduled message"), false, 0) @@ -70,23 +73,8 @@ func main() { time.Sleep(time.Second * 10) // Unsubscribe from the same filter to stop receiving messages. server.Log.Info("inline client unsubscribing") - server.Unsubscribe("direct/#", 1) + _ = server.Unsubscribe("direct/#", 1) }() - // If you want to have more control over your packets, you can directly inject a packet of any kind into the broker. - //go func() { - // for range time.Tick(time.Second * 5) { - // err := server.InjectPacket(cl, packets.Packet{ - // FixedHeader: packets.FixedHeader{ - // Type: packets.Publish, - // }, - // TopicName: "direct/publish", - // Payload: []byte("injected scheduled message"), - // }) - // if err != nil { - // log.Fatal(err) - // } - // } - //}() <-done server.Log.Warn("caught signal, stopping...") diff --git a/examples/hooks/main.go b/examples/hooks/main.go index f24b414..9cb8024 100644 --- a/examples/hooks/main.go +++ b/examples/hooks/main.go @@ -83,7 +83,7 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/paho.testing/main.go b/examples/paho.testing/main.go index cf92eac..4cc942d 100644 --- a/examples/paho.testing/main.go +++ b/examples/paho.testing/main.go @@ -46,7 +46,7 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/persistence/badger/main.go b/examples/persistence/badger/main.go index d0b88c7..58f3d6c 100644 --- a/examples/persistence/badger/main.go +++ b/examples/persistence/badger/main.go @@ -53,6 +53,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/persistence/bolt/main.go b/examples/persistence/bolt/main.go index 7e05f30..3a9351f 100644 --- a/examples/persistence/bolt/main.go +++ b/examples/persistence/bolt/main.go @@ -55,6 +55,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/persistence/redis/main.go b/examples/persistence/redis/main.go index 9f300bb..64e434e 100644 --- a/examples/persistence/redis/main.go +++ b/examples/persistence/redis/main.go @@ -63,6 +63,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/tcp/main.go b/examples/tcp/main.go index 2cb419d..f600ef0 100644 --- a/examples/tcp/main.go +++ b/examples/tcp/main.go @@ -53,6 +53,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/tls/main.go b/examples/tls/main.go index fe9f93f..67fbbc0 100644 --- a/examples/tls/main.go +++ b/examples/tls/main.go @@ -112,6 +112,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/examples/websocket/main.go b/examples/websocket/main.go index 7d42748..0a85e27 100644 --- a/examples/websocket/main.go +++ b/examples/websocket/main.go @@ -42,6 +42,6 @@ func main() { <-done server.Log.Warn("caught signal, stopping...") - server.Close() + _ = server.Close() server.Log.Info("main.go finished") } diff --git a/hooks/auth/ledger.go b/hooks/auth/ledger.go index 9e5e2e6..694b19d 100644 --- a/hooks/auth/ledger.go +++ b/hooks/auth/ledger.go @@ -80,8 +80,8 @@ func (r RString) Matches(a string) bool { } // FilterMatches returns true if a filter matches a topic rule. -func (f RString) FilterMatches(a string) bool { - _, ok := MatchTopic(string(f), a) +func (r RString) FilterMatches(a string) bool { + _, ok := MatchTopic(string(r), a) return ok } @@ -161,7 +161,7 @@ func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) { } // ACLOk returns true if the rules indicate the user is allowed to read or write to -// a specific filter or topic respectively, based on the write bool. +// a specific filter or topic respectively, based on the `write` bool. func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) { // If the users map is set, always check for a predefined user first instead // of iterating through global rules. @@ -209,7 +209,7 @@ func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok boo } } - for filter, _ := range rule.Filters { + for filter := range rule.Filters { if filter.FilterMatches(topic) { return n, false } diff --git a/hooks/auth/ledger_test.go b/hooks/auth/ledger_test.go index ab84768..8f46d14 100644 --- a/hooks/auth/ledger_test.go +++ b/hooks/auth/ledger_test.go @@ -561,17 +561,17 @@ func TestLedgerUpdate(t *testing.T) { }, } - new := &Ledger{ + n := &Ledger{ Auth: AuthRules{ {Remote: "127.0.0.1", Allow: true}, {Remote: "192.168.*", Allow: true}, }, } - old.Update(new) + old.Update(n) require.Len(t, old.Auth, 2) require.Equal(t, RString("192.168.*"), old.Auth[1].Remote) - require.NotSame(t, new, old) + require.NotSame(t, n, old) } func TestLedgerToJSON(t *testing.T) { diff --git a/hooks/debug/debug.go b/hooks/debug/debug.go index 00404f2..6a2f86b 100644 --- a/hooks/debug/debug.go +++ b/hooks/debug/debug.go @@ -114,7 +114,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { h.Log.Debug("inflight dropped", "m", h.packetMeta(pk)) } -// OnLWTSent is called when a will message has been issued from a disconnecting client. +// OnLWTSent is called when a Will Message has been issued from a disconnecting client. func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) { h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID) } @@ -136,25 +136,25 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) { return v, nil } -// StoredClients is called when the server restores subscriptions from a store. +// StoredSubscriptions is called when the server restores subscriptions from a store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { h.Log.Debug("", "method", "StoredSubscriptions") return v, nil } -// StoredClients is called when the server restores retained messages from a store. +// StoredRetainedMessages is called when the server restores retained messages from a store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { h.Log.Debug("", "method", "StoredRetainedMessages") return v, nil } -// StoredClients is called when the server restores inflight messages from a store. +// StoredInflightMessages is called when the server restores inflight messages from a store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { h.Log.Debug("", "method", "StoredInflightMessages") return v, nil } -// StoredClients is called when the server restores system info from a store. +// StoredSysInfo is called when the server restores system info from a store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { h.Log.Debug("", "method", "StoredSysInfo") diff --git a/hooks/storage/badger/badger.go b/hooks/storage/badger/badger.go index ca416f6..fdc653f 100644 --- a/hooks/storage/badger/badger.go +++ b/hooks/storage/badger/badger.go @@ -128,8 +128,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } diff --git a/hooks/storage/badger/badger_test.go b/hooks/storage/badger/badger_test.go index 0800367..7813cd5 100644 --- a/hooks/storage/badger/badger_test.go +++ b/hooks/storage/badger/badger_test.go @@ -38,8 +38,8 @@ var ( ) func teardown(t *testing.T, path string, h *Hook) { - h.Stop() - h.db.Badger().Close() + _ = h.Stop() + _ = h.db.Badger().Close() err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1)) require.NoError(t, err) } diff --git a/hooks/storage/bolt/bolt.go b/hooks/storage/bolt/bolt.go index b57c24a..0ea9bd1 100644 --- a/hooks/storage/bolt/bolt.go +++ b/hooks/storage/bolt/bolt.go @@ -1,7 +1,8 @@ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co -// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead. + +// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead. package bolt import ( @@ -132,8 +133,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } diff --git a/hooks/storage/bolt/bolt_test.go b/hooks/storage/bolt/bolt_test.go index 70695e7..6c1a638 100644 --- a/hooks/storage/bolt/bolt_test.go +++ b/hooks/storage/bolt/bolt_test.go @@ -38,7 +38,7 @@ var ( ) func teardown(t *testing.T, path string, h *Hook) { - h.Stop() + _ = h.Stop() err := os.Remove(path) require.NoError(t, err) } diff --git a/hooks/storage/redis/redis.go b/hooks/storage/redis/redis.go index e780225..3d9fa98 100644 --- a/hooks/storage/redis/redis.go +++ b/hooks/storage/redis/redis.go @@ -15,7 +15,7 @@ import ( "github.com/mochi-mqtt/server/v2/packets" "github.com/mochi-mqtt/server/v2/system" - redis "github.com/go-redis/redis/v8" + "github.com/go-redis/redis/v8" ) // defaultAddr is the default address to the redis service. @@ -134,7 +134,7 @@ func (h *Hook) Init(config any) error { return nil } -// Close closes the redis connection. +// Stop closes the redis connection. func (h *Hook) Stop() error { h.Log.Info("disconnecting from redis service") @@ -146,8 +146,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } diff --git a/hooks/storage/storage.go b/hooks/storage/storage.go index 12ade7b..eb57508 100644 --- a/hooks/storage/storage.go +++ b/hooks/storage/storage.go @@ -25,7 +25,7 @@ var ( ErrDBFileNotOpen = errors.New("db file not open") ) -// Client is a storable representation of an mqtt client. +// Client is a storable representation of an MQTT client. type Client struct { Will ClientWill `json:"will"` // will topic and payload data if applicable Properties ClientProperties `json:"properties"` // the connect properties for the client @@ -147,7 +147,7 @@ func (d *Message) ToPacket() packets.Packet { return pk } -// Subscription is a storable representation of an mqtt subscription. +// Subscription is a storable representation of an MQTT subscription. type Subscription struct { T string `json:"t"` ID string `json:"id" storm:"id"` diff --git a/listeners/http_healthcheck.go b/listeners/http_healthcheck.go index fb2d2a8..a82e2e3 100644 --- a/listeners/http_healthcheck.go +++ b/listeners/http_healthcheck.go @@ -79,9 +79,9 @@ func (l *HTTPHealthCheck) Init(_ *slog.Logger) error { // Serve starts listening for new connections and serving responses. func (l *HTTPHealthCheck) Serve(establish EstablishFn) { if l.listen.TLSConfig != nil { - l.listen.ListenAndServeTLS("", "") + _ = l.listen.ListenAndServeTLS("", "") } else { - l.listen.ListenAndServe() + _ = l.listen.ListenAndServe() } } @@ -93,7 +93,7 @@ func (l *HTTPHealthCheck) Close(closeClients CloseFn) { if atomic.CompareAndSwapUint32(&l.end, 0, 1) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - l.listen.Shutdown(ctx) + _ = l.listen.Shutdown(ctx) } closeClients(l.id) diff --git a/listeners/http_healthcheck_test.go b/listeners/http_healthcheck_test.go index 784c892..1c753c1 100644 --- a/listeners/http_healthcheck_test.go +++ b/listeners/http_healthcheck_test.go @@ -39,7 +39,7 @@ func TestHTTPHealthCheckTLSProtocol(t *testing.T) { TLSConfig: tlsConfigBasic, }) - l.Init(logger) + _ = l.Init(logger) require.Equal(t, "https", l.Protocol()) } diff --git a/listeners/http_sysinfo.go b/listeners/http_sysinfo.go index 771716f..98303a2 100644 --- a/listeners/http_sysinfo.go +++ b/listeners/http_sysinfo.go @@ -81,9 +81,9 @@ func (l *HTTPStats) Init(_ *slog.Logger) error { // Serve starts listening for new connections and serving responses. func (l *HTTPStats) Serve(establish EstablishFn) { if l.listen.TLSConfig != nil { - l.listen.ListenAndServeTLS("", "") + _ = l.listen.ListenAndServeTLS("", "") } else { - l.listen.ListenAndServe() + _ = l.listen.ListenAndServe() } } @@ -95,7 +95,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) { if atomic.CompareAndSwapUint32(&l.end, 0, 1) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - l.listen.Shutdown(ctx) + _ = l.listen.Shutdown(ctx) } closeClients(l.id) @@ -107,8 +107,8 @@ func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) { out, err := json.MarshalIndent(info, "", "\t") if err != nil { - io.WriteString(w, err.Error()) + _, _ = io.WriteString(w, err.Error()) } - w.Write(out) + _, _ = w.Write(out) } diff --git a/listeners/http_sysinfo_test.go b/listeners/http_sysinfo_test.go index ffd000c..2ab7ec3 100644 --- a/listeners/http_sysinfo_test.go +++ b/listeners/http_sysinfo_test.go @@ -42,7 +42,7 @@ func TestHTTPStatsTLSProtocol(t *testing.T) { TLSConfig: tlsConfigBasic, }, nil) - l.Init(logger) + _ = l.Init(logger) require.Equal(t, "https", l.Protocol()) } diff --git a/listeners/listeners.go b/listeners/listeners.go index f213d45..429f497 100644 --- a/listeners/listeners.go +++ b/listeners/listeners.go @@ -22,7 +22,7 @@ type Config struct { // EstablishFn is a callback function for establishing new clients. type EstablishFn func(id string, c net.Conn) error -// CloseFunc is a callback function for closing all listener clients. +// CloseFn is a callback function for closing all listener clients. type CloseFn func(id string) // Listener is an interface for network listeners. A network listener listens diff --git a/listeners/mock_test.go b/listeners/mock_test.go index 735401a..46aa922 100644 --- a/listeners/mock_test.go +++ b/listeners/mock_test.go @@ -16,7 +16,7 @@ func TestMockEstablisher(t *testing.T) { _, w := net.Pipe() err := MockEstablisher("t1", w) require.NoError(t, err) - w.Close() + _ = w.Close() } func TestNewMockListener(t *testing.T) { @@ -86,7 +86,7 @@ func TestMockListenerServe(t *testing.T) { require.Equal(t, true, closed) <-o - mocked.Init(nil) + _ = mocked.Init(nil) } func TestMockListenerClose(t *testing.T) { diff --git a/listeners/net_test.go b/listeners/net_test.go index 8afc666..14a1ad6 100644 --- a/listeners/net_test.go +++ b/listeners/net_test.go @@ -98,7 +98,7 @@ func TestNetEstablishThenEnd(t *testing.T) { }() time.Sleep(time.Millisecond) - net.Dial("tcp", n.Addr().String()) + _, _ = net.Dial("tcp", n.Addr().String()) require.Equal(t, true, <-established) l.Close(MockCloser) <-o diff --git a/listeners/tcp_test.go b/listeners/tcp_test.go index d1e6002..636c8ab 100644 --- a/listeners/tcp_test.go +++ b/listeners/tcp_test.go @@ -39,7 +39,7 @@ func TestTCPProtocolTLS(t *testing.T) { TLSConfig: tlsConfigBasic, }) - l.Init(logger) + _ = l.Init(logger) defer l.listen.Close() require.Equal(t, "tcp", l.Protocol()) } @@ -124,7 +124,7 @@ func TestTCPEstablishThenEnd(t *testing.T) { }() time.Sleep(time.Millisecond) - net.Dial("tcp", l.listen.Addr().String()) + _, _ = net.Dial("tcp", l.listen.Addr().String()) require.Equal(t, true, <-established) l.Close(MockCloser) <-o diff --git a/listeners/unixsock_test.go b/listeners/unixsock_test.go index 905b1cc..06ce24d 100644 --- a/listeners/unixsock_test.go +++ b/listeners/unixsock_test.go @@ -89,7 +89,7 @@ func TestUnixSockEstablishThenEnd(t *testing.T) { }() time.Sleep(time.Millisecond) - net.Dial("unix", l.listen.Addr().String()) + _, _ = net.Dial("unix", l.listen.Addr().String()) require.Equal(t, true, <-established) l.Close(MockCloser) <-o diff --git a/listeners/websocket.go b/listeners/websocket.go index d7934df..50715fc 100644 --- a/listeners/websocket.go +++ b/listeners/websocket.go @@ -30,7 +30,7 @@ type Websocket struct { // [MQTT-4.2.0-1] id string // the internal id of the listener address string // the network address to bind to config *Config // configuration values for the listener - listen *http.Server // an http server for serving websocket connections + listen *http.Server // a http server for serving websocket connections log *slog.Logger // server logger establish EstablishFn // the server's establish connection handler upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection. @@ -112,9 +112,9 @@ func (l *Websocket) Serve(establish EstablishFn) { l.establish = establish if l.listen.TLSConfig != nil { - l.listen.ListenAndServeTLS("", "") + _ = l.listen.ListenAndServeTLS("", "") } else { - l.listen.ListenAndServe() + _ = l.listen.ListenAndServe() } } @@ -126,7 +126,7 @@ func (l *Websocket) Close(closeClients CloseFn) { if atomic.CompareAndSwapUint32(&l.end, 0, 1) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - l.listen.Shutdown(ctx) + _ = l.listen.Shutdown(ctx) } closeClients(l.id) @@ -137,7 +137,7 @@ type wsConn struct { net.Conn c *websocket.Conn - // reader for the current message (may be nil) + // reader for the current message (can be nil) r io.Reader } diff --git a/listeners/websocket_test.go b/listeners/websocket_test.go index bcaee72..a2db1bb 100644 --- a/listeners/websocket_test.go +++ b/listeners/websocket_test.go @@ -37,14 +37,14 @@ func TestWebsocketProtocol(t *testing.T) { require.Equal(t, "ws", l.Protocol()) } -func TestWebsocketProtocoTLS(t *testing.T) { +func TestWebsocketProtocolTLS(t *testing.T) { l := NewWebsocket("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) require.Equal(t, "wss", l.Protocol()) } -func TestWebsockeInit(t *testing.T) { +func TestWebsocketInit(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) require.Nil(t, l.listen) err := l.Init(logger) @@ -54,7 +54,7 @@ func TestWebsockeInit(t *testing.T) { func TestWebsocketServeAndClose(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(logger) + _ = l.Init(logger) o := make(chan bool) go func(o chan bool) { @@ -96,7 +96,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) { func TestWebsocketUpgrade(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(logger) + _ = l.Init(logger) e := make(chan bool) l.establish = func(id string, c net.Conn) error { @@ -110,12 +110,12 @@ func TestWebsocketUpgrade(t *testing.T) { require.Equal(t, true, <-e) s.Close() - ws.Close() + _ = ws.Close() } func TestWebsocketConnectionReads(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(nil) + _ = l.Init(nil) recv := make(chan []byte) l.establish = func(id string, c net.Conn) error { @@ -151,5 +151,5 @@ func TestWebsocketConnectionReads(t *testing.T) { require.Equal(t, pkt, got) s.Close() - ws.Close() + _ = ws.Close() } diff --git a/packets/codes_test.go b/packets/codes_test.go index e6c196d..aed8e57 100644 --- a/packets/codes_test.go +++ b/packets/codes_test.go @@ -19,7 +19,7 @@ func TestCodesString(t *testing.T) { require.Equal(t, "test", c.String()) } -func TestCodesErrorr(t *testing.T) { +func TestCodesError(t *testing.T) { c := Code{ Reason: "error", Code: 0x1, diff --git a/packets/packets.go b/packets/packets.go index 2611bcb..ff5930b 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -14,7 +14,7 @@ import ( "sync" ) -// All of the valid packet types and their packet identifier. +// All valid packet types and their packet identifiers. const ( Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets. Connect // 1 @@ -37,9 +37,9 @@ const ( var ( // ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification. - ErrNoValidPacketAvailable error = errors.New("no valid packet available") + ErrNoValidPacketAvailable = errors.New("no valid packet available") - // PacketNames is a map of packet bytes to human readable names, for easier debugging. + // PacketNames is a map of packet bytes to human-readable names, for easier debugging. PacketNames = map[byte]string{ 0: "Reserved", 1: "Connect", @@ -272,28 +272,28 @@ func (s Subscription) Merge(n Subscription) Subscription { } // encode encodes a subscription and properties into bytes. -func (p Subscription) encode() byte { +func (s Subscription) encode() byte { var flag byte - flag |= p.Qos + flag |= s.Qos - if p.NoLocal { + if s.NoLocal { flag |= 1 << 2 } - if p.RetainAsPublished { + if s.RetainAsPublished { flag |= 1 << 3 } - flag |= p.RetainHandling << 4 + flag |= s.RetainHandling << 4 return flag } // decode decodes subscription bytes into a subscription struct. -func (p *Subscription) decode(b byte) { - p.Qos = b & 3 // byte - p.NoLocal = 1&(b>>2) > 0 // bool - p.RetainAsPublished = 1&(b>>3) > 0 // bool - p.RetainHandling = 3 & (b >> 4) // byte +func (s *Subscription) decode(b byte) { + s.Qos = b & 3 // byte + s.NoLocal = 1&(b>>2) > 0 // bool + s.RetainAsPublished = 1&(b>>3) > 0 // bool + s.RetainHandling = 3 & (b >> 4) // byte } // ConnectEncode encodes a connect packet. @@ -343,7 +343,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -505,7 +505,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -548,7 +548,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -619,7 +619,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -707,7 +707,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -844,7 +844,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -901,7 +901,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -996,7 +996,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -1049,7 +1049,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -1109,7 +1109,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } diff --git a/packets/packets_test.go b/packets/packets_test.go index 0e3b9ce..1e18f1f 100644 --- a/packets/packets_test.go +++ b/packets/packets_test.go @@ -150,7 +150,7 @@ func TestPacketEncode(t *testing.T) { } pk := new(Packet) - copier.Copy(pk, wanted.Packet) + _ = copier.Copy(pk, wanted.Packet) require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc) pk.Mods.AllowResponseInfo = true @@ -218,7 +218,7 @@ func TestPacketDecode(t *testing.T) { pk := &Packet{FixedHeader: FixedHeader{Type: pkt}} pk.Mods.AllowResponseInfo = true - pk.FixedHeader.Decode(wanted.RawBytes[0]) + _ = pk.FixedHeader.Decode(wanted.RawBytes[0]) if len(wanted.RawBytes) > 0 { pk.FixedHeader.Remaining = int(wanted.RawBytes[1]) } diff --git a/packets/properties.go b/packets/properties.go index c5eefc1..1fc02fd 100644 --- a/packets/properties.go +++ b/packets/properties.go @@ -77,7 +77,7 @@ type UserProperty struct { // [MQTT-1.5.7-1] Val string `json:"v"` } -// Properties contains all of the mqtt v5 properties available for a packet. +// Properties contains all mqtt v5 properties available for a packet. // Some properties have valid values of 0 or not-present. In this case, we opt for // property flags to indicate the usage of property. // Refer to mqtt v5 2.2.2.2 Property spec for more information. @@ -355,7 +355,7 @@ func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) { } encodeLength(b, int64(buf.Len())) - buf.WriteTo(b) // [MQTT-3.1.3-10] + _, _ = buf.WriteTo(b) // [MQTT-3.1.3-10] } // Decode decodes property bytes into a properties struct. diff --git a/packets/tpackets.go b/packets/tpackets.go index 1ad1255..267721e 100644 --- a/packets/tpackets.go +++ b/packets/tpackets.go @@ -40,7 +40,6 @@ const ( TConnectMqtt5 TConnectMqtt5LWT TConnectClean - TConnectCleanLWT TConnectUserPass TConnectUserPassLWT TConnectMalProtocolName @@ -61,7 +60,6 @@ const ( TConnectInvalidProtocolVersion2 TConnectInvalidReservedBit TConnectInvalidClientIDTooLong - TConnectInvalidPasswordNoUsername TConnectInvalidFlagNoUsername TConnectInvalidFlagNoPassword TConnectInvalidUsernameNoFlag @@ -186,7 +184,6 @@ const ( TUnsubscribe TUnsubscribeMany TUnsubscribeMqtt5 - TUnsubscribeDropProperties TUnsubscribeMalPacketID TUnsubscribeMalTopicName TUnsubscribeMalProperties @@ -204,7 +201,6 @@ const ( TDisconnect TDisconnectTakeover TDisconnectMqtt5 - TDisconnectNormalMqtt5 TDisconnectSecondConnect TDisconnectReceiveMaximum TDisconnectDropProperties diff --git a/server.go b/server.go index c9474ba..6b0c23b 100644 --- a/server.go +++ b/server.go @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co -// package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility. +// Package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility. package mqtt import ( @@ -26,7 +26,7 @@ import ( ) const ( - Version = "2.3.0" // the current server version. + Version = "2.4.0" // the current server version. defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes LocalListener = "local" InlineClientId = "inline" @@ -38,7 +38,7 @@ var ( MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over ReceiveMaximum: 1024, // maximum number of concurrent qos messages per client - MaximumQos: 2, // maxmimum qos value available to clients + MaximumQos: 2, // maximum qos value available to clients RetainAvailable: 1, // retain messages is available MaximumPacketSize: 0, // no maximum packet size TopicAliasMaximum: math.MaxUint16, // maximum topic alias value @@ -49,8 +49,9 @@ var ( MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client } - ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists. - ErrConnectionClosed = errors.New("connection not open") // connection is closed + ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists + ErrConnectionClosed = errors.New("connection not open") // connection is closed + ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default ) // Capabilities indicates the capabilities and features provided by the server. @@ -106,6 +107,10 @@ type Options struct { // SysTopicResendInterval specifies the interval between $SYS topic updates in seconds. SysTopicResendInterval int64 + + // Enable Inline client to allow direct subscribing and publishing from the parent codebase, + // with negligible performance difference (disabled by default to prevent confusion in statistics). + InlineClient bool } // Server is an MQTT broker server. It should be created with server.New() @@ -119,8 +124,8 @@ type Server struct { loop *loop // loop contains tickers for the system event loop done chan bool // indicate that the server is ending Log *slog.Logger // minimal no-alloc logger - hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage. - inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish. + hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage + inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish } // loop contains interval tickers for the system events loop. @@ -129,7 +134,7 @@ type loop struct { clientExpiry *time.Ticker // interval ticker for cleaning expired clients inflightExpiry *time.Ticker // interval ticker for cleaning up expired inflight messages retainedExpiry *time.Ticker // interval ticker for cleaning retained messages - willDelaySend *time.Ticker // interval ticker for sending will messages with a delay + willDelaySend *time.Ticker // interval ticker for sending Will Messages with a delay willDelayed *packets.Packets // activate LWT packets which will be sent after a delay } @@ -173,8 +178,11 @@ func New(opts *Options) *Server { Log: opts.Logger, }, } - s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true) - s.Clients.Add(s.inlineClient) + + if s.Options.InlineClient { + s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true) + s.Clients.Add(s.inlineClient) + } return s } @@ -426,7 +434,7 @@ func (s *Server) receivePacket(cl *Client, pk packets.Packet) error { if code, ok := err.(packets.Code); ok && cl.Properties.ProtocolVersion == 5 && code.Code >= packets.ErrUnspecifiedError.Code { - s.DisconnectClient(cl, code) + _ = s.DisconnectClient(cl, code) } s.Log.Warn("error processing packet", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "pk", pk) @@ -464,7 +472,7 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code { // session is abandoned. func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok { - s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] + _ = s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] s.UnsubscribeClient(existing) existing.ClearInflights(math.MaxInt64, 0) @@ -649,11 +657,15 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error { }) } -// Publish publishes a publish packet into the broker as if it were sent from the speicfied client. +// Publish publishes a publish packet into the broker as if it were sent from the specified client. // This is a convenience function which wraps InjectPacket. As such, this method can publish packets // to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the // outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete). func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + return s.InjectPacket(s.inlineClient, packets.Packet{ FixedHeader: packets.FixedHeader{ Type: packets.Publish, @@ -669,11 +681,18 @@ func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) er // Subscribe adds an inline subscription for the specified topic filter and subscription identifier // with the provided handler function. func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + if handler == nil { return packets.ErrInlineSubscriptionHandlerInvalid - } else if !IsValidFilter(filter, false) { + } + + if !IsValidFilter(filter, false) { return packets.ErrTopicFilterInvalid } + subscription := packets.Subscription{ Identifier: subscriptionId, Filter: filter, @@ -704,6 +723,10 @@ func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubF // It allows you to unsubscribe a specific subscription from the internal subscription // associated with the given topic filter. func (s *Server) Unsubscribe(filter string, subscriptionId int) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + if !IsValidFilter(filter, false) { return packets.ErrTopicFilterInvalid } @@ -761,12 +784,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { return s.DisconnectClient(cl, packets.ErrNotAuthorized) } - if pk.FixedHeader.Qos == 1 { - ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.ErrNotAuthorized) - return cl.WritePacket(ack) + ackType := packets.Puback + if pk.FixedHeader.Qos == 2 { + ackType = packets.Pubrec } - ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrNotAuthorized) + ack := s.buildAck(pk.PacketID, ackType, 0, pk.Properties, packets.ErrNotAuthorized) return cl.WritePacket(ack) } @@ -790,7 +813,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { } if pk.FixedHeader.Qos > s.Options.Capabilities.MaximumQos { - pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce Qos based on server max qos capability + pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce qos based on server max qos capability } pkx, err := s.hooks.OnPublish(cl, pk) @@ -1373,7 +1396,7 @@ func (s *Server) Close() error { func (s *Server) closeListenerClients(listener string) { clients := s.Clients.GetByListener(listener) for _, cl := range clients { - s.DisconnectClient(cl, packets.ErrServerShuttingDown) + _ = s.DisconnectClient(cl, packets.ErrServerShuttingDown) } } diff --git a/server_test.go b/server_test.go index 64b110f..ec0dc3b 100644 --- a/server_test.go +++ b/server_test.go @@ -10,7 +10,6 @@ import ( "io" "log/slog" "net" - "os" "strconv" "sync" "sync/atomic" @@ -25,7 +24,7 @@ import ( "github.com/stretchr/testify/require" ) -var logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) +var logger = slog.New(slog.NewTextHandler(io.Discard, nil)) type ProtocolTest []struct { protocolVersion byte @@ -100,13 +99,24 @@ func newServer() *Server { cc := *DefaultServerCapabilities cc.MaximumMessageExpiryInterval = 0 cc.ReceiveMaximum = 0 - s := New(&Options{ Logger: logger, Capabilities: &cc, }) + _ = s.AddHook(new(AllowHook), nil) + return s +} - s.AddHook(new(AllowHook), nil) +func newServerWithInlineClient() *Server { + cc := *DefaultServerCapabilities + cc.MaximumMessageExpiryInterval = 0 + cc.ReceiveMaximum = 0 + s := New(&Options{ + Logger: logger, + Capabilities: &cc, + InlineClient: true, + }) + _ = s.AddHook(new(AllowHook), nil) return s } @@ -138,6 +148,14 @@ func TestNew(t *testing.T) { require.NotNil(t, s.hooks) require.NotNil(t, s.hooks.Log) require.NotNil(t, s.done) + require.Nil(t, s.inlineClient) + require.Equal(t, 0, s.Clients.Len()) +} + +func TestNewWithInlineClient(t *testing.T) { + s := New(&Options{ + InlineClient: true, + }) require.NotNil(t, s.inlineClient) require.Equal(t, 1, s.Clients.Len()) } @@ -282,8 +300,8 @@ func TestServerReadConnectionPacket(t *testing.T) { }() go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _ = r.Close() }() require.Equal(t, *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet, <-o) @@ -303,8 +321,8 @@ func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) { }() go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes) + _ = r.Close() }() err := <-o @@ -320,8 +338,8 @@ func TestServerReadConnectionPacketBadPacketType(t *testing.T) { s.Clients.Add(cl) go func() { - r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes) + _ = r.Close() }() _, err := s.readConnectionPacket(cl) @@ -337,8 +355,8 @@ func TestServerReadConnectionPacketBadPacket(t *testing.T) { s.Clients.Add(cl) go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes) + _ = r.Close() }() _, err := s.readConnectionPacket(cl) @@ -357,8 +375,8 @@ func TestEstablishConnection(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -381,8 +399,8 @@ func TestEstablishConnection(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() // client must be deleted on session close if Clean = true _, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet.Connect.ClientIdentifier) @@ -400,15 +418,15 @@ func TestEstablishConnectionAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestEstablishConnectionReadError(t *testing.T) { @@ -422,8 +440,8 @@ func TestEstablishConnectionReadError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes) - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error }() // receive the connack @@ -449,8 +467,8 @@ func TestEstablishConnectionReadError(t *testing.T) { ret, ) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() } func TestEstablishConnectionInheritExisting(t *testing.T) { @@ -473,9 +491,9 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect. - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the disconnect session takeover @@ -510,8 +528,8 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover) time.Sleep(time.Microsecond * 100) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) require.True(t, ok) @@ -526,7 +544,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { s := newServer() d := new(DelayHook) d.DisconnectDelay = time.Millisecond * 200 - s.AddHook(d, nil) + _ = s.AddHook(d, nil) defer s.Close() // Clean session, 0 session expiry interval @@ -551,7 +569,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { o1 <- err }() go func() { - w1.Write(cl1RawBytes) + _, _ = w1.Write(cl1RawBytes) }() // receive the first connack @@ -580,7 +598,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { go func() { x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:] x[19] = '.' // differentiate username bytes in debugging - w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes) + _, _ = w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes) }() // receive the second connack @@ -608,7 +626,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { require.NotEmpty(t, clp2.State.Subscriptions.GetAll()) require.Empty(t, clp1.State.Subscriptions.GetAll()) - w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) require.NoError(t, <-o2) } @@ -631,7 +649,7 @@ func TestEstablishConnectionResentPendingInflightsError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) }() go func() { @@ -666,8 +684,8 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the disconnect @@ -697,8 +715,8 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) require.True(t, ok) @@ -718,8 +736,8 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -735,8 +753,8 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) { require.ErrorIs(t, err, packets.ErrBadUsernameOrPassword) require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackBadUsernamePasswordNoSession).RawBytes, <-recv) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() } func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) { @@ -752,15 +770,15 @@ func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionInvalidConnect(t *testing.T) { @@ -773,8 +791,8 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -790,7 +808,7 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) { require.ErrorIs(t, packets.ErrProtocolViolationReservedBit, err) require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackProtocolViolationNoSession).RawBytes, <-recv) - r.Close() + _ = r.Close() } // See https://github.com/mochi-mqtt/server/issues/178 @@ -804,8 +822,8 @@ func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack error @@ -817,7 +835,7 @@ func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { err := <-o require.NoError(t, err) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) { @@ -830,15 +848,15 @@ func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionBadPacket(t *testing.T) { @@ -851,15 +869,15 @@ func TestServerEstablishConnectionBadPacket(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() err := <-o require.Error(t, err) require.ErrorIs(t, err, packets.ErrProtocolViolationRequireFirstConnect) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionOnConnectError(t *testing.T) { @@ -876,14 +894,14 @@ func TestServerEstablishConnectionOnConnectError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) }() err = <-o require.Error(t, err) require.ErrorIs(t, err, errTestHook) - r.Close() + _ = r.Close() } func TestServerSendConnack(t *testing.T) { @@ -897,7 +915,7 @@ func TestServerSendConnack(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -912,7 +930,7 @@ func TestServerSendConnackFailureReason(t *testing.T) { go func() { err := s.SendConnack(cl, packets.ErrUnspecifiedError, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -929,7 +947,7 @@ func TestServerSendConnackWithServerKeepalive(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1008,7 +1026,7 @@ func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, false, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1087,7 +1105,7 @@ func TestServerProcessPacketPingreq(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1116,7 +1134,7 @@ func TestServerProcessPacketPublishInvalid(t *testing.T) { func TestInjectPacketPublishAndReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1141,17 +1159,18 @@ func TestInjectPacketPublishAndReceive(t *testing.T) { go func() { err := s.InjectPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w1.Close() + _ = w1.Close() time.Sleep(time.Millisecond * 10) - w2.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) } -func TestServerDirectPublishAndReceive(t *testing.T) { - s := newServer() - s.Serve() +func TestServerPublishAndReceive(t *testing.T) { + s := newServerWithInlineClient() + + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1177,14 +1196,22 @@ func TestServerDirectPublishAndReceive(t *testing.T) { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) require.NoError(t, err) - w1.Close() + _ = w1.Close() time.Sleep(time.Millisecond * 10) - w2.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) } +func TestServerPublishNoInlineClient(t *testing.T) { + s := newServer() + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + func TestInjectPacketError(t *testing.T) { s := newServer() defer s.Close() @@ -1209,7 +1236,7 @@ func TestInjectPacketPublishInvalidTopic(t *testing.T) { func TestServerProcessPacketPublishAndReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1235,8 +1262,8 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) { err := s.processPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) require.NoError(t, err) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -1301,7 +1328,7 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1313,13 +1340,13 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) { func TestServerProcessPublishAckFailure(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, _, w := newTestClient() s.Clients.Add(cl) - w.Close() + _ = w.Close() err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) @@ -1336,7 +1363,7 @@ func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) { cl, _, w := newTestClient() cl.Properties.ProtocolVersion = 5 s.Clients.Add(cl) - w.Close() + _ = w.Close() err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.Error(t, err) @@ -1350,7 +1377,7 @@ func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) { hook.err = packets.ErrPayloadFormatInvalid err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1360,7 +1387,7 @@ func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1375,7 +1402,7 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { hook.err = packets.CodeSuccessIgnore err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1399,8 +1426,8 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() - w2.Close() + _ = w.Close() + _ = w2.Close() }() buf, err := io.ReadAll(r) @@ -1412,7 +1439,7 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1424,7 +1451,7 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrReceiveMaximum) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1434,15 +1461,15 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { func TestServerProcessPublishInvalidTopic(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, _, _ := newTestClient() err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet) - require.NoError(t, err) // $SYS topics should be ignored? + require.NoError(t, err) // $SYS Topics should be ignored? } func TestServerProcessPublishACLCheckDeny(t *testing.T) { - tests := []struct { + tt := []struct { name string protocolVersion byte pk packets.Packet @@ -1450,54 +1477,88 @@ func TestServerProcessPublishACLCheckDeny(t *testing.T) { expectReponse []byte expectDisconnect bool }{ - {"v4_QOS0", 4, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet, - nil, nil, false}, - {"v4_QOS1", 4, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet, - packets.ErrNotAuthorized, nil, true}, - {"v4_QOS2", 4, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet, - packets.ErrNotAuthorized, nil, true}, - - {"v5_QOS0", 5, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet, - nil, nil, false}, - {"v5_QOS1", 5, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Mqtt5).Packet, - nil, packets.TPacketData[packets.Puback].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, false}, - {"v5_QOS2", 5, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet, - nil, packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, false}, + { + name: "v4_QOS0", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet, + expectErr: nil, + expectReponse: nil, + expectDisconnect: false, + }, + { + name: "v4_QOS1", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet, + expectErr: packets.ErrNotAuthorized, + expectReponse: nil, + expectDisconnect: true, + }, + { + name: "v4_QOS2", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet, + expectErr: packets.ErrNotAuthorized, + expectReponse: nil, + expectDisconnect: true, + }, + { + name: "v5_QOS0", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet, + expectErr: nil, + expectReponse: nil, + expectDisconnect: false, + }, + { + name: "v5_QOS1", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Mqtt5).Packet, + expectErr: nil, + expectReponse: packets.TPacketData[packets.Puback].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, + expectDisconnect: false, + }, + { + name: "v5_QOS2", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet, + expectErr: nil, + expectReponse: packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, + expectDisconnect: false, + }, } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { + for _, tx := range tt { + t.Run(tx.name, func(t *testing.T) { cc := *DefaultServerCapabilities s := New(&Options{ Logger: logger, Capabilities: &cc, }) - s.AddHook(new(DenyHook), nil) - s.Serve() + _ = s.AddHook(new(DenyHook), nil) + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() - cl.Properties.ProtocolVersion = tt.protocolVersion + cl.Properties.ProtocolVersion = tx.protocolVersion s.Clients.Add(cl) wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() - err := s.processPublish(cl, tt.pk) - require.ErrorIs(t, err, tt.expectErr) - w.Close() + err := s.processPublish(cl, tx.pk) + require.ErrorIs(t, err, tx.expectErr) + _ = w.Close() }() buf, err := io.ReadAll(r) require.NoError(t, err) - if tt.expectReponse != nil { - require.Equal(t, tt.expectReponse, buf) + if tx.expectReponse != nil { + require.Equal(t, tx.expectReponse, buf) } - require.Equal(t, tt.expectDisconnect, cl.Closed()) + require.Equal(t, tx.expectDisconnect, cl.Closed()) wg.Wait() }) } @@ -1513,7 +1574,7 @@ func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) { err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, _, _ := newTestClient() err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) @@ -1527,7 +1588,7 @@ func TestServerProcessPacketPublishQos0(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1544,7 +1605,7 @@ func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1563,7 +1624,7 @@ func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1579,7 +1640,7 @@ func TestServerProcessPacketPublishQos1(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1594,7 +1655,7 @@ func TestServerProcessPacketPublishQos2(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1610,7 +1671,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1630,7 +1691,7 @@ func TestPublishToSubscribersSelfNoLocal(t *testing.T) { pkx.Origin = cl.ID s.publishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1682,9 +1743,9 @@ func TestPublishToSubscribers(t *testing.T) { go func() { s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w1.Close() - w2.Close() - w3.Close() + _ = w1.Close() + _ = w2.Close() + _ = w3.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-cl1Recv) @@ -1725,7 +1786,7 @@ func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) { pkx.Created = time.Now().Unix() - 30 s.publishToSubscribers(pkx) time.Sleep(time.Millisecond) - w1.Close() + _ = w1.Close() }() b := <-cl1Recv @@ -1749,7 +1810,7 @@ func TestPublishToSubscribersIdentifiers(t *testing.T) { go func() { s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1774,7 +1835,7 @@ func TestPublishToSubscribersPkIgnore(t *testing.T) { pk.Ignore = true s.publishToSubscribers(pk) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1801,9 +1862,9 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.FixedHeader.Qos = 2 - s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx) time.Sleep(time.Microsecond * 100) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1830,9 +1891,9 @@ func TestPublishToClientSubscriptionDowngradeQos(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.FixedHeader.Qos = 2 - s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx) time.Sleep(time.Microsecond * 100) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1881,10 +1942,10 @@ func TestPublishToClientServerTopicAlias(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet - s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) - s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1937,6 +1998,19 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) { require.ErrorIs(t, err, packets.ErrQuotaExceeded) } +func TestPublishToClientACLNotAuthorized(t *testing.T) { + s := New(&Options{ + Logger: logger, + }) + err := s.AddHook(new(DenyHook), nil) + require.NoError(t, err) + cl, _, _ := newTestClient() + + _, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrNotAuthorized) +} + func TestPublishToClientNoConn(t *testing.T) { s := newServer() cl, _, _ := newTestClient() @@ -1963,10 +2037,10 @@ func TestProcessPublishWithTopicAlias(t *testing.T) { pkx.Properties.SubscriptionIdentifier = []int{} // must not contain from client to server pkx.TopicName = "" pkx.Properties.TopicAlias = 1 - s.processPacket(cl2, pkx) + _ = s.processPacket(cl2, pkx) time.Sleep(time.Millisecond) - w2.Close() - w.Close() + _ = w2.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1985,12 +2059,12 @@ func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.PacketID = 0 s.publishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { @@ -2006,12 +2080,12 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.PacketID = 0 s.publishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishToSubscribersNoConnection(t *testing.T) { @@ -2023,10 +2097,10 @@ func TestPublishToSubscribersNoConnection(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishRetainedToClient(t *testing.T) { @@ -2043,7 +2117,7 @@ func TestPublishRetainedToClient(t *testing.T) { go func() { s.publishRetainedToClient(cl, packets.Subscription{Filter: "a/b/c"}, false) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2062,7 +2136,7 @@ func TestPublishRetainedToClientIsShared(t *testing.T) { go func() { s.publishRetainedToClient(cl, sub, false) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2082,7 +2156,7 @@ func TestPublishRetainedToClientError(t *testing.T) { retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) require.Equal(t, int64(1), retained) - w.Close() + _ = w.Close() s.publishRetainedToClient(cl, sub, false) } @@ -2185,7 +2259,7 @@ func TestServerProcessPacketPubrec(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).RawBytes, <-recv) @@ -2213,7 +2287,7 @@ func TestServerProcessPacketPubrecNoPacketID(t *testing.T) { pk := *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet // not sending properties err := s.processPacket(cl, pk) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrelMqtt5AckNoPacket).RawBytes, <-recv) @@ -2263,7 +2337,7 @@ func TestServerProcessPacketPubrel(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) @@ -2292,7 +2366,7 @@ func TestServerProcessPacketPubrelNoPacketID(t *testing.T) { pk := *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet // not sending properties err := s.processPacket(cl, pk) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5AckNoPacket).RawBytes, <-recv) @@ -2404,7 +2478,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) { err := s.processPacket(cl, *tx.in.Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, tx.out.RawBytes, <-recv) if i == 0 { @@ -2485,7 +2559,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) { } time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() if i != 2 { require.Equal(t, tx.out.RawBytes, <-recv) @@ -2508,7 +2582,7 @@ func TestServerProcessPacketSubscribe(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2527,7 +2601,7 @@ func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, pkx) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2553,7 +2627,7 @@ func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidFilter).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2569,7 +2643,7 @@ func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2589,7 +2663,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2610,7 +2684,7 @@ func TestServerProcessSubscribeDowngradeQos(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2632,7 +2706,7 @@ func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2653,7 +2727,7 @@ func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2674,7 +2748,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2688,7 +2762,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { func TestServerProcessSubscribeNoConnection(t *testing.T) { s := newServer() cl, r, _ := newTestClient() - r.Close() + _ = r.Close() err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) @@ -2698,14 +2772,14 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { s := New(&Options{ Logger: logger, }) - s.Serve() + _ = s.Serve() cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 go func() { err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2717,7 +2791,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { s := New(&Options{ Logger: logger, }) - s.Serve() + _ = s.Serve() s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 @@ -2725,7 +2799,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { go func() { err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2742,7 +2816,7 @@ func TestServerProcessSubscribeErrorDowngrade(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2758,7 +2832,7 @@ func TestServerProcessPacketUnsubscribe(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2775,7 +2849,7 @@ func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2811,7 +2885,7 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) { err := s.receivePacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2826,7 +2900,7 @@ func TestServerRecievePacketDisconnectClient(t *testing.T) { go func() { err := s.DisconnectClient(cl, packets.CodeDisconnect) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2871,7 +2945,7 @@ func TestServerProcessPacketAuth(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2905,7 +2979,7 @@ func TestServerProcessPacketAuthFailure(t *testing.T) { func TestServerSendLWT(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -2935,8 +3009,8 @@ func TestServerSendLWT(t *testing.T) { go func() { s.sendLWT(sender) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -2944,7 +3018,7 @@ func TestServerSendLWT(t *testing.T) { func TestServerSendLWTRetain(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -2975,8 +3049,8 @@ func TestServerSendLWTRetain(t *testing.T) { go func() { s.sendLWT(sender) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -3010,7 +3084,7 @@ func TestServerSendLWTDelayed(t *testing.T) { s.sendDelayedLWT(time.Now().Unix()) require.Equal(t, 0, s.loop.willDelayed.Len()) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() recv := make(chan []byte) @@ -3026,7 +3100,7 @@ func TestServerSendLWTDelayed(t *testing.T) { func TestServerReadStore(t *testing.T) { s := newServer() hook := new(modifiedHookBase) - s.AddHook(hook, nil) + _ = s.AddHook(hook, nil) hook.failAt = 1 // clients err := s.readStore() @@ -3057,11 +3131,9 @@ func TestServerLoadClients(t *testing.T) { } s := newServer() - // Here, server.Clients also includes a server.inlineClient, - // so the total number of clients here should be increased by 1 - require.Equal(t, 1, s.Clients.Len()) + require.Equal(t, 0, s.Clients.Len()) s.loadClients(v) - require.Equal(t, 4, s.Clients.Len()) + require.Equal(t, 3, s.Clients.Len()) cl, ok := s.Clients.Get("mochi") require.True(t, ok) require.Equal(t, "mochi", cl.ID) @@ -3090,9 +3162,7 @@ func TestServerLoadInflightMessages(t *testing.T) { {ID: "mochi-co"}, }) - // server.Clients also includes a server.inlineClient, - // so the total number of clients here should be increased by 1 - require.Equal(t, 4, s.Clients.Len()) + require.Equal(t, 3, s.Clients.Len()) v := []storage.Message{ {Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"}, @@ -3136,7 +3206,7 @@ func TestServerClose(t *testing.T) { s := newServer() hook := new(modifiedHookBase) - s.AddHook(hook, nil) + _ = s.AddHook(hook, nil) cl, r, _ := newTestClient() cl.Net.Listener = "t1" @@ -3145,7 +3215,7 @@ func TestServerClose(t *testing.T) { err := s.AddListener(listeners.NewMockListener("t1", ":1882")) require.NoError(t, err) - s.Serve() + _ = s.Serve() // receive the disconnect recv := make(chan []byte) @@ -3162,7 +3232,7 @@ func TestServerClose(t *testing.T) { require.Equal(t, true, ok) require.Equal(t, true, listener.(*listeners.MockListener).IsServing()) - s.Close() + _ = s.Close() time.Sleep(time.Millisecond) require.Equal(t, false, listener.(*listeners.MockListener).IsServing()) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv) @@ -3248,15 +3318,10 @@ func TestServerClearExpiredClients(t *testing.T) { cl2.Properties.Props.SessionExpiryIntervalFlag = true s.Clients.Add(cl2) - // server.Clients also includes a server.inlineClient, - // so the total number of clients here should be increased by 1 - require.Equal(t, 5, s.Clients.Len()) + require.Equal(t, 4, s.Clients.Len()) s.clearExpiredClients(n) - - // server.Clients also includes a server.inlineClient, - // so the total number of clients here should be increased by 1 - require.Equal(t, 3, s.Clients.Len()) + require.Equal(t, 2, s.Clients.Len()) } func TestLoadServerInfoRestoreOnRestart(t *testing.T) { @@ -3277,12 +3342,9 @@ func TestAtomicItoa(t *testing.T) { } func TestServerSubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {} - handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { - // handler logic - } - - s := New(nil) + s := newServerWithInlineClient() require.NotNil(t, s) tt := []struct { @@ -3335,7 +3397,7 @@ func TestServerSubscribe(t *testing.T) { expect: nil, }, { - desc: "subscribe invalied ###", + desc: "subscribe invalid ###", filter: "###", identifier: 1, handler: handler, @@ -3357,12 +3419,19 @@ func TestServerSubscribe(t *testing.T) { } } +func TestServerSubscribeNoInlineClient(t *testing.T) { + s := newServer() + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {}) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + func TestServerUnsubscribe(t *testing.T) { handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { // handler logic } - s := New(nil) + s := newServerWithInlineClient() err := s.Subscribe("a/b/c", 1, handler) require.Nil(t, err) @@ -3388,8 +3457,15 @@ func TestServerUnsubscribe(t *testing.T) { require.Equal(t, packets.ErrTopicFilterInvalid, err) } -func TestPublishToInlineSubscriber(t *testing.T) { +func TestServerUnsubscribeNoInlineClient(t *testing.T) { s := newServer() + err := s.Unsubscribe("a/b/c", 1) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + +func TestPublishToInlineSubscriber(t *testing.T) { + s := newServerWithInlineClient() finishCh := make(chan bool) err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { require.Equal(t, []byte("hello mochi"), pk.Payload) @@ -3409,8 +3485,8 @@ func TestPublishToInlineSubscriber(t *testing.T) { require.Equal(t, true, <-finishCh) } -func TestPublishToInlineSubscribersDiffrentFilter(t *testing.T) { - s := newServer() +func TestPublishToInlineSubscribersDifferentFilter(t *testing.T) { + s := newServerWithInlineClient() subNumber := 2 finishCh := make(chan bool, subNumber) @@ -3447,8 +3523,8 @@ func TestPublishToInlineSubscribersDiffrentFilter(t *testing.T) { } } -func TestPublishToInlineSubscribersDiffrentIdentifier(t *testing.T) { - s := newServer() +func TestPublishToInlineSubscribersDifferentIdentifier(t *testing.T) { + s := newServerWithInlineClient() subNumber := 2 finishCh := make(chan bool, subNumber) @@ -3483,7 +3559,7 @@ func TestPublishToInlineSubscribersDiffrentIdentifier(t *testing.T) { } func TestServerSubscribeWithRetain(t *testing.T) { - s := newServer() + s := newServerWithInlineClient() subNumber := 1 finishCh := make(chan bool, subNumber) @@ -3502,8 +3578,8 @@ func TestServerSubscribeWithRetain(t *testing.T) { require.Equal(t, true, <-finishCh) } -func TestServerSubscribeWithRetainDiffrentFilter(t *testing.T) { - s := newServer() +func TestServerSubscribeWithRetainDifferentFilter(t *testing.T) { + s := newServerWithInlineClient() subNumber := 2 finishCh := make(chan bool, subNumber) @@ -3537,8 +3613,8 @@ func TestServerSubscribeWithRetainDiffrentFilter(t *testing.T) { } } -func TestServerSubscribeWithRetainDiffrentIdentifier(t *testing.T) { - s := newServer() +func TestServerSubscribeWithRetainDifferentIdentifier(t *testing.T) { + s := newServerWithInlineClient() subNumber := 2 finishCh := make(chan bool, subNumber) diff --git a/topics.go b/topics.go index fe8b378..63f704d 100644 --- a/topics.go +++ b/topics.go @@ -705,7 +705,7 @@ func IsSharedFilter(filter string) bool { // IsValidFilter returns true if the filter is valid. func IsValidFilter(filter string, forPublish bool) bool { - if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs. + if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish. return false // [MQTT-4.7.3-1] }