Small fixes and cleanups (#295)

* fix typos, indicate unused returns

* Add test for publishToClient acl unauthorized

* Add Inline Client as a server option
This commit is contained in:
JB
2023-09-08 23:06:14 +01:00
committed by GitHub
parent 58f9fed336
commit add87fea2e
44 changed files with 468 additions and 387 deletions

View File

@@ -99,7 +99,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
type Client struct {
Properties ClientProperties // client properties
State ClientState // the operational state of the client.
Net ClientConnection // network connection state of the clinet
Net ClientConnection // network connection state of the client
ID string // the client id.
ops *ops // ops provides a reference to server ops.
sync.RWMutex // mutex
@@ -111,7 +111,7 @@ type ClientConnection struct {
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
Inline bool // client is an inline programmetic client
Inline bool // if true, the client is the built-in 'inline' embedded client
}
// ClientProperties contains the properties which define the client behaviour.
@@ -134,7 +134,7 @@ type Will struct {
Retain bool // -
}
// State tracks the state of the client.
// ClientState tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
@@ -311,7 +311,7 @@ func (cl *Client) ResendInflightMessages(force bool) error {
return nil
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {

View File

@@ -263,7 +263,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
@@ -303,7 +303,7 @@ func TestClientResendInflightMessages(t *testing.T) {
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
w.Close()
_ = w.Close()
}()
buf, err := io.ReadAll(r)
@@ -315,7 +315,7 @@ func TestClientResendInflightMessages(t *testing.T) {
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newTestClient()
r.Close()
_ = r.Close()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
@@ -342,8 +342,8 @@ func TestClientReadFixedHeader(t *testing.T) {
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0x00})
r.Close()
_, _ = r.Write([]byte{packets.Connect << 4, 0x00})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -357,8 +357,8 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
r.Close()
_, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -372,8 +372,8 @@ func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
defer cl.Stop(errClientStop)
go func() {
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
r.Close()
_, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -387,7 +387,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
defer cl.Stop(errClientStop)
go func() {
r.Close()
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -401,8 +401,8 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
r.Close()
_, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -414,7 +414,7 @@ func TestClientReadOK(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
packets.Publish << 4, 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
@@ -424,7 +424,7 @@ func TestClientReadOK(t *testing.T) {
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
_ = r.Close()
}()
var pks []packets.Packet
@@ -499,10 +499,10 @@ func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
})
r.Close()
_ = r.Close()
}()
cl.Net.bconn = nil
@@ -516,13 +516,13 @@ func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
_ = r.Close()
}()
err := cl.Read(func(cl *Client, pk packets.Packet) error {
@@ -536,13 +536,13 @@ func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -573,7 +573,7 @@ func TestClientReadPacket(t *testing.T) {
t.Run(tt.Desc, func(t *testing.T) {
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
go func() {
r.Write(tt.RawBytes)
_, _ = r.Write(tt.RawBytes)
}()
fh := new(packets.FixedHeader)
@@ -600,7 +600,7 @@ func TestClientReadPacket(t *testing.T) {
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
_ = cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
@@ -624,7 +624,7 @@ func TestClientWritePacket(t *testing.T) {
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.Conn.Close()
_ = cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
@@ -660,13 +660,13 @@ func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
_ = r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
@@ -680,13 +680,13 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
_ = r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
@@ -706,7 +706,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
_ = cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)

View File

@@ -60,7 +60,7 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -78,6 +78,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -60,6 +60,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -47,6 +47,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -61,6 +61,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -26,7 +26,9 @@ func main() {
done <- true
}()
server := mqtt.New(nil)
server := mqtt.New(&mqtt.Options{
InlineClient: true, // you must enable inline client to use direct publishing and subscribing.
})
_ = server.AddHook(new(auth.AllowHook), nil)
// Start the server
@@ -50,12 +52,13 @@ func main() {
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
}
server.Log.Info("inline client subscribing")
server.Subscribe("direct/#", 1, callbackFn)
server.Subscribe("direct/#", 2, callbackFn)
_ = server.Subscribe("direct/#", 1, callbackFn)
_ = server.Subscribe("direct/#", 2, callbackFn)
}()
// There is a shorthand convenience function, Publish, for easily sending
// publish packets if you are not concerned with creating your own packets.
// There is a shorthand convenience function, Publish, for easily sending publish packets if you are not
// concerned with creating your own packets. If you want to have more control over your packets, you can
//directly inject a packet of any kind into the broker. See examples/hooks/main.go for usage.
go func() {
for range time.Tick(time.Second * 3) {
err := server.Publish("direct/publish", []byte("scheduled message"), false, 0)
@@ -70,23 +73,8 @@ func main() {
time.Sleep(time.Second * 10)
// Unsubscribe from the same filter to stop receiving messages.
server.Log.Info("inline client unsubscribing")
server.Unsubscribe("direct/#", 1)
_ = server.Unsubscribe("direct/#", 1)
}()
// If you want to have more control over your packets, you can directly inject a packet of any kind into the broker.
//go func() {
// for range time.Tick(time.Second * 5) {
// err := server.InjectPacket(cl, packets.Packet{
// FixedHeader: packets.FixedHeader{
// Type: packets.Publish,
// },
// TopicName: "direct/publish",
// Payload: []byte("injected scheduled message"),
// })
// if err != nil {
// log.Fatal(err)
// }
// }
//}()
<-done
server.Log.Warn("caught signal, stopping...")

View File

@@ -83,7 +83,7 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -46,7 +46,7 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -53,6 +53,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -55,6 +55,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -63,6 +63,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -53,6 +53,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -112,6 +112,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -42,6 +42,6 @@ func main() {
<-done
server.Log.Warn("caught signal, stopping...")
server.Close()
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -80,8 +80,8 @@ func (r RString) Matches(a string) bool {
}
// FilterMatches returns true if a filter matches a topic rule.
func (f RString) FilterMatches(a string) bool {
_, ok := MatchTopic(string(f), a)
func (r RString) FilterMatches(a string) bool {
_, ok := MatchTopic(string(r), a)
return ok
}
@@ -161,7 +161,7 @@ func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
}
// ACLOk returns true if the rules indicate the user is allowed to read or write to
// a specific filter or topic respectively, based on the write bool.
// a specific filter or topic respectively, based on the `write` bool.
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
@@ -209,7 +209,7 @@ func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok boo
}
}
for filter, _ := range rule.Filters {
for filter := range rule.Filters {
if filter.FilterMatches(topic) {
return n, false
}

View File

@@ -561,17 +561,17 @@ func TestLedgerUpdate(t *testing.T) {
},
}
new := &Ledger{
n := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
{Remote: "192.168.*", Allow: true},
},
}
old.Update(new)
old.Update(n)
require.Len(t, old.Auth, 2)
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
require.NotSame(t, new, old)
require.NotSame(t, n, old)
}
func TestLedgerToJSON(t *testing.T) {

View File

@@ -114,7 +114,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug("inflight dropped", "m", h.packetMeta(pk))
}
// OnLWTSent is called when a will message has been issued from a disconnecting client.
// OnLWTSent is called when a Will Message has been issued from a disconnecting client.
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID)
}
@@ -136,25 +136,25 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
return v, nil
}
// StoredClients is called when the server restores subscriptions from a store.
// StoredSubscriptions is called when the server restores subscriptions from a store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
h.Log.Debug("", "method", "StoredSubscriptions")
return v, nil
}
// StoredClients is called when the server restores retained messages from a store.
// StoredRetainedMessages is called when the server restores retained messages from a store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
h.Log.Debug("", "method", "StoredRetainedMessages")
return v, nil
}
// StoredClients is called when the server restores inflight messages from a store.
// StoredInflightMessages is called when the server restores inflight messages from a store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
h.Log.Debug("", "method", "StoredInflightMessages")
return v, nil
}
// StoredClients is called when the server restores system info from a store.
// StoredSysInfo is called when the server restores system info from a store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
h.Log.Debug("", "method", "StoredSysInfo")

View File

@@ -128,8 +128,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}

View File

@@ -38,8 +38,8 @@ var (
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
h.db.Badger().Close()
_ = h.Stop()
_ = h.db.Badger().Close()
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
require.NoError(t, err)
}

View File

@@ -1,7 +1,8 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
package bolt
import (
@@ -132,8 +133,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}

View File

@@ -38,7 +38,7 @@ var (
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
_ = h.Stop()
err := os.Remove(path)
require.NoError(t, err)
}

View File

@@ -15,7 +15,7 @@ import (
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
redis "github.com/go-redis/redis/v8"
"github.com/go-redis/redis/v8"
)
// defaultAddr is the default address to the redis service.
@@ -134,7 +134,7 @@ func (h *Hook) Init(config any) error {
return nil
}
// Close closes the redis connection.
// Stop closes the redis connection.
func (h *Hook) Stop() error {
h.Log.Info("disconnecting from redis service")
@@ -146,8 +146,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}

View File

@@ -25,7 +25,7 @@ var (
ErrDBFileNotOpen = errors.New("db file not open")
)
// Client is a storable representation of an mqtt client.
// Client is a storable representation of an MQTT client.
type Client struct {
Will ClientWill `json:"will"` // will topic and payload data if applicable
Properties ClientProperties `json:"properties"` // the connect properties for the client
@@ -147,7 +147,7 @@ func (d *Message) ToPacket() packets.Packet {
return pk
}
// Subscription is a storable representation of an mqtt subscription.
// Subscription is a storable representation of an MQTT subscription.
type Subscription struct {
T string `json:"t"`
ID string `json:"id" storm:"id"`

View File

@@ -79,9 +79,9 @@ func (l *HTTPHealthCheck) Init(_ *slog.Logger) error {
// Serve starts listening for new connections and serving responses.
func (l *HTTPHealthCheck) Serve(establish EstablishFn) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
_ = l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
_ = l.listen.ListenAndServe()
}
}
@@ -93,7 +93,7 @@ func (l *HTTPHealthCheck) Close(closeClients CloseFn) {
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
_ = l.listen.Shutdown(ctx)
}
closeClients(l.id)

View File

@@ -39,7 +39,7 @@ func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
TLSConfig: tlsConfigBasic,
})
l.Init(logger)
_ = l.Init(logger)
require.Equal(t, "https", l.Protocol())
}

