Compare commits

...

11 Commits

Author SHA1 Message Date
JB
75504ff201 Update server version 2022-12-16 18:27:29 +00:00
Wind
a556feb325 Add the OnUnsubscribed hook to the unsubscribeClient method (#122)
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-16 18:23:58 +00:00
JB
8d4cc091b4 Update version number 2022-12-16 00:31:59 +00:00
JB
d8f28cb843 Enforce server max packet (#121)
* Enforce Server Maximum Packet Size on client read
* Fix tests
2022-12-16 00:30:23 +00:00
JB
88861c219d Merge pull request #116 from tommyminds/bugfix/ws_malformed_package
Fix websocket malformed packet bug
2022-12-15 18:21:53 +00:00
JB
7ba6cf28d9 Merge branch 'master' into bugfix/ws_malformed_package 2022-12-15 18:21:33 +00:00
JB
c174cfdc6b Merge pull request #119 from mochi-co/fix-on-published
Fix mis-typed onpublished hook, update version, fanpool defaults
2022-12-15 18:21:19 +00:00
mochi-co
4f198a99dd Fix mis-typed onpublished hook, update version, fanpool defaults 2022-12-15 18:19:02 +00:00
Tommy Maintz
2a9c9fcc40 Fix websocket malformed packet bug 2022-12-14 21:41:33 +01:00
JB
835a85c8bf Update README.md 2022-12-12 11:44:36 +00:00
mochi-co
fe5d9ffa61 Simplify Client construction, add NewClient method to Server, add Publish convenience method 2022-12-12 11:37:19 +00:00
9 changed files with 396 additions and 251 deletions

View File

@@ -305,13 +305,22 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`. If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
### Packet Injection
It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement). ### Direct Publish
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
```go ```go
cl := mqtt.NewInlineClient("inline", "local") err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
```
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
### Packet Injection
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
```go
cl := server.NewClient(nil, "local", "inline", true)
server.InjectPacket(cl, packets.Packet{ server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{ FixedHeader: packets.FixedHeader{
Type: packets.Publish, Type: packets.Publish,

View File

@@ -146,14 +146,10 @@ type ClientState struct {
keepalive uint16 // the number of seconds the connection can wait keepalive uint16 // the number of seconds the connection can wait
} }
// NewClient returns a new instance of Client. // newClient returns a new instance of Client. This is almost exclusively used by Server
func NewClient(c net.Conn, o *ops) *Client { // for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
cl := &Client{ cl := &Client{
Net: ClientConnection{
conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Remote: c.RemoteAddr().String(),
},
State: ClientState{ State: ClientState{
Inflight: NewInflights(), Inflight: NewInflights(),
Subscriptions: NewSubscriptions(), Subscriptions: NewSubscriptions(),
@@ -166,46 +162,19 @@ func NewClient(c net.Conn, o *ops) *Client {
ops: o, ops: o,
} }
if c != nil {
cl.Net = ClientConnection{
conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Remote: c.RemoteAddr().String(),
}
}
cl.refreshDeadline(cl.State.keepalive) cl.refreshDeadline(cl.State.keepalive)
return cl return cl
} }
// NewInlineClient returns a client used when publishing from the embedding system.
func NewInlineClient(id, remote string) *Client {
return &Client{
ID: id,
Net: ClientConnection{
Remote: remote,
Inline: true,
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
}
}
// newClientStub returns an instance of Client with minimal initializations, such as
// restoring client data from a db. In particular, the client is marked as offline (done).
func newClientStub() *Client {
return &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
done: 1,
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
}
}
// ParseConnect parses the connect parameters and properties for a client. // ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) { func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid cl.Net.Listener = lid
@@ -414,6 +383,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
return err return err
} }
if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1)) atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil return nil
} }

View File

@@ -22,10 +22,10 @@ const pkInfo = "packet type %v, %s"
var errClientStop = errors.New("test stop") var errClientStop = errors.New("test stop")
func newClient() (cl *Client, r net.Conn, w net.Conn) { func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
r, w = net.Pipe() r, w = net.Pipe()
cl = NewClient(w, &ops{ cl = newClient(w, &ops{
info: new(system.Info), info: new(system.Info),
hooks: new(Hooks), hooks: new(Hooks),
log: &logger, log: &logger,
@@ -119,34 +119,21 @@ func TestClientsGetByListener(t *testing.T) {
} }
func TestNewClient(t *testing.T) { func TestNewClient(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
require.NotNil(t, cl) require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal) require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions) require.NotNil(t, cl.State.Subscriptions)
require.Nil(t, cl.StopCause()) require.NotNil(t, cl.State.TopicAliases)
} require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
func TestNewClientStub(t *testing.T) { require.NotNil(t, cl.Net.conn)
cl := newClientStub() require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl) require.False(t, cl.Net.Inline)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
}
func TestNewInlineClient(t *testing.T) {
cl := NewInlineClient("inline", "local")
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done))
require.Equal(t, "inline", cl.ID)
require.Equal(t, "local", cl.Net.Remote)
} }
func TestClientParseConnect(t *testing.T) { func TestClientParseConnect(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
pk := packets.Packet{ pk := packets.Packet{
ProtocolVersion: 4, ProtocolVersion: 4,
@@ -183,7 +170,7 @@ func TestClientParseConnect(t *testing.T) {
} }
func TestClientParseConnectOverrideWillDelay(t *testing.T) { func TestClientParseConnectOverrideWillDelay(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
pk := packets.Packet{ pk := packets.Packet{
ProtocolVersion: 4, ProtocolVersion: 4,
@@ -208,13 +195,13 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) {
} }
func TestClientParseConnectNoID(t *testing.T) { func TestClientParseConnectNoID(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.ParseConnect("tcp1", packets.Packet{}) cl.ParseConnect("tcp1", packets.Packet{})
require.NotEmpty(t, cl.ID) require.NotEmpty(t, cl.ID)
} }
func TestClientNextPacketID(t *testing.T) { func TestClientNextPacketID(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
i, err := cl.NextPacketID() i, err := cl.NextPacketID()
require.NoError(t, err) require.NoError(t, err)
@@ -226,7 +213,7 @@ func TestClientNextPacketID(t *testing.T) {
} }
func TestClientNextPacketIDInUse(t *testing.T) { func TestClientNextPacketIDInUse(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
// skip over 2 // skip over 2
cl.State.Inflight.Set(packets.Packet{PacketID: 2}) cl.State.Inflight.Set(packets.Packet{PacketID: 2})
@@ -249,7 +236,7 @@ func TestClientNextPacketIDInUse(t *testing.T) {
} }
func TestClientNextPacketIDExhausted(t *testing.T) { func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
for i := 0; i <= 65535; i++ { for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
} }
@@ -261,7 +248,7 @@ func TestClientNextPacketIDExhausted(t *testing.T) {
} }
func TestClientNextPacketIDOverflow(t *testing.T) { func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.packetID = uint32(65534) cl.State.packetID = uint32(65534)
@@ -275,7 +262,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) {
} }
func TestClientClearInflights(t *testing.T) { func TestClientClearInflights(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
n := time.Now().Unix() n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
@@ -291,7 +278,7 @@ func TestClientClearInflights(t *testing.T) {
func TestClientResendInflightMessages(t *testing.T) { func TestClientResendInflightMessages(t *testing.T) {
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback) pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
cl, r, w := newClient() cl, r, w := newTestClient()
cl.State.Inflight.Set(*pk1.Packet) cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len()) require.Equal(t, 1, cl.State.Inflight.Len())
@@ -311,7 +298,7 @@ func TestClientResendInflightMessages(t *testing.T) {
func TestClientResendInflightMessagesWriteFailure(t *testing.T) { func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newClient() cl, r, _ := newTestClient()
r.Close() r.Close()
cl.State.Inflight.Set(*pk1.Packet) cl.State.Inflight.Set(*pk1.Packet)
@@ -323,19 +310,19 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
} }
func TestClientResendInflightMessagesNoMessages(t *testing.T) { func TestClientResendInflightMessagesNoMessages(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := cl.ResendInflightMessages(true) err := cl.ResendInflightMessages(true)
require.NoError(t, err) require.NoError(t, err)
} }
func TestClientRefreshDeadline(t *testing.T) { func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.refreshDeadline(10) cl.refreshDeadline(10)
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline? require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
} }
func TestClientReadFixedHeader(t *testing.T) { func TestClientReadFixedHeader(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
@@ -350,7 +337,7 @@ func TestClientReadFixedHeader(t *testing.T) {
} }
func TestClientReadFixedHeaderDecodeError(t *testing.T) { func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
@@ -363,8 +350,24 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) { func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
@@ -378,7 +381,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
} }
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
@@ -392,7 +395,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
} }
func TestClientReadOK(t *testing.T) { func TestClientReadOK(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
r.Write([]byte{ r.Write([]byte{
@@ -446,7 +449,7 @@ func TestClientReadOK(t *testing.T) {
} }
func TestClientReadDone(t *testing.T) { func TestClientReadDone(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
cl.State.done = 1 cl.State.done = 1
@@ -461,15 +464,16 @@ func TestClientReadDone(t *testing.T) {
} }
func TestClientStop(t *testing.T) { func TestClientStop(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Stop(nil) cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load()) require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected) require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done) require.Equal(t, uint32(1), cl.State.done)
require.Equal(t, nil, cl.StopCause())
} }
func TestClientReadFixedHeaderError(t *testing.T) { func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
r.Write([]byte{ r.Write([]byte{
@@ -486,7 +490,7 @@ func TestClientReadFixedHeaderError(t *testing.T) {
} }
func TestClientReadReadHandlerErr(t *testing.T) { func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
r.Write([]byte{ r.Write([]byte{
@@ -506,7 +510,7 @@ func TestClientReadReadHandlerErr(t *testing.T) {
} }
func TestClientReadReadPacketOK(t *testing.T) { func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
r.Write([]byte{ r.Write([]byte{
@@ -538,7 +542,7 @@ func TestClientReadReadPacketOK(t *testing.T) {
} }
func TestClientReadPacket(t *testing.T) { func TestClientReadPacket(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
for _, tx := range pkTable { for _, tx := range pkTable {
@@ -571,9 +575,17 @@ func TestClientReadPacket(t *testing.T) {
} }
} }
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
}
func TestClientWritePacket(t *testing.T) { func TestClientWritePacket(t *testing.T) {
for _, tt := range pkTable { for _, tt := range pkTable {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
@@ -613,7 +625,7 @@ func TestClientWritePacket(t *testing.T) {
} }
func TestWriteClientOversizePacket(t *testing.T) { func TestWriteClientOversizePacket(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Properties.Props.MaximumPacketSize = 2 cl.Properties.Props.MaximumPacketSize = 2
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
err := cl.WritePacket(pk) err := cl.WritePacket(pk)
@@ -622,7 +634,7 @@ func TestWriteClientOversizePacket(t *testing.T) {
} }
func TestClientReadPacketReadingError(t *testing.T) { func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
r.Write([]byte{ r.Write([]byte{
@@ -642,7 +654,7 @@ func TestClientReadPacketReadingError(t *testing.T) {
} }
func TestClientReadPacketReadUnknown(t *testing.T) { func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newClient() cl, r, _ := newTestClient()
defer cl.Stop(errClientStop) defer cl.Stop(errClientStop)
go func() { go func() {
r.Write([]byte{ r.Write([]byte{
@@ -661,7 +673,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
} }
func TestClientWritePacketWriteNoConn(t *testing.T) { func TestClientWritePacketWriteNoConn(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Stop(errClientStop) cl.Stop(errClientStop)
err := cl.WritePacket(*pkTable[1].Packet) err := cl.WritePacket(*pkTable[1].Packet)
@@ -670,7 +682,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
} }
func TestClientWritePacketWriteError(t *testing.T) { func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Net.conn.Close() cl.Net.conn.Close()
err := cl.WritePacket(*pkTable[1].Packet) err := cl.WritePacket(*pkTable[1].Packet)
@@ -678,7 +690,7 @@ func TestClientWritePacketWriteError(t *testing.T) {
} }
func TestClientWritePacketInvalidPacket(t *testing.T) { func TestClientWritePacketInvalidPacket(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := cl.WritePacket(packets.Packet{}) err := cl.WritePacket(packets.Packet{})
require.Error(t, err) require.Error(t, err)
} }

View File

@@ -52,15 +52,30 @@ func main() {
// `server.Publish` method. Subscribe to `direct/publish` using your // `server.Publish` method. Subscribe to `direct/publish` using your
// MQTT client to see the messages. // MQTT client to see the messages.
go func() { go func() {
cl := mqtt.NewInlineClient("inline", "local") cl := server.NewClient(nil, "local", "inline", true)
for range time.Tick(time.Second * 10) { for range time.Tick(time.Second * 1) {
server.InjectPacket(cl, packets.Packet{ err := server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{ FixedHeader: packets.FixedHeader{
Type: packets.Publish, Type: packets.Publish,
}, },
TopicName: "direct/publish", TopicName: "direct/publish",
Payload: []byte("scheduled message"), Payload: []byte("injected scheduled message"),
}) })
if err != nil {
server.Log.Error().Err(err).Msg("server.InjectPacket")
}
server.Log.Info().Msgf("main.go injected packet to direct/publish")
}
}()
// There is also a shorthand convenience function, Publish, for easily sending
// publish packets if you are not concerned with creating your own packets.
go func() {
for range time.Tick(time.Second * 5) {
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
if err != nil {
server.Log.Error().Err(err).Msg("server.Publish")
}
server.Log.Info().Msgf("main.go issued direct message to direct/publish") server.Log.Info().Msgf("main.go issued direct message to direct/publish")
} }
}() }()

View File

@@ -351,7 +351,7 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
} }
} }
// OnPublish is called when a client publishes a message. This method differs from OnMessage // OnPublish is called when a client publishes a message. This method differs from OnPublished
// in that it allows you to modify you to modify the incoming packet before it is processed. // in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached. // The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {

View File

@@ -13,7 +13,7 @@ import (
) )
func TestInflightSet(t *testing.T) { func TestInflightSet(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1}) r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.True(t, r) require.True(t, r)
@@ -25,7 +25,7 @@ func TestInflightSet(t *testing.T) {
} }
func TestInflightGet(t *testing.T) { func TestInflightGet(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2}) cl.State.Inflight.Set(packets.Packet{PacketID: 2})
msg, ok := cl.State.Inflight.Get(2) msg, ok := cl.State.Inflight.Get(2)
@@ -34,7 +34,7 @@ func TestInflightGet(t *testing.T) {
} }
func TestInflightGetAllAndImmediate(t *testing.T) { func TestInflightGetAllAndImmediate(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
@@ -56,13 +56,13 @@ func TestInflightGetAllAndImmediate(t *testing.T) {
} }
func TestInflightLen(t *testing.T) { func TestInflightLen(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2}) cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len()) require.Equal(t, 1, cl.State.Inflight.Len())
} }
func TestInflightDelete(t *testing.T) { func TestInflightDelete(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 3}) cl.State.Inflight.Set(packets.Packet{PacketID: 3})
require.NotNil(t, cl.State.Inflight.internal[3]) require.NotNil(t, cl.State.Inflight.internal[3])
@@ -163,7 +163,7 @@ func TestSendQuota(t *testing.T) {
} }
func TestNextImmediate(t *testing.T) { func TestNextImmediate(t *testing.T) {
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})

View File

@@ -7,6 +7,7 @@ package listeners
import ( import (
"context" "context"
"errors" "errors"
"io"
"net" "net"
"net/http" "net/http"
"sync" "sync"
@@ -137,25 +138,35 @@ type wsConn struct {
} }
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read. // Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) { func (ws *wsConn) Read(p []byte) (int, error) {
op, r, err := ws.c.NextReader() op, r, err := ws.c.NextReader()
if err != nil { if err != nil {
return return 0, err
} }
if op != websocket.BinaryMessage { if op != websocket.BinaryMessage {
err = ErrInvalidMessage err = ErrInvalidMessage
return return 0, err
} }
return r.Read(p) var n, br int
for {
br, err = r.Read(p[n:])
n += br
if err != nil {
if err == io.EOF {
err = nil
}
return n, err
}
}
} }
// Write writes bytes to the websocket connection. // Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) { func (ws *wsConn) Write(p []byte) (int, error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p) err := ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil { if err != nil {
return return 0, err
} }
return len(p), nil return len(p), nil

View File

@@ -26,10 +26,10 @@ import (
) )
const ( const (
Version = "2.0.0" // the current server version. Version = "2.0.7" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 64 // the number of concurrent workers in the pool defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
) )
var ( var (
@@ -199,6 +199,31 @@ func (o *Options) ensureDefaults() {
} }
} }
// NewClient returns a new Client instance, populated with all the required values and
// references to be used with the server. If you are using this client to directly publish
// messages from the embedding application, set the inline flag to true to bypass ACL and
// topic validation checks.
func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client {
cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit
capabilities: s.Options.Capabilities,
info: s.Info,
hooks: s.hooks,
log: s.Log,
})
cl.ID = id
cl.Net.Listener = listener
if inline { // inline clients bypass acl and some validity checks.
cl.Net.Inline = true
// By default we don't want to restrict developer publishes,
// but if you do, reset this after creating inline client.
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
}
return cl
}
// AddHook attaches a new Hook to the server. Ideally, this should be called // AddHook attaches a new Hook to the server. Ideally, this should be called
// before the server is started with s.Serve(). // before the server is started with s.Serve().
func (s *Server) AddHook(hook Hook, config any) error { func (s *Server) AddHook(hook Hook, config any) error {
@@ -281,27 +306,21 @@ func (s *Server) eventLoop() {
} }
// EstablishConnection establishes a new client when a listener accepts a new connection. // EstablishConnection establishes a new client when a listener accepts a new connection.
func (s *Server) EstablishConnection(lid string, c net.Conn) error { func (s *Server) EstablishConnection(listener string, c net.Conn) error {
cl := NewClient(c, &ops{ // [MQTT-3.1.2-6] implicit cl := s.NewClient(c, listener, "", false)
capabilities: s.Options.Capabilities, return s.attachClient(cl, listener)
info: s.Info,
hooks: s.hooks,
log: s.Log,
})
return s.attachClient(cl, lid)
} }
// attachClient validates an incoming client connection and if viable, attaches the client // attachClient validates an incoming client connection and if viable, attaches the client
// to the server, performs session housekeeping, and reads incoming packets. // to the server, performs session housekeeping, and reads incoming packets.
func (s *Server) attachClient(cl *Client, lid string) error { func (s *Server) attachClient(cl *Client, listener string) error {
defer cl.Stop(nil) defer cl.Stop(nil)
pk, err := s.readConnectionPacket(cl) pk, err := s.readConnectionPacket(cl)
if err != nil { if err != nil {
return fmt.Errorf("read connection: %w", err) return fmt.Errorf("read connection: %w", err)
} }
cl.ParseConnect(lid, pk) cl.ParseConnect(listener, pk)
code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2] code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2]
if code != packets.CodeSuccess { if code != packets.CodeSuccess {
if err := s.sendConnack(cl, code, false); err != nil { if err := s.sendConnack(cl, code, false); err != nil {
@@ -353,11 +372,11 @@ func (s *Server) attachClient(cl *Client, lid string) error {
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10] cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
} }
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", lid).Msg("client disconnected") s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire) s.hooks.OnDisconnect(cl, err, expire)
if expire { if expire {
s.unsubscribeClient(cl) s.UnsubscribeClient(cl)
cl.ClearInflights(math.MaxInt64, 0) cl.ClearInflights(math.MaxInt64, 0)
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23] s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
} }
@@ -436,7 +455,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
defer existing.Unlock() defer existing.Unlock()
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.unsubscribeClient(existing) s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0) existing.ClearInflights(math.MaxInt64, 0)
return false // [MQTT-3.2.2-3] return false // [MQTT-3.2.2-3]
} }
@@ -592,6 +611,24 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
}) })
} }
// Publish publishes a publish packet into the broker as if it were sent from the speicfied client.
// This is a convenience function which wraps InjectPacket. As such, this method can publish packets
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
cl := s.NewClient(nil, "local", "inline", true)
return s.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: qos,
Retain: retain,
},
TopicName: topic,
Payload: payload,
PacketID: uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks.
})
}
// InjectPacket injects a packet into the broker as if it were sent from the specified client. // InjectPacket injects a packet into the broker as if it were sent from the specified client.
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. // InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
@@ -627,7 +664,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
pk.Origin = cl.ID pk.Origin = cl.ID
pk.Created = time.Now().Unix() pk.Created = time.Now().Unix()
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok { if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok && !cl.Net.Inline {
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10] if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse) ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
return cl.WritePacket(ack) return cl.WritePacket(ack)
@@ -660,6 +697,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
s.publishToSubscribers(pk) s.publishToSubscribers(pk)
}) })
s.hooks.OnPublished(cl, pk)
return nil return nil
} }
@@ -690,8 +728,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
s.publishToSubscribers(pk) s.publishToSubscribers(pk)
}) })
s.hooks.OnPublish(cl, pk) s.hooks.OnPublished(cl, pk)
return nil return nil
} }
@@ -1035,14 +1072,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
return cl.WritePacket(ack) return cl.WritePacket(ack)
} }
// unsubscribeClient unsubscribes a client from all of their subscriptions. // UnsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) unsubscribeClient(cl *Client) { func (s *Server) UnsubscribeClient(cl *Client) {
for k := range cl.State.Subscriptions.GetAll() { i := 0
filterMap := cl.State.Subscriptions.GetAll()
filters := make([]packets.Subscription, len(filterMap))
for k, v := range filterMap {
cl.State.Subscriptions.Delete(k) cl.State.Subscriptions.Delete(k)
if s.Topics.Unsubscribe(k, cl.ID) { if s.Topics.Unsubscribe(k, cl.ID) {
atomic.AddInt64(&s.Info.Subscriptions, -1) atomic.AddInt64(&s.Info.Subscriptions, -1)
} }
filters[i] = v
i++
} }
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
} }
// processAuth processes an Auth packet. // processAuth processes an Auth packet.
@@ -1087,9 +1130,14 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1] out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1]
} }
// We already have a code we are using to disconnect the client, so we are not
// interested if the write packet fails due to a closed connection (as we are closing it).
err := cl.WritePacket(out) err := cl.WritePacket(out)
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect { if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
cl.Stop(code) cl.Stop(code)
if code.Code >= packets.ErrUnspecifiedError.Code {
return code
}
} }
return err return err
@@ -1304,9 +1352,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) {
// loadClients restores clients from the datastore. // loadClients restores clients from the datastore.
func (s *Server) loadClients(v []storage.Client) { func (s *Server) loadClients(v []storage.Client) {
for _, c := range v { for _, c := range v {
cl := newClientStub() cl := s.NewClient(nil, c.Listener, c.ID, false)
cl.ID = c.ID
cl.Net.Listener = c.Listener
cl.Properties.Username = c.Username cl.Properties.Username = c.Username
cl.Properties.Clean = c.Clean cl.Properties.Clean = c.Clean
cl.Properties.ProtocolVersion = c.ProtocolVersion cl.Properties.ProtocolVersion = c.ProtocolVersion

View File

@@ -102,7 +102,34 @@ func TestNewNilOpts(t *testing.T) {
require.NotNil(t, s.Options) require.NotNil(t, s.Options)
} }
func TestAddHook(t *testing.T) { func TestServerNewClient(t *testing.T) {
s := New(nil)
s.Log = &logger
r, _ := net.Pipe()
cl := s.NewClient(r, "testing", "test", false)
require.NotNil(t, cl)
require.Equal(t, "test", cl.ID)
require.Equal(t, "testing", cl.Net.Listener)
require.False(t, cl.Net.Inline)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.Equal(t, s.Log, cl.ops.log)
}
func TestServerNewClientInline(t *testing.T) {
s := New(nil)
cl := s.NewClient(nil, "testing", "test", true)
require.True(t, cl.Net.Inline)
}
func TestServerAddHook(t *testing.T) {
s := New(nil) s := New(nil)
s.Log = &logger s.Log = &logger
require.NotNil(t, s) require.NotNil(t, s)
@@ -113,7 +140,7 @@ func TestAddHook(t *testing.T) {
require.Equal(t, int64(1), s.hooks.Len()) require.Equal(t, int64(1), s.hooks.Len())
} }
func TestAddListener(t *testing.T) { func TestServerAddListener(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
@@ -128,7 +155,7 @@ func TestAddListener(t *testing.T) {
require.Equal(t, ErrListenerIDExists, err) require.Equal(t, ErrListenerIDExists, err)
} }
func TestAddListenerInitFailure(t *testing.T) { func TestServerAddListenerInitFailure(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
@@ -197,7 +224,7 @@ func TestServerReadConnectionPacket(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
cl, r, _ := newClient() cl, r, _ := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
o := make(chan packets.Packet) o := make(chan packets.Packet)
@@ -219,7 +246,7 @@ func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
cl, r, _ := newClient() cl, r, _ := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
o := make(chan error) o := make(chan error)
@@ -242,7 +269,7 @@ func TestServerReadConnectionPacketBadPacketType(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
cl, r, _ := newClient() cl, r, _ := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
go func() { go func() {
@@ -259,7 +286,7 @@ func TestServerReadConnectionPacketBadPacket(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
cl, r, _ := newClient() cl, r, _ := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
go func() { go func() {
@@ -377,7 +404,7 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
cl, r0, _ := newClient() cl, r0, _ := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
@@ -438,7 +465,7 @@ func TestEstablishConnectionResentPendingInflightsError(t *testing.T) {
defer s.Close() defer s.Close()
n := time.Now().Unix() n := time.Now().Unix()
cl, r0, _ := newClient() cl, r0, _ := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.State.Inflight = NewInflights() cl.State.Inflight = NewInflights()
@@ -474,7 +501,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
cl, r0, _ := newClient() cl, r0, _ := newTestClient()
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.Properties.Clean = true cl.Properties.Clean = true
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
@@ -660,7 +687,7 @@ func TestServerEstablishConnectionBadPacket(t *testing.T) {
func TestServerSendConnack(t *testing.T) { func TestServerSendConnack(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.ServerKeepAlive = 20 s.Options.Capabilities.ServerKeepAlive = 20
s.Options.Capabilities.MaximumQos = 1 s.Options.Capabilities.MaximumQos = 1
@@ -680,7 +707,7 @@ func TestServerSendConnack(t *testing.T) {
func TestServerSendConnackFailureReason(t *testing.T) { func TestServerSendConnackFailureReason(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.ServerKeepAlive = 20 s.Options.Capabilities.ServerKeepAlive = 20
go func() { go func() {
@@ -758,7 +785,7 @@ func TestServerValidateConnect(t *testing.T) {
func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) { func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.Properties.Props.SessionExpiryInterval = uint32(300) cl.Properties.Props.SessionExpiryInterval = uint32(300)
s.Options.Capabilities.MaximumSessionExpiryInterval = 120 s.Options.Capabilities.MaximumSessionExpiryInterval = 120
@@ -778,7 +805,7 @@ func TestInheritClientSession(t *testing.T) {
n := time.Now().Unix() n := time.Now().Unix()
existing, _, _ := newClient() existing, _, _ := newTestClient()
existing.Net.conn = nil existing.Net.conn = nil
existing.ID = "mochi" existing.ID = "mochi"
existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
@@ -788,7 +815,7 @@ func TestInheritClientSession(t *testing.T) {
s.Clients.Add(existing) s.Clients.Add(existing)
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
require.Equal(t, 0, cl.State.Inflight.Len()) require.Equal(t, 0, cl.State.Inflight.Len())
@@ -801,7 +828,7 @@ func TestInheritClientSession(t *testing.T) {
require.Equal(t, 1, cl.State.Subscriptions.Len()) require.Equal(t, 1, cl.State.Subscriptions.Len())
// On clean, clear existing properties // On clean, clear existing properties
cl, _, _ = newClient() cl, _, _ = newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
b = s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi", Clean: true}}, cl) b = s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi", Clean: true}}, cl)
require.False(t, b) require.False(t, b)
@@ -811,27 +838,27 @@ func TestInheritClientSession(t *testing.T) {
func TestServerUnsubscribeClient(t *testing.T) { func TestServerUnsubscribeClient(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
pk := packets.Subscription{Filter: "a/b/c", Qos: 1} pk := packets.Subscription{Filter: "a/b/c", Qos: 1}
cl.State.Subscriptions.Add("a/b/c", pk) cl.State.Subscriptions.Add("a/b/c", pk)
s.Topics.Subscribe(cl.ID, pk) s.Topics.Subscribe(cl.ID, pk)
subs := s.Topics.Subscribers("a/b/c") subs := s.Topics.Subscribers("a/b/c")
require.Equal(t, 1, len(subs.Subscriptions)) require.Equal(t, 1, len(subs.Subscriptions))
s.unsubscribeClient(cl) s.UnsubscribeClient(cl)
subs = s.Topics.Subscribers("a/b/c") subs = s.Topics.Subscribers("a/b/c")
require.Equal(t, 0, len(subs.Subscriptions)) require.Equal(t, 0, len(subs.Subscriptions))
} }
func TestServerProcessPacketFailure(t *testing.T) { func TestServerProcessPacketFailure(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := s.processPacket(cl, packets.Packet{}) err := s.processPacket(cl, packets.Packet{})
require.Error(t, err) require.Error(t, err)
} }
func TestServerProcessPacketConnect(t *testing.T) { func TestServerProcessPacketConnect(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := s.processPacket(cl, *packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet)
require.Error(t, err) require.Error(t, err)
@@ -839,7 +866,7 @@ func TestServerProcessPacketConnect(t *testing.T) {
func TestServerProcessPacketPingreq(t *testing.T) { func TestServerProcessPacketPingreq(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet)
@@ -854,7 +881,7 @@ func TestServerProcessPacketPingreq(t *testing.T) {
func TestServerProcessPacketPingreqError(t *testing.T) { func TestServerProcessPacketPingreqError(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Stop(packets.CodeDisconnect) cl.Stop(packets.CodeDisconnect)
err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet)
@@ -864,7 +891,7 @@ func TestServerProcessPacketPingreqError(t *testing.T) {
func TestServerProcessPacketPublishInvalid(t *testing.T) { func TestServerProcessPacketPublishInvalid(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishInvalidQosMustPacketID).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishInvalidQosMustPacketID).Packet)
require.Error(t, err) require.Error(t, err)
@@ -876,12 +903,12 @@ func TestInjectPacketPublishAndReceive(t *testing.T) {
s.Serve() s.Serve()
defer s.Close() defer s.Close()
sender, _, w1 := newClient() sender, _, w1 := newTestClient()
sender.Net.Inline = true sender.Net.Inline = true
sender.ID = "sender" sender.ID = "sender"
s.Clients.Add(sender) s.Clients.Add(sender)
receiver, r2, w2 := newClient() receiver, r2, w2 := newTestClient()
receiver.ID = "receiver" receiver.ID = "receiver"
s.Clients.Add(receiver) s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
@@ -906,10 +933,46 @@ func TestInjectPacketPublishAndReceive(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
} }
func TestServerDirectPublishAndReceive(t *testing.T) {
s := newServer()
s.Serve()
defer s.Close()
sender, _, w1 := newTestClient()
sender.Net.Inline = true
sender.ID = "sender"
s.Clients.Add(sender)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos)
require.NoError(t, err)
w1.Close()
time.Sleep(time.Millisecond * 10)
w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
}
func TestInjectPacketError(t *testing.T) { func TestInjectPacketError(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Net.Inline = true cl.Net.Inline = true
pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet
pkx.Filters = packets.Subscriptions{} pkx.Filters = packets.Subscriptions{}
@@ -920,7 +983,7 @@ func TestInjectPacketError(t *testing.T) {
func TestInjectPacketPublishInvalidTopic(t *testing.T) { func TestInjectPacketPublishInvalidTopic(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Net.Inline = true cl.Net.Inline = true
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
pkx.TopicName = "$SYS/test" pkx.TopicName = "$SYS/test"
@@ -933,11 +996,11 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) {
s.Serve() s.Serve()
defer s.Close() defer s.Close()
sender, _, w1 := newClient() sender, _, w1 := newTestClient()
sender.ID = "sender" sender.ID = "sender"
s.Clients.Add(sender) s.Clients.Add(sender)
receiver, r2, w2 := newClient() receiver, r2, w2 := newTestClient()
receiver.ID = "receiver" receiver.ID = "receiver"
s.Clients.Add(receiver) s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
@@ -966,7 +1029,7 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) {
func TestServerProcessPacketAndNextImmediate(t *testing.T) { func TestServerProcessPacketAndNextImmediate(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
next := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet next := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
next.Expiry = -1 next.Expiry = -1
@@ -993,7 +1056,7 @@ func TestServerProcessPacketPublishAckFailure(t *testing.T) {
s.Serve() s.Serve()
defer s.Close() defer s.Close()
cl, _, w := newClient() cl, _, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
w.Close() w.Close()
@@ -1007,14 +1070,15 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) {
s.Serve() s.Serve()
defer s.Close() defer s.Close()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.State.Inflight.ResetReceiveQuota(0) cl.State.Inflight.ResetReceiveQuota(0)
s.Clients.Add(cl) s.Clients.Add(cl)
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.NoError(t, err) require.Error(t, err)
require.ErrorIs(t, err, packets.ErrReceiveMaximum)
w.Close() w.Close()
}() }()
@@ -1027,7 +1091,7 @@ func TestServerProcessPublishInvalidTopic(t *testing.T) {
s := newServer() s := newServer()
s.Serve() s.Serve()
defer s.Close() defer s.Close()
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet) err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet)
require.NoError(t, err) // $SYS topics should be ignored? require.NoError(t, err) // $SYS topics should be ignored?
} }
@@ -1040,7 +1104,7 @@ func TestServerProcessPublishACLCheckDeny(t *testing.T) {
}) })
s.Serve() s.Serve()
defer s.Close() defer s.Close()
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
require.NoError(t, err) // ACL check fails silently require.NoError(t, err) // ACL check fails silently
} }
@@ -1057,14 +1121,14 @@ func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) {
s.Serve() s.Serve()
defer s.Close() defer s.Close()
cl, _, _ := newClient() cl, _, _ := newTestClient()
err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
require.NoError(t, err) // packets rejected silently require.NoError(t, err) // packets rejected silently
} }
func TestServerProcessPacketPublishQos0(t *testing.T) { func TestServerProcessPacketPublishQos0(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
@@ -1079,7 +1143,7 @@ func TestServerProcessPacketPublishQos0(t *testing.T) {
func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) { func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
atomic.StoreInt64(&s.Info.Inflight, 1) atomic.StoreInt64(&s.Info.Inflight, 1)
@@ -1097,7 +1161,7 @@ func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) {
func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) { func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Pubrec}}) cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Pubrec}})
atomic.StoreInt64(&s.Info.Inflight, 1) atomic.StoreInt64(&s.Info.Inflight, 1)
@@ -1116,7 +1180,7 @@ func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) {
func TestServerProcessPacketPublishQos1(t *testing.T) { func TestServerProcessPacketPublishQos1(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
@@ -1131,7 +1195,7 @@ func TestServerProcessPacketPublishQos1(t *testing.T) {
func TestServerProcessPacketPublishQos2(t *testing.T) { func TestServerProcessPacketPublishQos2(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
@@ -1147,7 +1211,7 @@ func TestServerProcessPacketPublishQos2(t *testing.T) {
func TestServerProcessPacketPublishDowngradeQos(t *testing.T) { func TestServerProcessPacketPublishDowngradeQos(t *testing.T) {
s := newServer() s := newServer()
s.Options.Capabilities.MaximumQos = 1 s.Options.Capabilities.MaximumQos = 1
cl, r, w := newClient() cl, r, w := newTestClient()
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
@@ -1162,7 +1226,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) {
func TestPublishToSubscribersSelfNoLocal(t *testing.T) { func TestPublishToSubscribersSelfNoLocal(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", NoLocal: true}) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", NoLocal: true})
require.True(t, subbed) require.True(t, subbed)
@@ -1187,11 +1251,11 @@ func TestPublishToSubscribersSelfNoLocal(t *testing.T) {
func TestPublishToSubscribers(t *testing.T) { func TestPublishToSubscribers(t *testing.T) {
s := newServer() s := newServer()
cl, r1, w1 := newClient() cl, r1, w1 := newTestClient()
cl.ID = "cl1" cl.ID = "cl1"
cl2, r2, w2 := newClient() cl2, r2, w2 := newTestClient()
cl2.ID = "cl2" cl2.ID = "cl2"
cl3, r3, w3 := newClient() cl3, r3, w3 := newTestClient()
cl3.ID = "cl3" cl3.ID = "cl3"
s.Clients.Add(cl) s.Clients.Add(cl)
s.Clients.Add(cl2) s.Clients.Add(cl2)
@@ -1249,7 +1313,7 @@ func TestPublishToSubscribers(t *testing.T) {
func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) { func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) {
s := newServer() s := newServer()
s.Options.Capabilities.MaximumMessageExpiryInterval = 86400 s.Options.Capabilities.MaximumMessageExpiryInterval = 86400
cl, r1, w1 := newClient() cl, r1, w1 := newTestClient()
cl.ID = "cl1" cl.ID = "cl1"
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl) s.Clients.Add(cl)
@@ -1278,7 +1342,7 @@ func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) {
func TestPublishToSubscribersIdentifiers(t *testing.T) { func TestPublishToSubscribersIdentifiers(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl) s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/+", Identifier: 2}) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/+", Identifier: 2})
@@ -1308,7 +1372,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
s := newServer() s := newServer()
s.Options.Capabilities.MaximumQos = 1 s.Options.Capabilities.MaximumQos = 1
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
_, ok := cl.State.Inflight.Get(1) _, ok := cl.State.Inflight.Get(1)
@@ -1334,7 +1398,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
func TestPublishToClientServerTopicAlias(t *testing.T) { func TestPublishToClientServerTopicAlias(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.Properties.Props.TopicAliasMaximum = 5 cl.Properties.Props.TopicAliasMaximum = 5
s.Clients.Add(cl) s.Clients.Add(cl)
@@ -1363,7 +1427,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) {
func TestPublishToClientExhaustedPacketID(t *testing.T) { func TestPublishToClientExhaustedPacketID(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
for i := 0; i <= 65535; i++ { for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
} }
@@ -1375,7 +1439,7 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) {
func TestPublishToClientNoConn(t *testing.T) { func TestPublishToClientNoConn(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Net.conn = nil cl.Net.conn = nil
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
@@ -1385,12 +1449,12 @@ func TestPublishToClientNoConn(t *testing.T) {
func TestProcessPublishWithTopicAlias(t *testing.T) { func TestProcessPublishWithTopicAlias(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 0}) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
require.True(t, subbed) require.True(t, subbed)
cl2, _, w2 := newClient() cl2, _, w2 := newTestClient()
cl2.Properties.ProtocolVersion = 5 cl2.Properties.ProtocolVersion = 5
cl2.State.TopicAliases.Inbound.Set(1, "a/b/c") cl2.State.TopicAliases.Inbound.Set(1, "a/b/c")
@@ -1412,7 +1476,7 @@ func TestProcessPublishWithTopicAlias(t *testing.T) {
func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) { func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
cl.State.Inflight.sendQuota = 0 cl.State.Inflight.sendQuota = 0
@@ -1431,7 +1495,7 @@ func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) {
func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
for i := 0; i <= 65535; i++ { for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: 1}) cl.State.Inflight.Set(packets.Packet{PacketID: 1})
@@ -1452,7 +1516,7 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
func TestPublishToSubscribersNoConnection(t *testing.T) { func TestPublishToSubscribersNoConnection(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
require.True(t, subbed) require.True(t, subbed)
@@ -1467,7 +1531,7 @@ func TestPublishToSubscribersNoConnection(t *testing.T) {
func TestPublishRetainedToClient(t *testing.T) { func TestPublishRetainedToClient(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
@@ -1489,7 +1553,7 @@ func TestPublishRetainedToClient(t *testing.T) {
func TestPublishRetainedToClientIsShared(t *testing.T) { func TestPublishRetainedToClientIsShared(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
sub := packets.Subscription{Filter: SharePrefix + "/test/a/b/c"} sub := packets.Subscription{Filter: SharePrefix + "/test/a/b/c"}
@@ -1508,7 +1572,7 @@ func TestPublishRetainedToClientIsShared(t *testing.T) {
func TestPublishRetainedToClientError(t *testing.T) { func TestPublishRetainedToClientError(t *testing.T) {
s := newServer() s := newServer()
cl, _, w := newClient() cl, _, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
sub := packets.Subscription{Filter: "a/b/c"} sub := packets.Subscription{Filter: "a/b/c"}
@@ -1538,7 +1602,7 @@ func TestServerProcessPacketPuback(t *testing.T) {
t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) { t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1560,7 +1624,7 @@ func TestServerProcessPacketPuback(t *testing.T) {
func TestServerProcessPacketPubackNoPacketID(t *testing.T) { func TestServerProcessPacketPubackNoPacketID(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1575,7 +1639,7 @@ func TestServerProcessPacketPubackNoPacketID(t *testing.T) {
func TestServerProcessPacketPubrec(t *testing.T) { func TestServerProcessPacketPubrec(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1604,7 +1668,7 @@ func TestServerProcessPacketPubrec(t *testing.T) {
func TestServerProcessPacketPubrecNoPacketID(t *testing.T) { func TestServerProcessPacketPubrecNoPacketID(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1630,7 +1694,7 @@ func TestServerProcessPacketPubrecNoPacketID(t *testing.T) {
func TestServerProcessPacketPubrecInvalidReason(t *testing.T) { func TestServerProcessPacketPubrecInvalidReason(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: pID}) cl.State.Inflight.Set(packets.Packet{PacketID: pID})
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrecInvalidReason).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrecInvalidReason).Packet)
require.NoError(t, err) require.NoError(t, err)
@@ -1642,7 +1706,7 @@ func TestServerProcessPacketPubrecInvalidReason(t *testing.T) {
func TestServerProcessPacketPubrecFailure(t *testing.T) { func TestServerProcessPacketPubrecFailure(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: pID}) cl.State.Inflight.Set(packets.Packet{PacketID: pID})
cl.Stop(packets.CodeDisconnect) cl.Stop(packets.CodeDisconnect)
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet)
@@ -1653,7 +1717,7 @@ func TestServerProcessPacketPubrecFailure(t *testing.T) {
func TestServerProcessPacketPubrel(t *testing.T) { func TestServerProcessPacketPubrel(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1683,7 +1747,7 @@ func TestServerProcessPacketPubrel(t *testing.T) {
func TestServerProcessPacketPubrelNoPacketID(t *testing.T) { func TestServerProcessPacketPubrelNoPacketID(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1709,7 +1773,7 @@ func TestServerProcessPacketPubrelNoPacketID(t *testing.T) {
func TestServerProcessPacketPubrelFailure(t *testing.T) { func TestServerProcessPacketPubrelFailure(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: pID}) cl.State.Inflight.Set(packets.Packet{PacketID: pID})
cl.Stop(packets.CodeDisconnect) cl.Stop(packets.CodeDisconnect)
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet)
@@ -1720,7 +1784,7 @@ func TestServerProcessPacketPubrelFailure(t *testing.T) {
func TestServerProcessPacketPubrelBadReason(t *testing.T) { func TestServerProcessPacketPubrelBadReason(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: pID}) cl.State.Inflight.Set(packets.Packet{PacketID: pID})
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrelInvalidReason).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrelInvalidReason).Packet)
require.NoError(t, err) require.NoError(t, err)
@@ -1745,7 +1809,7 @@ func TestServerProcessPacketPubcomp(t *testing.T) {
t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) { t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Properties.ProtocolVersion = tx.protocolVersion cl.Properties.ProtocolVersion = tx.protocolVersion
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1792,7 +1856,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) {
pID := uint16(7) pID := uint16(7)
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1863,7 +1927,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
pID := uint16(6) pID := uint16(6)
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.State.packetID = uint32(6) cl.State.packetID = uint32(6)
cl.State.Inflight.sendQuota = 3 cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3 cl.State.Inflight.receiveQuota = 3
@@ -1907,7 +1971,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
func TestServerProcessPacketSubscribe(t *testing.T) { func TestServerProcessPacketSubscribe(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet)
@@ -1922,7 +1986,7 @@ func TestServerProcessPacketSubscribe(t *testing.T) {
func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) { func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
@@ -1941,7 +2005,7 @@ func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) {
func TestServerProcessPacketSubscribeInvalid(t *testing.T) { func TestServerProcessPacketSubscribeInvalid(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeSpecQosMustPacketID).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeSpecQosMustPacketID).Packet)
@@ -1951,7 +2015,7 @@ func TestServerProcessPacketSubscribeInvalid(t *testing.T) {
func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) { func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
go func() { go func() {
@@ -1967,7 +2031,7 @@ func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) {
func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) { func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
go func() { go func() {
@@ -1983,7 +2047,7 @@ func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) {
func TestServerProcessSubscribeWithRetain(t *testing.T) { func TestServerProcessSubscribeWithRetain(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained) require.Equal(t, int64(1), retained)
@@ -2007,7 +2071,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) {
func TestServerProcessSubscribeDowngradeQos(t *testing.T) { func TestServerProcessSubscribeDowngradeQos(t *testing.T) {
s := newServer() s := newServer()
s.Options.Capabilities.MaximumQos = 1 s.Options.Capabilities.MaximumQos = 1
cl, r, w := newClient() cl, r, w := newTestClient()
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMany).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMany).Packet)
@@ -2024,7 +2088,7 @@ func TestServerProcessSubscribeDowngradeQos(t *testing.T) {
func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) { func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"}) s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"})
s.Clients.Add(cl) s.Clients.Add(cl)
@@ -2046,7 +2110,7 @@ func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) {
func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) { func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
@@ -2067,7 +2131,7 @@ func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) {
func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
@@ -2091,7 +2155,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) {
func TestServerProcessSubscribeNoConnection(t *testing.T) { func TestServerProcessSubscribeNoConnection(t *testing.T) {
s := newServer() s := newServer()
cl, r, _ := newClient() cl, r, _ := newTestClient()
r.Close() r.Close()
err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
require.Error(t, err) require.Error(t, err)
@@ -2105,7 +2169,7 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
FanPoolQueueSize: 10, FanPoolQueueSize: 10,
}) })
s.Serve() s.Serve()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
go func() { go func() {
@@ -2127,7 +2191,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
}) })
s.Serve() s.Serve()
s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
go func() { go func() {
@@ -2143,7 +2207,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
func TestServerProcessSubscribeErrorDowngrade(t *testing.T) { func TestServerProcessSubscribeErrorDowngrade(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 3 cl.Properties.ProtocolVersion = 3
cl.State.packetID = 1 // just to match the same packet id (7) in the fixtures cl.State.packetID = 1 // just to match the same packet id (7) in the fixtures
@@ -2160,7 +2224,7 @@ func TestServerProcessSubscribeErrorDowngrade(t *testing.T) {
func TestServerProcessPacketUnsubscribe(t *testing.T) { func TestServerProcessPacketUnsubscribe(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b", Qos: 0}) s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b", Qos: 0})
go func() { go func() {
@@ -2177,7 +2241,7 @@ func TestServerProcessPacketUnsubscribe(t *testing.T) {
func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) { func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
go func() { go func() {
@@ -2194,7 +2258,7 @@ func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) {
func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) { func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet)
require.Error(t, err) require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID) require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
@@ -2202,7 +2266,7 @@ func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) {
func TestServerReceivePacketError(t *testing.T) { func TestServerReceivePacketError(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
err := s.receivePacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet) err := s.receivePacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet)
require.Error(t, err) require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID) require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
@@ -2210,7 +2274,7 @@ func TestServerReceivePacketError(t *testing.T) {
func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) { func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
cl.Properties.Props.SessionExpiryInterval = 0 cl.Properties.Props.SessionExpiryInterval = 0
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.Properties.Props.RequestProblemInfo = 0 cl.Properties.Props.RequestProblemInfo = 0
@@ -2227,9 +2291,24 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf)
} }
func TestServerRecievePacketDisconnectClient(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
go func() {
err := s.DisconnectClient(cl, packets.CodeDisconnect)
require.NoError(t, err)
w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, buf)
}
func TestServerProcessPacketDisconnect(t *testing.T) { func TestServerProcessPacketDisconnect(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Properties.Props.SessionExpiryInterval = 30 cl.Properties.Props.SessionExpiryInterval = 30
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
@@ -2246,7 +2325,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) { func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.Properties.Props.SessionExpiryInterval = 0 cl.Properties.Props.SessionExpiryInterval = 0
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
cl.Properties.Props.RequestProblemInfo = 0 cl.Properties.Props.RequestProblemInfo = 0
@@ -2259,7 +2338,7 @@ func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) {
func TestServerProcessPacketAuth(t *testing.T) { func TestServerProcessPacketAuth(t *testing.T) {
s := newServer() s := newServer()
cl, r, w := newClient() cl, r, w := newTestClient()
go func() { go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet) err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet)
@@ -2274,7 +2353,7 @@ func TestServerProcessPacketAuth(t *testing.T) {
func TestServerProcessPacketAuthInvalidReason(t *testing.T) { func TestServerProcessPacketAuthInvalidReason(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
pkx := *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet pkx := *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet
pkx.ReasonCode = 99 pkx.ReasonCode = 99
err := s.processPacket(cl, pkx) err := s.processPacket(cl, pkx)
@@ -2284,7 +2363,7 @@ func TestServerProcessPacketAuthInvalidReason(t *testing.T) {
func TestServerProcessPacketAuthFailure(t *testing.T) { func TestServerProcessPacketAuthFailure(t *testing.T) {
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
hook := new(modifiedHookBase) hook := new(modifiedHookBase)
hook.fail = true hook.fail = true
@@ -2301,7 +2380,7 @@ func TestServerSendLWT(t *testing.T) {
s.Serve() s.Serve()
defer s.Close() defer s.Close()
sender, _, w1 := newClient() sender, _, w1 := newTestClient()
sender.ID = "sender" sender.ID = "sender"
sender.Properties.Will = Will{ sender.Properties.Will = Will{
Flag: 1, Flag: 1,
@@ -2310,7 +2389,7 @@ func TestServerSendLWT(t *testing.T) {
} }
s.Clients.Add(sender) s.Clients.Add(sender)
receiver, r2, w2 := newClient() receiver, r2, w2 := newTestClient()
receiver.ID = "receiver" receiver.ID = "receiver"
s.Clients.Add(receiver) s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0}) s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
@@ -2337,7 +2416,7 @@ func TestServerSendLWT(t *testing.T) {
func TestServerSendLWTDelayed(t *testing.T) { func TestServerSendLWTDelayed(t *testing.T) {
s := newServer() s := newServer()
cl1, _, _ := newClient() cl1, _, _ := newTestClient()
cl1.ID = "cl1" cl1.ID = "cl1"
cl1.Properties.Will = Will{ cl1.Properties.Will = Will{
Flag: 1, Flag: 1,
@@ -2348,7 +2427,7 @@ func TestServerSendLWTDelayed(t *testing.T) {
} }
s.Clients.Add(cl1) s.Clients.Add(cl1)
cl2, r, w := newClient() cl2, r, w := newTestClient()
cl2.ID = "cl2" cl2.ID = "cl2"
s.Clients.Add(cl2) s.Clients.Add(cl2)
require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: "a/b/c"})) require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: "a/b/c"}))
@@ -2426,7 +2505,7 @@ func TestServerLoadSubscriptions(t *testing.T) {
} }
s := newServer() s := newServer()
cl, _, _ := newClient() cl, _, _ := newTestClient()
s.Clients.Add(cl) s.Clients.Add(cl)
require.Equal(t, 0, cl.State.Subscriptions.Len()) require.Equal(t, 0, cl.State.Subscriptions.Len())
s.loadSubscriptions(v) s.loadSubscriptions(v)
@@ -2486,7 +2565,7 @@ func TestServerClose(t *testing.T) {
hook := new(modifiedHookBase) hook := new(modifiedHookBase)
s.AddHook(hook, nil) s.AddHook(hook, nil)
cl, r, _ := newClient() cl, r, _ := newTestClient()
cl.Net.Listener = "t1" cl.Net.Listener = "t1"
cl.Properties.ProtocolVersion = 5 cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl) s.Clients.Add(cl)
@@ -2524,7 +2603,7 @@ func TestServerClearExpiredInflights(t *testing.T) {
s.Options.Capabilities.MaximumMessageExpiryInterval = 4 s.Options.Capabilities.MaximumMessageExpiryInterval = 4
n := time.Now().Unix() n := time.Now().Unix()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.ops.info = s.Info cl.ops.info = s.Info
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
@@ -2564,12 +2643,12 @@ func TestServerClearExpiredClients(t *testing.T) {
n := time.Now().Unix() n := time.Now().Unix()
cl, _, _ := newClient() cl, _, _ := newTestClient()
cl.ID = "cl" cl.ID = "cl"
s.Clients.Add(cl) s.Clients.Add(cl)
// No Expiry // No Expiry
cl0, _, _ := newClient() cl0, _, _ := newTestClient()
cl0.ID = "c0" cl0.ID = "c0"
cl0.State.disconnected = n - 10 cl0.State.disconnected = n - 10
cl0.State.done = 1 cl0.State.done = 1
@@ -2579,7 +2658,7 @@ func TestServerClearExpiredClients(t *testing.T) {
s.Clients.Add(cl0) s.Clients.Add(cl0)
// Normal Expiry // Normal Expiry
cl1, _, _ := newClient() cl1, _, _ := newTestClient()
cl1.ID = "c1" cl1.ID = "c1"
cl1.State.disconnected = n - 10 cl1.State.disconnected = n - 10
cl1.State.done = 1 cl1.State.done = 1
@@ -2589,7 +2668,7 @@ func TestServerClearExpiredClients(t *testing.T) {
s.Clients.Add(cl1) s.Clients.Add(cl1)
// No Expiry, indefinite session // No Expiry, indefinite session
cl2, _, _ := newClient() cl2, _, _ := newTestClient()
cl2.ID = "c2" cl2.ID = "c2"
cl2.State.disconnected = n - 10 cl2.State.disconnected = n - 10
cl2.State.done = 1 cl2.State.done = 1