mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-05 08:07:06 +08:00
Compare commits
8 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0de1d731db | ||
![]() |
80746abc52 | ||
![]() |
a73cf4ca0e | ||
![]() |
bc549ee7ed | ||
![]() |
c464b46713 | ||
![]() |
05ce56008c | ||
![]() |
8254cb0cbc | ||
![]() |
4ae58b79e3 |
@@ -83,7 +83,11 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
|
||||
Importing Mochi MQTT as a package requires just a few lines of code to get started.
|
||||
``` go
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -94,7 +98,7 @@ func main() {
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// Create a TCP listener on a standard port.
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
18
clients.go
18
clients.go
@@ -106,7 +106,7 @@ type Client struct {
|
||||
|
||||
// ClientConnection contains the connection transport and metadata for the client.
|
||||
type ClientConnection struct {
|
||||
conn net.Conn // the net.Conn used to establish the connection
|
||||
Conn net.Conn // the net.Conn used to establish the connection
|
||||
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
|
||||
Remote string // the remote address of the client
|
||||
Listener string // listener id of the client
|
||||
@@ -164,7 +164,7 @@ func newClient(c net.Conn, o *ops) *Client {
|
||||
|
||||
if c != nil {
|
||||
cl.Net = ClientConnection{
|
||||
conn: c,
|
||||
Conn: c,
|
||||
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
}
|
||||
@@ -223,13 +223,13 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
|
||||
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||
func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
if cl.Net.conn != nil {
|
||||
if cl.Net.Conn != nil {
|
||||
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
|
||||
}
|
||||
|
||||
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -297,7 +297,7 @@ func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
deleted = append(deleted, uint16(tk.PacketID))
|
||||
deleted = append(deleted, tk.PacketID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -341,8 +341,8 @@ func (cl *Client) Stop(err error) {
|
||||
}
|
||||
|
||||
cl.State.endOnce.Do(func() {
|
||||
if cl.Net.conn != nil {
|
||||
_ = cl.Net.conn.Close() // omit close error
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.Close() // omit close error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -464,7 +464,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if cl.Net.conn == nil {
|
||||
if cl.Net.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -533,7 +533,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
}
|
||||
|
||||
nb := net.Buffers{buf.Bytes()}
|
||||
n, err := nb.WriteTo(cl.Net.conn)
|
||||
n, err := nb.WriteTo(cl.Net.Conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -127,7 +127,7 @@ func TestNewClient(t *testing.T) {
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.conn)
|
||||
require.NotNil(t, cl.Net.Conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.False(t, cl.Net.Inline)
|
||||
}
|
||||
@@ -320,7 +320,7 @@ func TestClientResendInflightMessagesNoMessages(t *testing.T) {
|
||||
func TestClientRefreshDeadline(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.refreshDeadline(10)
|
||||
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
|
||||
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeader(t *testing.T) {
|
||||
@@ -586,7 +586,7 @@ func TestClientReadPacket(t *testing.T) {
|
||||
|
||||
func TestClientReadPacketInvalidTypeError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.conn.Close()
|
||||
cl.Net.Conn.Close()
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid packet type")
|
||||
@@ -610,7 +610,7 @@ func TestClientWritePacket(t *testing.T) {
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
cl.Net.conn.Close()
|
||||
cl.Net.Conn.Close()
|
||||
|
||||
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
@@ -692,7 +692,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.conn.Close()
|
||||
cl.Net.Conn.Close()
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
|
89
hooks.go
89
hooks.go
@@ -1,6 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
// SPDX-FileContributor: mochi-co, thedevop
|
||||
|
||||
package mqtt
|
||||
|
||||
@@ -109,11 +109,11 @@ type HookOptions struct {
|
||||
|
||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||
type Hooks struct {
|
||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||
internal []Hook // a slice of hooks
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex
|
||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||
internal []Hook // a slice of hooks
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.RWMutex // a mutex
|
||||
}
|
||||
|
||||
// Len returns the number of hooks added.
|
||||
@@ -123,7 +123,7 @@ func (h *Hooks) Len() int64 {
|
||||
|
||||
// Provides returns true if any one hook provides any of the requested hook methods.
|
||||
func (h *Hooks) Provides(b ...byte) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
for _, hb := range b {
|
||||
if hook.Provides(hb) {
|
||||
return true
|
||||
@@ -154,10 +154,17 @@ func (h *Hooks) Add(hook Hook, config any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll returns a slice of all the hooks.
|
||||
func (h *Hooks) GetAll() []Hook {
|
||||
h.RLock()
|
||||
defer h.RUnlock()
|
||||
return append([]Hook{}, h.internal...)
|
||||
}
|
||||
|
||||
// Stop indicates all attached hooks to gracefully end.
|
||||
func (h *Hooks) Stop() {
|
||||
go func() {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
|
||||
if err := hook.Stop(); err != nil {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
|
||||
@@ -172,7 +179,7 @@ func (h *Hooks) Stop() {
|
||||
|
||||
// OnSysInfoTick is called when the $SYS topic values are published out.
|
||||
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSysInfoTick) {
|
||||
hook.OnSysInfoTick(sys)
|
||||
}
|
||||
@@ -181,7 +188,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
|
||||
// OnStarted is called when the server has successfully started.
|
||||
func (h *Hooks) OnStarted() {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStarted) {
|
||||
hook.OnStarted()
|
||||
}
|
||||
@@ -190,7 +197,7 @@ func (h *Hooks) OnStarted() {
|
||||
|
||||
// OnStopped is called when the server has successfully stopped.
|
||||
func (h *Hooks) OnStopped() {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStopped) {
|
||||
hook.OnStopped()
|
||||
}
|
||||
@@ -199,7 +206,7 @@ func (h *Hooks) OnStopped() {
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnect) {
|
||||
hook.OnConnect(cl, pk)
|
||||
}
|
||||
@@ -208,7 +215,7 @@ func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSessionEstablished) {
|
||||
hook.OnSessionEstablished(cl, pk)
|
||||
}
|
||||
@@ -217,7 +224,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnDisconnect) {
|
||||
hook.OnDisconnect(cl, err, expire)
|
||||
}
|
||||
@@ -227,7 +234,7 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
// OnPacketRead is called when a packet is received from a client.
|
||||
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketRead) {
|
||||
npk, err := hook.OnPacketRead(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
@@ -248,7 +255,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
|
||||
// to create their own auth packet handling mechanisms.
|
||||
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnAuthPacket) {
|
||||
npk, err := hook.OnAuthPacket(cl, pkx)
|
||||
if err != nil {
|
||||
@@ -264,7 +271,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet,
|
||||
|
||||
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
|
||||
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketEncode) {
|
||||
pk = hook.OnPacketEncode(cl, pk)
|
||||
}
|
||||
@@ -275,7 +282,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
|
||||
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketProcessed) {
|
||||
hook.OnPacketProcessed(cl, pk, err)
|
||||
}
|
||||
@@ -285,7 +292,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
|
||||
// containing the bytes sent.
|
||||
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketSent) {
|
||||
hook.OnPacketSent(cl, pk, b)
|
||||
}
|
||||
@@ -297,7 +304,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribe) {
|
||||
pk = hook.OnSubscribe(cl, pk)
|
||||
}
|
||||
@@ -307,7 +314,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribed) {
|
||||
hook.OnSubscribed(cl, pk, reasonCodes)
|
||||
}
|
||||
@@ -319,7 +326,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
|
||||
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
|
||||
// group in a custom manner (such as based on client id, ip, etc).
|
||||
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSelectSubscribers) {
|
||||
subs = hook.OnSelectSubscribers(subs, pk)
|
||||
}
|
||||
@@ -332,7 +339,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribe) {
|
||||
pk = hook.OnUnsubscribe(cl, pk)
|
||||
}
|
||||
@@ -342,7 +349,7 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribed) {
|
||||
hook.OnUnsubscribed(cl, pk)
|
||||
}
|
||||
@@ -354,7 +361,7 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublish) {
|
||||
npk, err := hook.OnPublish(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
@@ -373,7 +380,7 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublished) {
|
||||
hook.OnPublished(cl, pk)
|
||||
}
|
||||
@@ -382,7 +389,7 @@ func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainMessage) {
|
||||
hook.OnRetainMessage(cl, pk, r)
|
||||
}
|
||||
@@ -393,7 +400,7 @@ func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
|
||||
// In other words, this method is called when a new inflight message is created or resent.
|
||||
// It is typically used to store a new inflight message.
|
||||
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosPublish) {
|
||||
hook.OnQosPublish(cl, pk, sent, resends)
|
||||
}
|
||||
@@ -404,7 +411,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends
|
||||
// In other words, when an inflight message is resolved.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosComplete) {
|
||||
hook.OnQosComplete(cl, pk)
|
||||
}
|
||||
@@ -415,7 +422,7 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
// an inflight message expires or is abandoned. It is typically used to delete an
|
||||
// inflight message from a store.
|
||||
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosDropped) {
|
||||
hook.OnQosDropped(cl, pk)
|
||||
}
|
||||
@@ -427,7 +434,7 @@ func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||
// published. The return values of the hook methods are passed-through in the order
|
||||
// the hooks were attached.
|
||||
func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWill) {
|
||||
mlwt, err := hook.OnWill(cl, will)
|
||||
if err != nil {
|
||||
@@ -443,7 +450,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWillSent) {
|
||||
hook.OnWillSent(cl, pk)
|
||||
}
|
||||
@@ -452,7 +459,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnClientExpired is called when a client session has expired and should be deleted.
|
||||
func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnClientExpired) {
|
||||
hook.OnClientExpired(cl)
|
||||
}
|
||||
@@ -461,7 +468,7 @@ func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
|
||||
// OnRetainedExpired is called when a retained message has expired and should be deleted.
|
||||
func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainedExpired) {
|
||||
hook.OnRetainedExpired(filter)
|
||||
}
|
||||
@@ -471,7 +478,7 @@ func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
// StoredClients returns all clients, e.g. from a persistent store, is used to
|
||||
// populate the server clients list before start.
|
||||
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredClients) {
|
||||
v, err := hook.StoredClients()
|
||||
if err != nil {
|
||||
@@ -491,7 +498,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
|
||||
// used to populate the server subscriptions list before start.
|
||||
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSubscriptions) {
|
||||
v, err := hook.StoredSubscriptions()
|
||||
if err != nil {
|
||||
@@ -511,7 +518,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
|
||||
// and is used to populate the restored clients with inflight messages before start.
|
||||
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredInflightMessages) {
|
||||
v, err := hook.StoredInflightMessages()
|
||||
if err != nil {
|
||||
@@ -531,7 +538,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
|
||||
// and is used to populate the server topics with retained messages before start.
|
||||
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredRetainedMessages) {
|
||||
v, err := hook.StoredRetainedMessages()
|
||||
if err != nil {
|
||||
@@ -550,7 +557,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSysInfo) {
|
||||
v, err := hook.StoredSysInfo()
|
||||
if err != nil {
|
||||
@@ -572,7 +579,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check connecting users against an existing user database.
|
||||
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnectAuthenticate) {
|
||||
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
|
||||
return true
|
||||
@@ -588,7 +595,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check publishing and subscribing users against an existing permissions or roles database.
|
||||
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnACLCheck) {
|
||||
if ok := hook.OnACLCheck(cl, topic, write); ok {
|
||||
return true
|
||||
|
@@ -154,7 +154,7 @@ func (ws *wsConn) Read(p []byte) (int, error) {
|
||||
br, err = r.Read(p[n:])
|
||||
n += br
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
|
@@ -124,4 +124,23 @@ var (
|
||||
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
||||
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
||||
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
||||
|
||||
// MQTTv3 specific bytes.
|
||||
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
||||
Err3ClientIdentifierNotValid = Code{Code: 0x02}
|
||||
Err3ServerUnavailable = Code{Code: 0x03}
|
||||
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
|
||||
Err3NotAuthorized = Code{Code: 0x05}
|
||||
|
||||
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
|
||||
// This is required because MQTTv3 has different return byte specification.
|
||||
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
|
||||
V5CodesToV3 = map[Code]Code{
|
||||
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
|
||||
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
|
||||
ErrServerUnavailable: Err3ServerUnavailable,
|
||||
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
|
||||
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
|
||||
ErrBadUsernameOrPassword: Err3NotAuthorized,
|
||||
}
|
||||
)
|
||||
|
@@ -89,6 +89,7 @@ const (
|
||||
TConnackServerUnavailable
|
||||
TConnackBadUsernamePassword
|
||||
TConnackBadUsernamePasswordNoSession
|
||||
TConnackMqtt5BadUsernamePasswordNoSession
|
||||
TConnackNotAuthorised
|
||||
TConnackMalSessionPresent
|
||||
TConnackMalReturnCode
|
||||
@@ -1316,10 +1317,28 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "bad username or password no session",
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 2, // fixed header
|
||||
0, // No session present
|
||||
ErrBadUsernameOrPassword.Code,
|
||||
0, // No session present
|
||||
Err3NotAuthorized.Code, // use v3 remapping
|
||||
},
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 2,
|
||||
},
|
||||
ReasonCode: Err3NotAuthorized.Code,
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnackMqtt5BadUsernamePasswordNoSession,
|
||||
Desc: "mqtt5 bad username or password no session",
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 3, // fixed header
|
||||
0, // No session present
|
||||
ErrBadUsernameOrPassword.Code,
|
||||
0,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 2,
|
||||
@@ -1327,6 +1346,7 @@ var TPacketData = map[byte]TPacketCases{
|
||||
ReasonCode: ErrBadUsernameOrPassword.Code,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Case: TConnackNotAuthorised,
|
||||
Desc: "not authorised",
|
||||
|
14
server.go
14
server.go
@@ -26,7 +26,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "2.1.1" // the current server version.
|
||||
Version = "2.1.4" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
|
||||
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
|
||||
@@ -489,6 +489,12 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro
|
||||
}
|
||||
|
||||
if reason.Code >= packets.ErrUnspecifiedError.Code {
|
||||
if cl.Properties.ProtocolVersion < 5 {
|
||||
if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes
|
||||
reason = v3reason
|
||||
}
|
||||
}
|
||||
|
||||
properties.ReasonString = reason.Reason
|
||||
ack := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
@@ -773,12 +779,12 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (out packets.Packet, err error) {
|
||||
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) {
|
||||
if sub.NoLocal && pk.Origin == cl.ID {
|
||||
return pk, nil // [MQTT-3.8.3-3]
|
||||
}
|
||||
|
||||
out = pk.Copy(false)
|
||||
out := pk.Copy(false)
|
||||
if !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
|
||||
out.FixedHeader.Retain = false // [MQTT-3.3.1-12]
|
||||
}
|
||||
@@ -828,7 +834,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
|
||||
}
|
||||
}
|
||||
|
||||
if cl.Net.conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return pk, packets.CodeDisconnect
|
||||
}
|
||||
|
||||
|
@@ -117,7 +117,7 @@ func TestServerNewClient(t *testing.T) {
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.conn)
|
||||
require.NotNil(t, cl.Net.Conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.NotNil(t, cl.ops)
|
||||
require.Equal(t, s.Log, cl.ops.log)
|
||||
@@ -806,7 +806,7 @@ func TestInheritClientSession(t *testing.T) {
|
||||
n := time.Now().Unix()
|
||||
|
||||
existing, _, _ := newTestClient()
|
||||
existing.Net.conn = nil
|
||||
existing.Net.Conn = nil
|
||||
existing.ID = "mochi"
|
||||
existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
|
||||
existing.State.Inflight = NewInflights()
|
||||
@@ -1440,7 +1440,7 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) {
|
||||
func TestPublishToClientNoConn(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.conn = nil
|
||||
cl.Net.Conn = nil
|
||||
|
||||
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
|
||||
require.Error(t, err)
|
||||
@@ -1863,7 +1863,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) {
|
||||
for i, tx := range tt {
|
||||
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
|
||||
r, w = net.Pipe()
|
||||
cl.Net.conn = w
|
||||
cl.Net.Conn = w
|
||||
|
||||
recv := make(chan []byte)
|
||||
go func() { // receive the ack
|
||||
@@ -1937,7 +1937,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
|
||||
for i, tx := range tt {
|
||||
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
cl.Net.conn = w
|
||||
cl.Net.Conn = w
|
||||
|
||||
recv := make(chan []byte)
|
||||
go func() { // receive the ack
|
||||
|
Reference in New Issue
Block a user