View File

@@ -81,9 +81,9 @@ func (l *HTTPStats) Init(_ *slog.Logger) error {
// Serve starts listening for new connections and serving responses.
func (l *HTTPStats) Serve(establish EstablishFn) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
_ = l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
_ = l.listen.ListenAndServe()
}
}
@@ -95,7 +95,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
_ = l.listen.Shutdown(ctx)
}
closeClients(l.id)
@@ -107,8 +107,8 @@ func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
out, err := json.MarshalIndent(info, "", "\t")
if err != nil {
io.WriteString(w, err.Error())
_, _ = io.WriteString(w, err.Error())
}
w.Write(out)
_, _ = w.Write(out)
}

View File

@@ -42,7 +42,7 @@ func TestHTTPStatsTLSProtocol(t *testing.T) {
TLSConfig: tlsConfigBasic,
}, nil)
l.Init(logger)
_ = l.Init(logger)
require.Equal(t, "https", l.Protocol())
}

View File

@@ -22,7 +22,7 @@ type Config struct {
// EstablishFn is a callback function for establishing new clients.
type EstablishFn func(id string, c net.Conn) error
// CloseFunc is a callback function for closing all listener clients.
// CloseFn is a callback function for closing all listener clients.
type CloseFn func(id string)
// Listener is an interface for network listeners. A network listener listens

View File

@@ -16,7 +16,7 @@ func TestMockEstablisher(t *testing.T) {
_, w := net.Pipe()
err := MockEstablisher("t1", w)
require.NoError(t, err)
w.Close()
_ = w.Close()
}
func TestNewMockListener(t *testing.T) {
@@ -86,7 +86,7 @@ func TestMockListenerServe(t *testing.T) {
require.Equal(t, true, closed)
<-o
mocked.Init(nil)
_ = mocked.Init(nil)
}
func TestMockListenerClose(t *testing.T) {

View File

@@ -98,7 +98,7 @@ func TestNetEstablishThenEnd(t *testing.T) {
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", n.Addr().String())
_, _ = net.Dial("tcp", n.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o

View File

@@ -39,7 +39,7 @@ func TestTCPProtocolTLS(t *testing.T) {
TLSConfig: tlsConfigBasic,
})
l.Init(logger)
_ = l.Init(logger)
defer l.listen.Close()
require.Equal(t, "tcp", l.Protocol())
}
@@ -124,7 +124,7 @@ func TestTCPEstablishThenEnd(t *testing.T) {
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", l.listen.Addr().String())
_, _ = net.Dial("tcp", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o

View File

@@ -89,7 +89,7 @@ func TestUnixSockEstablishThenEnd(t *testing.T) {
}()
time.Sleep(time.Millisecond)
net.Dial("unix", l.listen.Addr().String())
_, _ = net.Dial("unix", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o

View File

@@ -30,7 +30,7 @@ type Websocket struct { // [MQTT-4.2.0-1]
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // an http server for serving websocket connections
listen *http.Server // a http server for serving websocket connections
log *slog.Logger // server logger
establish EstablishFn // the server's establish connection handler
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
@@ -112,9 +112,9 @@ func (l *Websocket) Serve(establish EstablishFn) {
l.establish = establish
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
_ = l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
_ = l.listen.ListenAndServe()
}
}
@@ -126,7 +126,7 @@ func (l *Websocket) Close(closeClients CloseFn) {
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
_ = l.listen.Shutdown(ctx)
}
closeClients(l.id)
@@ -137,7 +137,7 @@ type wsConn struct {
net.Conn
c *websocket.Conn
// reader for the current message (may be nil)
// reader for the current message (can be nil)
r io.Reader
}

View File

@@ -37,14 +37,14 @@ func TestWebsocketProtocol(t *testing.T) {
require.Equal(t, "ws", l.Protocol())
}
func TestWebsocketProtocoTLS(t *testing.T) {
func TestWebsocketProtocolTLS(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
require.Equal(t, "wss", l.Protocol())
}
func TestWebsockeInit(t *testing.T) {
func TestWebsocketInit(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Nil(t, l.listen)
err := l.Init(logger)
@@ -54,7 +54,7 @@ func TestWebsockeInit(t *testing.T) {
func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(logger)
_ = l.Init(logger)
o := make(chan bool)
go func(o chan bool) {
@@ -96,7 +96,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(logger)
_ = l.Init(logger)
e := make(chan bool)
l.establish = func(id string, c net.Conn) error {
@@ -110,12 +110,12 @@ func TestWebsocketUpgrade(t *testing.T) {
require.Equal(t, true, <-e)
s.Close()
ws.Close()
_ = ws.Close()
}
func TestWebsocketConnectionReads(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
_ = l.Init(nil)
recv := make(chan []byte)
l.establish = func(id string, c net.Conn) error {
@@ -151,5 +151,5 @@ func TestWebsocketConnectionReads(t *testing.T) {
require.Equal(t, pkt, got)
s.Close()
ws.Close()
_ = ws.Close()
}

View File

@@ -19,7 +19,7 @@ func TestCodesString(t *testing.T) {
require.Equal(t, "test", c.String())
}
func TestCodesErrorr(t *testing.T) {
func TestCodesError(t *testing.T) {
c := Code{
Reason: "error",
Code: 0x1,

View File

@@ -14,7 +14,7 @@ import (
"sync"
)
// All of the valid packet types and their packet identifier.
// All valid packet types and their packet identifiers.
const (
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
@@ -37,9 +37,9 @@ const (
var (
// ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification.
ErrNoValidPacketAvailable error = errors.New("no valid packet available")
ErrNoValidPacketAvailable = errors.New("no valid packet available")
// PacketNames is a map of packet bytes to human readable names, for easier debugging.
// PacketNames is a map of packet bytes to human-readable names, for easier debugging.
PacketNames = map[byte]string{
0: "Reserved",
1: "Connect",
@@ -272,28 +272,28 @@ func (s Subscription) Merge(n Subscription) Subscription {
}
// encode encodes a subscription and properties into bytes.
func (p Subscription) encode() byte {
func (s Subscription) encode() byte {
var flag byte
flag |= p.Qos
flag |= s.Qos
if p.NoLocal {
if s.NoLocal {
flag |= 1 << 2
}
if p.RetainAsPublished {
if s.RetainAsPublished {
flag |= 1 << 3
}
flag |= p.RetainHandling << 4
flag |= s.RetainHandling << 4
return flag
}
// decode decodes subscription bytes into a subscription struct.
func (p *Subscription) decode(b byte) {
p.Qos = b & 3 // byte
p.NoLocal = 1&(b>>2) > 0 // bool
p.RetainAsPublished = 1&(b>>3) > 0 // bool
p.RetainHandling = 3 & (b >> 4) // byte
func (s *Subscription) decode(b byte) {
s.Qos = b & 3 // byte
s.NoLocal = 1&(b>>2) > 0 // bool
s.RetainAsPublished = 1&(b>>3) > 0 // bool
s.RetainHandling = 3 & (b >> 4) // byte
}
// ConnectEncode encodes a connect packet.
@@ -343,7 +343,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -505,7 +505,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -548,7 +548,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -619,7 +619,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -707,7 +707,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -844,7 +844,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -901,7 +901,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -996,7 +996,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -1049,7 +1049,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -1109,7 +1109,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}

View File

@@ -150,7 +150,7 @@ func TestPacketEncode(t *testing.T) {
}
pk := new(Packet)
copier.Copy(pk, wanted.Packet)
_ = copier.Copy(pk, wanted.Packet)
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
pk.Mods.AllowResponseInfo = true
@@ -218,7 +218,7 @@ func TestPacketDecode(t *testing.T) {
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
pk.Mods.AllowResponseInfo = true
pk.FixedHeader.Decode(wanted.RawBytes[0])
_ = pk.FixedHeader.Decode(wanted.RawBytes[0])
if len(wanted.RawBytes) > 0 {
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
}

View File

@@ -77,7 +77,7 @@ type UserProperty struct { // [MQTT-1.5.7-1]
Val string `json:"v"`
}
// Properties contains all of the mqtt v5 properties available for a packet.
// Properties contains all mqtt v5 properties available for a packet.
// Some properties have valid values of 0 or not-present. In this case, we opt for
// property flags to indicate the usage of property.
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
@@ -355,7 +355,7 @@ func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
}
encodeLength(b, int64(buf.Len()))
buf.WriteTo(b) // [MQTT-3.1.3-10]
_, _ = buf.WriteTo(b) // [MQTT-3.1.3-10]
}
// Decode decodes property bytes into a properties struct.

View File

@@ -40,7 +40,6 @@ const (
TConnectMqtt5
TConnectMqtt5LWT
TConnectClean
TConnectCleanLWT
TConnectUserPass
TConnectUserPassLWT
TConnectMalProtocolName
@@ -61,7 +60,6 @@ const (
TConnectInvalidProtocolVersion2
TConnectInvalidReservedBit
TConnectInvalidClientIDTooLong
TConnectInvalidPasswordNoUsername
TConnectInvalidFlagNoUsername
TConnectInvalidFlagNoPassword
TConnectInvalidUsernameNoFlag
@@ -186,7 +184,6 @@ const (
TUnsubscribe
TUnsubscribeMany
TUnsubscribeMqtt5
TUnsubscribeDropProperties
TUnsubscribeMalPacketID
TUnsubscribeMalTopicName
TUnsubscribeMalProperties
@@ -204,7 +201,6 @@ const (
TDisconnect
TDisconnectTakeover
TDisconnectMqtt5
TDisconnectNormalMqtt5
TDisconnectSecondConnect
TDisconnectReceiveMaximum
TDisconnectDropProperties

View File

@@ -2,7 +2,7 @@
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
// package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
// Package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
package mqtt
import (
@@ -26,7 +26,7 @@ import (
)
const (
Version = "2.3.0" // the current server version.
Version = "2.4.0" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
LocalListener = "local"
InlineClientId = "inline"
@@ -38,7 +38,7 @@ var (
MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions
MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over
ReceiveMaximum: 1024, // maximum number of concurrent qos messages per client
MaximumQos: 2, // maxmimum qos value available to clients
MaximumQos: 2, // maximum qos value available to clients
RetainAvailable: 1, // retain messages is available
MaximumPacketSize: 0, // no maximum packet size
TopicAliasMaximum: math.MaxUint16, // maximum topic alias value
@@ -49,8 +49,9 @@ var (
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
}
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists.
ErrConnectionClosed = errors.New("connection not open") // connection is closed
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists
ErrConnectionClosed = errors.New("connection not open") // connection is closed
ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default
)
// Capabilities indicates the capabilities and features provided by the server.
@@ -106,6 +107,10 @@ type Options struct {
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
SysTopicResendInterval int64
// Enable Inline client to allow direct subscribing and publishing from the parent codebase,
// with negligible performance difference (disabled by default to prevent confusion in statistics).
InlineClient bool
}
// Server is an MQTT broker server. It should be created with server.New()
@@ -119,8 +124,8 @@ type Server struct {
loop *loop // loop contains tickers for the system event loop
done chan bool // indicate that the server is ending
Log *slog.Logger // minimal no-alloc logger
hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage.
inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish.
hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage
inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish
}
// loop contains interval tickers for the system events loop.
@@ -129,7 +134,7 @@ type loop struct {
clientExpiry *time.Ticker // interval ticker for cleaning expired clients
inflightExpiry *time.Ticker // interval ticker for cleaning up expired inflight messages
retainedExpiry *time.Ticker // interval ticker for cleaning retained messages
willDelaySend *time.Ticker // interval ticker for sending will messages with a delay
willDelaySend *time.Ticker // interval ticker for sending Will Messages with a delay
willDelayed *packets.Packets // activate LWT packets which will be sent after a delay
}
@@ -173,8 +178,11 @@ func New(opts *Options) *Server {
Log: opts.Logger,
},
}
s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true)
s.Clients.Add(s.inlineClient)
if s.Options.InlineClient {
s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true)
s.Clients.Add(s.inlineClient)
}
return s
}
@@ -426,7 +434,7 @@ func (s *Server) receivePacket(cl *Client, pk packets.Packet) error {
if code, ok := err.(packets.Code); ok &&
cl.Properties.ProtocolVersion == 5 &&
code.Code >= packets.ErrUnspecifiedError.Code {
s.DisconnectClient(cl, code)
_ = s.DisconnectClient(cl, code)
}
s.Log.Warn("error processing packet", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "pk", pk)
@@ -464,7 +472,7 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
// session is abandoned.
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
_ = s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
@@ -649,11 +657,15 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
})
}
// Publish publishes a publish packet into the broker as if it were sent from the speicfied client.
// Publish publishes a publish packet into the broker as if it were sent from the specified client.
// This is a convenience function which wraps InjectPacket. As such, this method can publish packets
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
if !s.Options.InlineClient {
return ErrInlineClientNotEnabled
}
return s.InjectPacket(s.inlineClient, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
@@ -669,11 +681,18 @@ func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) er
// Subscribe adds an inline subscription for the specified topic filter and subscription identifier
// with the provided handler function.
func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error {
if !s.Options.InlineClient {
return ErrInlineClientNotEnabled
}
if handler == nil {
return packets.ErrInlineSubscriptionHandlerInvalid
} else if !IsValidFilter(filter, false) {
}
if !IsValidFilter(filter, false) {
return packets.ErrTopicFilterInvalid
}
subscription := packets.Subscription{
Identifier: subscriptionId,
Filter: filter,
@@ -704,6 +723,10 @@ func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubF
// It allows you to unsubscribe a specific subscription from the internal subscription
// associated with the given topic filter.
func (s *Server) Unsubscribe(filter string, subscriptionId int) error {
if !s.Options.InlineClient {
return ErrInlineClientNotEnabled
}
if !IsValidFilter(filter, false) {
return packets.ErrTopicFilterInvalid
}
@@ -761,12 +784,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
return s.DisconnectClient(cl, packets.ErrNotAuthorized)
}
if pk.FixedHeader.Qos == 1 {
ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.ErrNotAuthorized)
return cl.WritePacket(ack)
ackType := packets.Puback
if pk.FixedHeader.Qos == 2 {
ackType = packets.Pubrec
}
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrNotAuthorized)
ack := s.buildAck(pk.PacketID, ackType, 0, pk.Properties, packets.ErrNotAuthorized)
return cl.WritePacket(ack)
}
@@ -790,7 +813,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
}
if pk.FixedHeader.Qos > s.Options.Capabilities.MaximumQos {
pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce Qos based on server max qos capability
pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce qos based on server max qos capability
}
pkx, err := s.hooks.OnPublish(cl, pk)
@@ -1373,7 +1396,7 @@ func (s *Server) Close() error {
func (s *Server) closeListenerClients(listener string) {
clients := s.Clients.GetByListener(listener)
for _, cl := range clients {
s.DisconnectClient(cl, packets.ErrServerShuttingDown)
_ = s.DisconnectClient(cl, packets.ErrServerShuttingDown)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -705,7 +705,7 @@ func IsSharedFilter(filter string) bool {
// IsValidFilter returns true if the filter is valid.
func IsValidFilter(filter string, forPublish bool) bool {
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs.
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish.
return false // [MQTT-4.7.3-1]
}